nautilus_network/websocket/
proxy.rs1use std::fmt::Write as _;
43
44use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
45use rustls::{ClientConfig, RootCertStore, pki_types::ServerName};
46use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
47use tokio_rustls::{TlsConnector, client::TlsStream};
48use url::Url;
49
50use crate::{net::TcpStream, transport::TransportError};
51
52const MAX_PROXY_RESPONSE_BYTES: usize = 16 * 1024;
57
58#[derive(Debug)]
66pub enum ProxiedStream {
67 Plain(TcpStream),
69 PlainOverTlsProxy(Box<TlsStream<TcpStream>>),
71 Tls(Box<TlsStream<TcpStream>>),
73 TlsOverTlsProxy(Box<TlsStream<TlsStream<TcpStream>>>),
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
79pub struct WsTarget {
80 pub host: String,
82 pub port: u16,
84 pub is_tls: bool,
86}
87
88impl WsTarget {
89 pub fn parse(url: &str) -> Result<Self, TransportError> {
96 let parsed =
97 Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
98
99 let is_tls = match parsed.scheme() {
100 "ws" => false,
101 "wss" => true,
102 other => {
103 return Err(TransportError::InvalidUrl(format!(
104 "expected ws:// or wss:// scheme, was {other}"
105 )));
106 }
107 };
108
109 let raw_host = parsed
110 .host_str()
111 .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
112
113 let host = if raw_host.starts_with('[') && raw_host.ends_with(']') {
116 raw_host[1..raw_host.len() - 1].to_string()
117 } else {
118 raw_host.to_string()
119 };
120
121 let port = parsed.port().unwrap_or(if is_tls { 443 } else { 80 });
122
123 Ok(Self { host, port, is_tls })
124 }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum ProxyKind {
135 Http(ProxyTarget),
137 Unsupported {
139 scheme: String,
141 },
142}
143
144impl ProxyKind {
145 pub fn parse(url: &str) -> Result<Self, TransportError> {
153 let parsed =
154 Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
155
156 match parsed.scheme() {
157 "http" | "https" => ProxyTarget::parse(url).map(ProxyKind::Http),
158 scheme @ ("socks5" | "socks5h" | "socks4" | "socks4a") => {
159 if parsed.host_str().is_none_or(str::is_empty) {
164 return Err(TransportError::InvalidUrl(format!(
165 "proxy URL '{url}' is missing a host (did you mean {scheme}://...)?"
166 )));
167 }
168 Ok(Self::Unsupported {
169 scheme: scheme.to_string(),
170 })
171 }
172 other => Err(TransportError::InvalidUrl(format!(
173 "unsupported proxy scheme '{other}'; expected http:// or https://"
174 ))),
175 }
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct ProxyTarget {
182 pub host: String,
185 pub port: u16,
187 pub is_tls: bool,
189 pub auth_header: Option<String>,
192}
193
194impl ProxyTarget {
195 pub fn parse(url: &str) -> Result<Self, TransportError> {
206 let parsed =
207 Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
208
209 let is_tls = match parsed.scheme() {
210 "http" => false,
211 "https" => true,
212 "socks5" | "socks5h" | "socks4" | "socks4a" => {
213 return Err(TransportError::InvalidUrl(format!(
214 "SOCKS proxy scheme '{}' is not yet supported for WebSocket connections; \
215 use an http:// or https:// proxy",
216 parsed.scheme()
217 )));
218 }
219 other => {
220 return Err(TransportError::InvalidUrl(format!(
221 "unsupported proxy scheme '{other}'; expected http:// or https://"
222 )));
223 }
224 };
225
226 let raw_host = parsed
227 .host_str()
228 .ok_or_else(|| TransportError::InvalidUrl("proxy URL missing hostname".to_string()))?;
229
230 let host = if raw_host.starts_with('[') && raw_host.ends_with(']') {
234 raw_host[1..raw_host.len() - 1].to_string()
235 } else {
236 raw_host.to_string()
237 };
238
239 let port = parsed.port().unwrap_or(if is_tls { 443 } else { 80 });
240
241 let auth_header = if parsed.username().is_empty() {
242 None
243 } else {
244 let username = decode_userinfo(parsed.username());
245 let password = decode_userinfo(parsed.password().unwrap_or(""));
246 let credentials = format!("{username}:{password}");
247 Some(format!("Basic {}", BASE64.encode(credentials)))
248 };
249
250 Ok(Self {
251 host,
252 port,
253 is_tls,
254 auth_header,
255 })
256 }
257}
258
259fn decode_userinfo(value: &str) -> String {
263 let bytes = nautilus_core::string::urlencoding::decode_bytes(value.as_bytes());
264 String::from_utf8_lossy(&bytes).into_owned()
265}
266
267pub async fn tunnel_via_proxy(
283 target: &WsTarget,
284 proxy: &ProxyTarget,
285) -> Result<ProxiedStream, TransportError> {
286 let tcp = TcpStream::connect((proxy.host.as_str(), proxy.port))
287 .await
288 .map_err(TransportError::Io)?;
289
290 if let Err(e) = tcp.set_nodelay(true) {
291 log::warn!("Failed to enable TCP_NODELAY on proxy connection: {e:?}");
292 }
293
294 if proxy.is_tls {
295 let proxy_tls = wrap_tls(tcp, &proxy.host).await?;
296 let tunneled = send_connect(proxy_tls, target, proxy).await?;
297 if target.is_tls {
298 let upstream = wrap_tls(tunneled, &target.host).await?;
299 Ok(ProxiedStream::TlsOverTlsProxy(Box::new(upstream)))
300 } else {
301 Ok(ProxiedStream::PlainOverTlsProxy(Box::new(tunneled)))
302 }
303 } else {
304 let tunneled = send_connect(tcp, target, proxy).await?;
305 if target.is_tls {
306 let upstream = wrap_tls(tunneled, &target.host).await?;
307 Ok(ProxiedStream::Tls(Box::new(upstream)))
308 } else {
309 Ok(ProxiedStream::Plain(tunneled))
310 }
311 }
312}
313
314async fn send_connect<S>(
318 mut stream: S,
319 target: &WsTarget,
320 proxy: &ProxyTarget,
321) -> Result<S, TransportError>
322where
323 S: AsyncRead + AsyncWrite + Unpin,
324{
325 let host_header = format_host_header(&target.host, target.port);
326 let mut request = format!(
327 "CONNECT {host_header} HTTP/1.1\r\n\
328 Host: {host_header}\r\n\
329 Proxy-Connection: Keep-Alive\r\n"
330 );
331
332 if let Some(auth) = &proxy.auth_header {
333 write!(request, "Proxy-Authorization: {auth}\r\n").expect("writing to String never fails");
334 }
335 request.push_str("\r\n");
336
337 stream
338 .write_all(request.as_bytes())
339 .await
340 .map_err(TransportError::Io)?;
341 stream.flush().await.map_err(TransportError::Io)?;
342
343 read_connect_response(&mut stream).await?;
344 Ok(stream)
345}
346
347fn format_host_header(host: &str, port: u16) -> String {
348 if host.contains(':') && !(host.starts_with('[') && host.ends_with(']')) {
349 format!("[{host}]:{port}")
350 } else {
351 format!("{host}:{port}")
352 }
353}
354
355async fn read_connect_response<S>(stream: &mut S) -> Result<(), TransportError>
358where
359 S: AsyncRead + Unpin,
360{
361 let mut buf = Vec::with_capacity(512);
362 let mut byte = [0u8; 1];
363
364 loop {
365 let n = stream.read(&mut byte).await.map_err(TransportError::Io)?;
366 if n == 0 {
367 return Err(TransportError::Handshake(
368 "proxy closed connection before sending CONNECT response".to_string(),
369 ));
370 }
371
372 buf.push(byte[0]);
373
374 if buf.ends_with(b"\r\n\r\n") {
375 break;
376 }
377
378 if buf.len() > MAX_PROXY_RESPONSE_BYTES {
379 return Err(TransportError::Handshake(format!(
380 "proxy CONNECT response exceeded {MAX_PROXY_RESPONSE_BYTES} bytes without terminator"
381 )));
382 }
383 }
384
385 let text = std::str::from_utf8(&buf).map_err(|_| {
386 TransportError::Handshake("proxy CONNECT response was not valid UTF-8".to_string())
387 })?;
388
389 let status_line = text.lines().next().ok_or_else(|| {
390 TransportError::Handshake("proxy CONNECT response missing status line".to_string())
391 })?;
392
393 let mut parts = status_line.splitn(3, ' ');
395 let _version = parts.next().ok_or_else(|| {
396 TransportError::Handshake(format!("malformed status line: {status_line}"))
397 })?;
398 let status_code = parts
399 .next()
400 .ok_or_else(|| TransportError::Handshake(format!("malformed status line: {status_line}")))?
401 .parse::<u16>()
402 .map_err(|_| TransportError::Handshake(format!("non-numeric status: {status_line}")))?;
403
404 if !(200..300).contains(&status_code) {
405 return Err(TransportError::Handshake(format!(
406 "proxy refused CONNECT: {status_line}"
407 )));
408 }
409
410 Ok(())
411}
412
413async fn wrap_tls<S>(stream: S, server_name: &str) -> Result<TlsStream<S>, TransportError>
415where
416 S: AsyncRead + AsyncWrite + Unpin,
417{
418 let mut root_store = RootCertStore::empty();
419 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
420
421 let config = ClientConfig::builder()
422 .with_root_certificates(root_store)
423 .with_no_client_auth();
424
425 let connector = TlsConnector::from(std::sync::Arc::new(config));
426 let domain = ServerName::try_from(server_name.to_string())
427 .map_err(|e| TransportError::Tls(format!("invalid DNS name '{server_name}': {e}")))?;
428
429 connector
430 .connect(domain, stream)
431 .await
432 .map_err(TransportError::Io)
433}
434
435#[cfg(test)]
436#[cfg(not(feature = "turmoil"))] mod tests {
438 use std::net::SocketAddr;
439
440 use rstest::rstest;
441 use tokio::net::TcpListener;
442
443 use super::*;
444
445 #[rstest]
446 fn ws_target_parses_wss() {
447 let target = WsTarget::parse("wss://stream.binance.com:9443/ws/btcusdt@trade").unwrap();
448 assert_eq!(target.host, "stream.binance.com");
449 assert_eq!(target.port, 9443);
450 assert!(target.is_tls);
451 }
452
453 #[rstest]
454 fn ws_target_default_ports() {
455 let plain = WsTarget::parse("ws://example.com/path").unwrap();
456 assert_eq!(plain.port, 80);
457 assert!(!plain.is_tls);
458
459 let tls = WsTarget::parse("wss://example.com/path").unwrap();
460 assert_eq!(tls.port, 443);
461 assert!(tls.is_tls);
462 }
463
464 #[rstest]
465 fn ws_target_strips_ipv6_brackets() {
466 let target = WsTarget::parse("wss://[::1]:9443/ws").unwrap();
467 assert_eq!(target.host, "::1");
468 assert_eq!(target.port, 9443);
469 }
470
471 #[rstest]
472 fn ws_target_rejects_non_ws_scheme() {
473 let err = WsTarget::parse("https://example.com").unwrap_err();
474 assert!(matches!(err, TransportError::InvalidUrl(_)));
475 }
476
477 #[rstest]
478 fn proxy_target_parses_http() {
479 let proxy = ProxyTarget::parse("http://127.0.0.1:9999").unwrap();
480 assert_eq!(proxy.host, "127.0.0.1");
481 assert_eq!(proxy.port, 9999);
482 assert!(!proxy.is_tls);
483 assert!(proxy.auth_header.is_none());
484 }
485
486 #[rstest]
487 fn proxy_target_default_ports() {
488 let plain = ProxyTarget::parse("http://proxy.example.com").unwrap();
489 assert_eq!(plain.port, 80);
490 let tls = ProxyTarget::parse("https://proxy.example.com").unwrap();
491 assert_eq!(tls.port, 443);
492 assert!(tls.is_tls);
493 }
494
495 #[rstest]
496 fn proxy_target_basic_auth() {
497 let proxy =
498 ProxyTarget::parse("http://proxytest:fixture42@proxy.example.com:8080").unwrap();
499 assert_eq!(
501 proxy.auth_header.unwrap(),
502 "Basic cHJveHl0ZXN0OmZpeHR1cmU0Mg=="
503 );
504 }
505
506 #[rstest]
507 fn proxy_target_basic_auth_decodes_percent_encoded() {
508 let proxy = ProxyTarget::parse("http://us%2Fer:p%40ss@proxy.example.com:8080").unwrap();
510 let header = proxy.auth_header.unwrap();
511 assert_eq!(header, "Basic dXMvZXI6cEBzcw==");
513 }
514
515 #[rstest]
516 fn proxy_target_strips_ipv6_brackets() {
517 let proxy = ProxyTarget::parse("http://[::1]:8080").unwrap();
518 assert_eq!(proxy.host, "::1");
519 assert_eq!(proxy.port, 8080);
520 }
521
522 #[rstest]
523 fn proxy_target_rejects_socks() {
524 let err = ProxyTarget::parse("socks5://127.0.0.1:1080").unwrap_err();
525 let TransportError::InvalidUrl(msg) = err else {
526 panic!("expected InvalidUrl");
527 };
528 assert!(msg.contains("SOCKS"));
529 }
530
531 #[rstest]
532 fn proxy_kind_classifies_http() {
533 let kind = ProxyKind::parse("http://127.0.0.1:9999").unwrap();
534 assert!(matches!(kind, ProxyKind::Http(_)));
535 }
536
537 #[rstest]
538 fn proxy_kind_classifies_socks_as_unsupported() {
539 let kind = ProxyKind::parse("socks5://127.0.0.1:1080").unwrap();
540 let ProxyKind::Unsupported { scheme } = kind else {
541 panic!("expected Unsupported");
542 };
543 assert_eq!(scheme, "socks5");
544 }
545
546 #[rstest]
547 fn proxy_kind_rejects_garbage() {
548 assert!(ProxyKind::parse("ftp://x").is_err());
549 assert!(ProxyKind::parse("").is_err());
550 }
551
552 #[rstest]
553 fn proxy_kind_rejects_socks_without_authority() {
554 let err = ProxyKind::parse("socks5:127.0.0.1:1080").unwrap_err();
557 assert!(matches!(err, TransportError::InvalidUrl(_)));
558 }
559
560 #[rstest]
561 fn proxy_target_rejects_unknown_scheme() {
562 let err = ProxyTarget::parse("ftp://proxy.example.com").unwrap_err();
563 assert!(matches!(err, TransportError::InvalidUrl(_)));
564 }
565
566 #[rstest]
567 fn proxy_target_rejects_empty() {
568 let err = ProxyTarget::parse("").unwrap_err();
569 assert!(matches!(err, TransportError::InvalidUrl(_)));
570 }
571
572 #[rstest]
573 fn host_header_brackets_ipv6() {
574 assert_eq!(format_host_header("example.com", 443), "example.com:443");
575 assert_eq!(format_host_header("::1", 443), "[::1]:443");
576 assert_eq!(format_host_header("[::1]", 443), "[::1]:443");
577 }
578
579 async fn spawn_fake_proxy(response: &'static [u8]) -> SocketAddr {
582 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
583 let addr = listener.local_addr().unwrap();
584 tokio::spawn(async move {
585 let (mut stream, _) = listener.accept().await.unwrap();
586 let mut buf = vec![0u8; 1024];
587 loop {
589 let n = AsyncReadExt::read(&mut stream, &mut buf).await.unwrap();
590 if n == 0 {
591 break;
592 }
593
594 if buf[..n].windows(4).any(|w| w == b"\r\n\r\n") {
595 break;
596 }
597 }
598 stream.write_all(response).await.unwrap();
599 stream.flush().await.unwrap();
600 });
601 addr
602 }
603
604 #[tokio::test]
605 async fn read_connect_response_accepts_2xx() {
606 let addr = spawn_fake_proxy(b"HTTP/1.1 200 Connection established\r\n\r\n").await;
607 let mut stream = TcpStream::connect(addr).await.unwrap();
608 stream
609 .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
610 .await
611 .unwrap();
612 stream.flush().await.unwrap();
613 read_connect_response(&mut stream).await.unwrap();
614 }
615
616 #[tokio::test]
617 async fn read_connect_response_rejects_403() {
618 let addr = spawn_fake_proxy(b"HTTP/1.1 403 Forbidden\r\n\r\n").await;
619 let mut stream = TcpStream::connect(addr).await.unwrap();
620 stream
621 .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
622 .await
623 .unwrap();
624 stream.flush().await.unwrap();
625 let err = read_connect_response(&mut stream).await.unwrap_err();
626 let TransportError::Handshake(msg) = err else {
627 panic!("expected Handshake error");
628 };
629 assert!(msg.contains("403"));
630 }
631
632 #[rstest]
637 #[case::status_300(&b"HTTP/1.1 300 Multiple Choices\r\n\r\n"[..], "300")]
638 #[case::status_407(
639 &b"HTTP/1.1 407 Proxy Authentication Required\r\nProxy-Authenticate: Basic\r\n\r\n"[..],
640 "407",
641 )]
642 #[case::malformed_status(&b"HTTP/1.1 abc Boom\r\n\r\n"[..], "non-numeric")]
643 #[tokio::test]
644 async fn read_connect_response_rejects_non_2xx(
645 #[case] response: &'static [u8],
646 #[case] expected_msg_substring: &'static str,
647 ) {
648 let addr = spawn_fake_proxy(response).await;
649 let mut stream = TcpStream::connect(addr).await.unwrap();
650 stream
651 .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
652 .await
653 .unwrap();
654 stream.flush().await.unwrap();
655 let err = read_connect_response(&mut stream).await.unwrap_err();
656 let TransportError::Handshake(msg) = err else {
657 panic!("expected Handshake error, was {err:?}");
658 };
659 assert!(
660 msg.contains(expected_msg_substring),
661 "expected error message to contain {expected_msg_substring:?}, was {msg:?}"
662 );
663 }
664
665 #[tokio::test]
668 async fn read_connect_response_rejects_eof_before_terminator() {
669 let addr = spawn_fake_proxy(b"HTTP/1.1 200 OK\r\n").await;
671 let mut stream = TcpStream::connect(addr).await.unwrap();
672 stream
673 .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
674 .await
675 .unwrap();
676 stream.flush().await.unwrap();
677 let err = read_connect_response(&mut stream).await.unwrap_err();
678 let TransportError::Handshake(msg) = err else {
679 panic!("expected Handshake error, was {err:?}");
680 };
681 assert!(
682 msg.contains("closed connection"),
683 "unexpected handshake error: {msg}"
684 );
685 }
686
687 #[tokio::test]
690 async fn read_connect_response_rejects_oversize_headers() {
691 let mut response = b"HTTP/1.1 200 OK\r\n".to_vec();
692 while response.len() <= MAX_PROXY_RESPONSE_BYTES {
693 response.extend_from_slice(b"X-Pad: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\r\n");
694 }
695 let leaked: &'static [u8] = response.leak();
696 let addr = spawn_fake_proxy(leaked).await;
697 let mut stream = TcpStream::connect(addr).await.unwrap();
698 stream
699 .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
700 .await
701 .unwrap();
702 stream.flush().await.unwrap();
703 let err = read_connect_response(&mut stream).await.unwrap_err();
704 let TransportError::Handshake(msg) = err else {
705 panic!("expected Handshake error, was {err:?}");
706 };
707 assert!(
708 msg.contains("exceeded"),
709 "unexpected handshake error: {msg}"
710 );
711 }
712
713 #[tokio::test]
717 async fn read_connect_response_preserves_trailing_bytes() {
718 let addr = spawn_fake_proxy(b"HTTP/1.1 200 Connection established\r\n\r\nLEFTOVER").await;
719 let mut stream = TcpStream::connect(addr).await.unwrap();
720 stream
721 .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
722 .await
723 .unwrap();
724 stream.flush().await.unwrap();
725 read_connect_response(&mut stream).await.unwrap();
726
727 let mut tail = [0u8; b"LEFTOVER".len()];
728 AsyncReadExt::read_exact(&mut stream, &mut tail)
729 .await
730 .unwrap();
731 assert_eq!(&tail, b"LEFTOVER");
732 }
733}