Skip to main content

nautilus_network/websocket/
proxy.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Proxy support for outbound WebSocket connections.
17//!
18//! Implements HTTP `CONNECT` tunneling so a `WebSocketClient` can be reached
19//! through an HTTP or HTTPS forward proxy. The same `proxy_url` field is used
20//! by the HTTP client (via `reqwest::Proxy::all`), keeping a single config
21//! field for both transports.
22//!
23//! `socks5://` / `socks5h://` URLs are recognized but not yet implemented
24//! for the WebSocket path. The dispatcher logs a warning and falls back to
25//! a direct connection so that REST configs that already point at a SOCKS
26//! proxy keep working unchanged. SOCKS support requires the optional
27//! `tokio-socks` crate, which is not yet a workspace dependency.
28//!
29//! The tunnel is established as follows:
30//! 1. TCP connect to the proxy host / port.
31//! 2. If the proxy URL scheme is `https`, layer TLS using the proxy host as
32//!    the SNI and certificate domain.
33//! 3. Send `CONNECT target_host:target_port HTTP/1.1` plus the matching
34//!    `Host:` header (and optional `Proxy-Authorization:` derived from the
35//!    proxy URL user-info).
36//! 4. Read the response line and headers; require a `2xx` status.
37//! 5. If the upstream WebSocket scheme is `wss`, layer a second TLS session
38//!    using the upstream host name.
39//! 6. Hand the resulting stream to `tokio-tungstenite`'s `client_async` so the
40//!    WebSocket handshake completes over the tunnel.
41
42use std::fmt::Write as _;
43
44use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
45use rustls::{ClientConfig, RootCertStore, pki_types::ServerName};
46use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
47use tokio_rustls::{TlsConnector, client::TlsStream};
48use url::Url;
49
50use crate::{net::TcpStream, transport::TransportError};
51
52/// Maximum size of a `CONNECT` proxy response we are willing to read.
53///
54/// Bounds the buffer so a malicious or broken proxy cannot make us allocate
55/// indefinitely while we wait for the header terminator.
56const MAX_PROXY_RESPONSE_BYTES: usize = 16 * 1024;
57
58/// Stream produced by `tunnel_via_proxy` when the upstream is `ws://`
59/// (no upstream TLS, but the proxy hop itself may have been TLS-protected).
60///
61/// The TLS-bearing variants are boxed because [`tokio_rustls::client::TlsStream`]
62/// is large enough that a flat enum trips `clippy::large_enum_variant`. Boxing
63/// keeps the discriminant cheap to move while leaving the rare TLS path on the
64/// heap.
65#[derive(Debug)]
66pub enum ProxiedStream {
67    /// Plain TCP after a plain proxy hop.
68    Plain(TcpStream),
69    /// Plain TCP after a TLS proxy hop.
70    PlainOverTlsProxy(Box<TlsStream<TcpStream>>),
71    /// Upstream TLS over a plain proxy hop.
72    Tls(Box<TlsStream<TcpStream>>),
73    /// Upstream TLS over a TLS proxy hop.
74    TlsOverTlsProxy(Box<TlsStream<TlsStream<TcpStream>>>),
75}
76
77/// Parsed components of a target WebSocket URL needed by the proxy hop.
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub struct WsTarget {
80    /// Host name for DNS / SNI / `CONNECT` request line.
81    pub host: String,
82    /// TCP port of the WebSocket origin.
83    pub port: u16,
84    /// `true` when the WebSocket scheme is `wss://`.
85    pub is_tls: bool,
86}
87
88impl WsTarget {
89    /// Parse a `ws://` or `wss://` URL into the host/port/TLS components.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`TransportError::InvalidUrl`] when the URL fails to parse,
94    /// is missing a hostname, or uses a scheme other than `ws`/`wss`.
95    pub fn parse(url: &str) -> Result<Self, TransportError> {
96        let parsed =
97            Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
98
99        let is_tls = match parsed.scheme() {
100            "ws" => false,
101            "wss" => true,
102            other => {
103                return Err(TransportError::InvalidUrl(format!(
104                    "expected ws:// or wss:// scheme, was {other}"
105                )));
106            }
107        };
108
109        let raw_host = parsed
110            .host_str()
111            .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
112
113        // url::Url stores IPv6 literals in bracketed form (`[::1]`); the
114        // `CONNECT` request line and TLS SNI both want the unbracketed form.
115        let host = if raw_host.starts_with('[') && raw_host.ends_with(']') {
116            raw_host[1..raw_host.len() - 1].to_string()
117        } else {
118            raw_host.to_string()
119        };
120
121        let port = parsed.port().unwrap_or(if is_tls { 443 } else { 80 });
122
123        Ok(Self { host, port, is_tls })
124    }
125}
126
127/// Outcome of parsing a proxy URL prior to opening a tunnel.
128///
129/// SOCKS schemes are recognized but not implemented for the WebSocket path
130/// yet. They are surfaced as [`ProxyKind::Unsupported`] so callers can log
131/// a warning and fall back to a direct connection, preserving compatibility
132/// with REST configs that already pointed at a SOCKS proxy.
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum ProxyKind {
135    /// HTTP / HTTPS forward proxy reachable via `CONNECT` tunneling.
136    Http(ProxyTarget),
137    /// Recognized scheme without a working tunnel (currently SOCKS).
138    Unsupported {
139        /// Original URL scheme (e.g. `socks5`).
140        scheme: String,
141    },
142}
143
144impl ProxyKind {
145    /// Parse a proxy URL into a [`ProxyKind`]. Returns
146    /// [`TransportError::InvalidUrl`] for malformed input or non-proxy
147    /// schemes (`ftp://`, `ws://`, etc.).
148    ///
149    /// # Errors
150    ///
151    /// See [`ProxyTarget::parse`] for the underlying validation.
152    pub fn parse(url: &str) -> Result<Self, TransportError> {
153        let parsed =
154            Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
155
156        match parsed.scheme() {
157            "http" | "https" => ProxyTarget::parse(url).map(ProxyKind::Http),
158            scheme @ ("socks5" | "socks5h" | "socks4" | "socks4a") => {
159                // Reject malformed inputs like `socks5:host:port` that parse as
160                // scheme + opaque path with no authority: surfacing them as
161                // Unsupported would silently fall back to a direct connection
162                // and hide the typo.
163                if parsed.host_str().is_none_or(str::is_empty) {
164                    return Err(TransportError::InvalidUrl(format!(
165                        "proxy URL '{url}' is missing a host (did you mean {scheme}://...)?"
166                    )));
167                }
168                Ok(Self::Unsupported {
169                    scheme: scheme.to_string(),
170                })
171            }
172            other => Err(TransportError::InvalidUrl(format!(
173                "unsupported proxy scheme '{other}'; expected http:// or https://"
174            ))),
175        }
176    }
177}
178
179/// Parsed components of a forward proxy URL.
180#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct ProxyTarget {
182    /// Host name of the proxy (used for both DNS and TLS SNI when
183    /// [`ProxyTarget::is_tls`] is `true`).
184    pub host: String,
185    /// TCP port of the proxy.
186    pub port: u16,
187    /// `true` when the proxy URL scheme is `https`.
188    pub is_tls: bool,
189    /// Pre-computed `Proxy-Authorization` header value, if the URL embeds
190    /// `user:pass@`.
191    pub auth_header: Option<String>,
192}
193
194impl ProxyTarget {
195    /// Parse a proxy URL into the components needed to establish the tunnel.
196    ///
197    /// Only `http://` and `https://` schemes are accepted here. Use
198    /// [`ProxyKind::parse`] when callers need to distinguish recognised but
199    /// unsupported schemes (currently SOCKS) from malformed input.
200    ///
201    /// # Errors
202    ///
203    /// Returns [`TransportError::InvalidUrl`] for malformed URLs, missing
204    /// hosts, or any scheme other than `http`/`https`.
205    pub fn parse(url: &str) -> Result<Self, TransportError> {
206        let parsed =
207            Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
208
209        let is_tls = match parsed.scheme() {
210            "http" => false,
211            "https" => true,
212            "socks5" | "socks5h" | "socks4" | "socks4a" => {
213                return Err(TransportError::InvalidUrl(format!(
214                    "SOCKS proxy scheme '{}' is not yet supported for WebSocket connections; \
215                    use an http:// or https:// proxy",
216                    parsed.scheme()
217                )));
218            }
219            other => {
220                return Err(TransportError::InvalidUrl(format!(
221                    "unsupported proxy scheme '{other}'; expected http:// or https://"
222                )));
223            }
224        };
225
226        let raw_host = parsed
227            .host_str()
228            .ok_or_else(|| TransportError::InvalidUrl("proxy URL missing hostname".to_string()))?;
229
230        // url::Url stores IPv6 literals bracketed (`[::1]`); the bracketed
231        // form is only valid in the HTTP `Host:` header, not for DNS or
232        // TLS SNI, so we keep both representations.
233        let host = if raw_host.starts_with('[') && raw_host.ends_with(']') {
234            raw_host[1..raw_host.len() - 1].to_string()
235        } else {
236            raw_host.to_string()
237        };
238
239        let port = parsed.port().unwrap_or(if is_tls { 443 } else { 80 });
240
241        let auth_header = if parsed.username().is_empty() {
242            None
243        } else {
244            let username = decode_userinfo(parsed.username());
245            let password = decode_userinfo(parsed.password().unwrap_or(""));
246            let credentials = format!("{username}:{password}");
247            Some(format!("Basic {}", BASE64.encode(credentials)))
248        };
249
250        Ok(Self {
251            host,
252            port,
253            is_tls,
254            auth_header,
255        })
256    }
257}
258
259/// Percent-decode a userinfo field from a proxy URL. `url::Url` keeps the
260/// raw percent-encoded form, so we decode it here before assembling the
261/// `Basic` credentials.
262fn decode_userinfo(value: &str) -> String {
263    let bytes = nautilus_core::string::urlencoding::decode_bytes(value.as_bytes());
264    String::from_utf8_lossy(&bytes).into_owned()
265}
266
267/// Establish a tunneled connection through `proxy` to the WebSocket `target`.
268///
269/// On success the returned stream is positioned right after the proxy's
270/// `200`/`2xx` response, ready for the WebSocket handshake. The function does
271/// not perform the WebSocket handshake itself; callers wrap the stream in
272/// `tokio-tungstenite::client_async`.
273///
274/// # Errors
275///
276/// Returns a [`TransportError`] when:
277/// - The TCP connection to the proxy fails ([`TransportError::Io`]).
278/// - The TLS layer to the proxy or upstream cannot be established
279///   ([`TransportError::Tls`]).
280/// - The proxy returns a non-success status, malformed headers, or closes the
281///   stream before completing the response ([`TransportError::Handshake`]).
282pub async fn tunnel_via_proxy(
283    target: &WsTarget,
284    proxy: &ProxyTarget,
285) -> Result<ProxiedStream, TransportError> {
286    let tcp = TcpStream::connect((proxy.host.as_str(), proxy.port))
287        .await
288        .map_err(TransportError::Io)?;
289
290    if let Err(e) = tcp.set_nodelay(true) {
291        log::warn!("Failed to enable TCP_NODELAY on proxy connection: {e:?}");
292    }
293
294    if proxy.is_tls {
295        let proxy_tls = wrap_tls(tcp, &proxy.host).await?;
296        let tunneled = send_connect(proxy_tls, target, proxy).await?;
297        if target.is_tls {
298            let upstream = wrap_tls(tunneled, &target.host).await?;
299            Ok(ProxiedStream::TlsOverTlsProxy(Box::new(upstream)))
300        } else {
301            Ok(ProxiedStream::PlainOverTlsProxy(Box::new(tunneled)))
302        }
303    } else {
304        let tunneled = send_connect(tcp, target, proxy).await?;
305        if target.is_tls {
306            let upstream = wrap_tls(tunneled, &target.host).await?;
307            Ok(ProxiedStream::Tls(Box::new(upstream)))
308        } else {
309            Ok(ProxiedStream::Plain(tunneled))
310        }
311    }
312}
313
314/// Send a `CONNECT` request and return the underlying stream once a `2xx`
315/// status is received. The returned stream is positioned after the empty line
316/// terminating the proxy response headers.
317async fn send_connect<S>(
318    mut stream: S,
319    target: &WsTarget,
320    proxy: &ProxyTarget,
321) -> Result<S, TransportError>
322where
323    S: AsyncRead + AsyncWrite + Unpin,
324{
325    let host_header = format_host_header(&target.host, target.port);
326    let mut request = format!(
327        "CONNECT {host_header} HTTP/1.1\r\n\
328         Host: {host_header}\r\n\
329         Proxy-Connection: Keep-Alive\r\n"
330    );
331
332    if let Some(auth) = &proxy.auth_header {
333        write!(request, "Proxy-Authorization: {auth}\r\n").expect("writing to String never fails");
334    }
335    request.push_str("\r\n");
336
337    stream
338        .write_all(request.as_bytes())
339        .await
340        .map_err(TransportError::Io)?;
341    stream.flush().await.map_err(TransportError::Io)?;
342
343    read_connect_response(&mut stream).await?;
344    Ok(stream)
345}
346
347fn format_host_header(host: &str, port: u16) -> String {
348    if host.contains(':') && !(host.starts_with('[') && host.ends_with(']')) {
349        format!("[{host}]:{port}")
350    } else {
351        format!("{host}:{port}")
352    }
353}
354
355/// Read the proxy's response up to the empty line that terminates the
356/// headers, validating the status line.
357async fn read_connect_response<S>(stream: &mut S) -> Result<(), TransportError>
358where
359    S: AsyncRead + Unpin,
360{
361    let mut buf = Vec::with_capacity(512);
362    let mut byte = [0u8; 1];
363
364    loop {
365        let n = stream.read(&mut byte).await.map_err(TransportError::Io)?;
366        if n == 0 {
367            return Err(TransportError::Handshake(
368                "proxy closed connection before sending CONNECT response".to_string(),
369            ));
370        }
371
372        buf.push(byte[0]);
373
374        if buf.ends_with(b"\r\n\r\n") {
375            break;
376        }
377
378        if buf.len() > MAX_PROXY_RESPONSE_BYTES {
379            return Err(TransportError::Handshake(format!(
380                "proxy CONNECT response exceeded {MAX_PROXY_RESPONSE_BYTES} bytes without terminator"
381            )));
382        }
383    }
384
385    let text = std::str::from_utf8(&buf).map_err(|_| {
386        TransportError::Handshake("proxy CONNECT response was not valid UTF-8".to_string())
387    })?;
388
389    let status_line = text.lines().next().ok_or_else(|| {
390        TransportError::Handshake("proxy CONNECT response missing status line".to_string())
391    })?;
392
393    // Expect: `HTTP/1.1 200 Connection established` (or any 2xx).
394    let mut parts = status_line.splitn(3, ' ');
395    let _version = parts.next().ok_or_else(|| {
396        TransportError::Handshake(format!("malformed status line: {status_line}"))
397    })?;
398    let status_code = parts
399        .next()
400        .ok_or_else(|| TransportError::Handshake(format!("malformed status line: {status_line}")))?
401        .parse::<u16>()
402        .map_err(|_| TransportError::Handshake(format!("non-numeric status: {status_line}")))?;
403
404    if !(200..300).contains(&status_code) {
405        return Err(TransportError::Handshake(format!(
406            "proxy refused CONNECT: {status_line}"
407        )));
408    }
409
410    Ok(())
411}
412
413/// Wrap a stream in a `rustls`-backed TLS session using `webpki_roots`.
414async fn wrap_tls<S>(stream: S, server_name: &str) -> Result<TlsStream<S>, TransportError>
415where
416    S: AsyncRead + AsyncWrite + Unpin,
417{
418    let mut root_store = RootCertStore::empty();
419    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
420
421    let config = ClientConfig::builder()
422        .with_root_certificates(root_store)
423        .with_no_client_auth();
424
425    let connector = TlsConnector::from(std::sync::Arc::new(config));
426    let domain = ServerName::try_from(server_name.to_string())
427        .map_err(|e| TransportError::Tls(format!("invalid DNS name '{server_name}': {e}")))?;
428
429    connector
430        .connect(domain, stream)
431        .await
432        .map_err(TransportError::Io)
433}
434
435#[cfg(test)]
436#[cfg(not(feature = "turmoil"))] // proxy hop is not modelled under the turmoil simulator
437mod tests {
438    use std::net::SocketAddr;
439
440    use rstest::rstest;
441    use tokio::net::TcpListener;
442
443    use super::*;
444
445    #[rstest]
446    fn ws_target_parses_wss() {
447        let target = WsTarget::parse("wss://stream.binance.com:9443/ws/btcusdt@trade").unwrap();
448        assert_eq!(target.host, "stream.binance.com");
449        assert_eq!(target.port, 9443);
450        assert!(target.is_tls);
451    }
452
453    #[rstest]
454    fn ws_target_default_ports() {
455        let plain = WsTarget::parse("ws://example.com/path").unwrap();
456        assert_eq!(plain.port, 80);
457        assert!(!plain.is_tls);
458
459        let tls = WsTarget::parse("wss://example.com/path").unwrap();
460        assert_eq!(tls.port, 443);
461        assert!(tls.is_tls);
462    }
463
464    #[rstest]
465    fn ws_target_strips_ipv6_brackets() {
466        let target = WsTarget::parse("wss://[::1]:9443/ws").unwrap();
467        assert_eq!(target.host, "::1");
468        assert_eq!(target.port, 9443);
469    }
470
471    #[rstest]
472    fn ws_target_rejects_non_ws_scheme() {
473        let err = WsTarget::parse("https://example.com").unwrap_err();
474        assert!(matches!(err, TransportError::InvalidUrl(_)));
475    }
476
477    #[rstest]
478    fn proxy_target_parses_http() {
479        let proxy = ProxyTarget::parse("http://127.0.0.1:9999").unwrap();
480        assert_eq!(proxy.host, "127.0.0.1");
481        assert_eq!(proxy.port, 9999);
482        assert!(!proxy.is_tls);
483        assert!(proxy.auth_header.is_none());
484    }
485
486    #[rstest]
487    fn proxy_target_default_ports() {
488        let plain = ProxyTarget::parse("http://proxy.example.com").unwrap();
489        assert_eq!(plain.port, 80);
490        let tls = ProxyTarget::parse("https://proxy.example.com").unwrap();
491        assert_eq!(tls.port, 443);
492        assert!(tls.is_tls);
493    }
494
495    #[rstest]
496    fn proxy_target_basic_auth() {
497        let proxy =
498            ProxyTarget::parse("http://proxytest:fixture42@proxy.example.com:8080").unwrap();
499        // base64("proxytest:fixture42") == "cHJveHl0ZXN0OmZpeHR1cmU0Mg=="
500        assert_eq!(
501            proxy.auth_header.unwrap(),
502            "Basic cHJveHl0ZXN0OmZpeHR1cmU0Mg=="
503        );
504    }
505
506    #[rstest]
507    fn proxy_target_basic_auth_decodes_percent_encoded() {
508        // `p%40ss` should decode to `p@ss` before assembling Basic credentials
509        let proxy = ProxyTarget::parse("http://us%2Fer:p%40ss@proxy.example.com:8080").unwrap();
510        let header = proxy.auth_header.unwrap();
511        // base64("us/er:p@ss") == "dXMvZXI6cEBzcw=="
512        assert_eq!(header, "Basic dXMvZXI6cEBzcw==");
513    }
514
515    #[rstest]
516    fn proxy_target_strips_ipv6_brackets() {
517        let proxy = ProxyTarget::parse("http://[::1]:8080").unwrap();
518        assert_eq!(proxy.host, "::1");
519        assert_eq!(proxy.port, 8080);
520    }
521
522    #[rstest]
523    fn proxy_target_rejects_socks() {
524        let err = ProxyTarget::parse("socks5://127.0.0.1:1080").unwrap_err();
525        let TransportError::InvalidUrl(msg) = err else {
526            panic!("expected InvalidUrl");
527        };
528        assert!(msg.contains("SOCKS"));
529    }
530
531    #[rstest]
532    fn proxy_kind_classifies_http() {
533        let kind = ProxyKind::parse("http://127.0.0.1:9999").unwrap();
534        assert!(matches!(kind, ProxyKind::Http(_)));
535    }
536
537    #[rstest]
538    fn proxy_kind_classifies_socks_as_unsupported() {
539        let kind = ProxyKind::parse("socks5://127.0.0.1:1080").unwrap();
540        let ProxyKind::Unsupported { scheme } = kind else {
541            panic!("expected Unsupported");
542        };
543        assert_eq!(scheme, "socks5");
544    }
545
546    #[rstest]
547    fn proxy_kind_rejects_garbage() {
548        assert!(ProxyKind::parse("ftp://x").is_err());
549        assert!(ProxyKind::parse("").is_err());
550    }
551
552    #[rstest]
553    fn proxy_kind_rejects_socks_without_authority() {
554        // `socks5:host:port` (no `//`) parses as scheme + opaque path; surface
555        // as a real error instead of a silent direct-fallback.
556        let err = ProxyKind::parse("socks5:127.0.0.1:1080").unwrap_err();
557        assert!(matches!(err, TransportError::InvalidUrl(_)));
558    }
559
560    #[rstest]
561    fn proxy_target_rejects_unknown_scheme() {
562        let err = ProxyTarget::parse("ftp://proxy.example.com").unwrap_err();
563        assert!(matches!(err, TransportError::InvalidUrl(_)));
564    }
565
566    #[rstest]
567    fn proxy_target_rejects_empty() {
568        let err = ProxyTarget::parse("").unwrap_err();
569        assert!(matches!(err, TransportError::InvalidUrl(_)));
570    }
571
572    #[rstest]
573    fn host_header_brackets_ipv6() {
574        assert_eq!(format_host_header("example.com", 443), "example.com:443");
575        assert_eq!(format_host_header("::1", 443), "[::1]:443");
576        assert_eq!(format_host_header("[::1]", 443), "[::1]:443");
577    }
578
579    /// Spawn a fake HTTP proxy that returns the configured response after
580    /// reading one CONNECT request line. Returns the bound address.
581    async fn spawn_fake_proxy(response: &'static [u8]) -> SocketAddr {
582        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
583        let addr = listener.local_addr().unwrap();
584        tokio::spawn(async move {
585            let (mut stream, _) = listener.accept().await.unwrap();
586            let mut buf = vec![0u8; 1024];
587            // Read until we see the CONNECT terminator.
588            loop {
589                let n = AsyncReadExt::read(&mut stream, &mut buf).await.unwrap();
590                if n == 0 {
591                    break;
592                }
593
594                if buf[..n].windows(4).any(|w| w == b"\r\n\r\n") {
595                    break;
596                }
597            }
598            stream.write_all(response).await.unwrap();
599            stream.flush().await.unwrap();
600        });
601        addr
602    }
603
604    #[tokio::test]
605    async fn read_connect_response_accepts_2xx() {
606        let addr = spawn_fake_proxy(b"HTTP/1.1 200 Connection established\r\n\r\n").await;
607        let mut stream = TcpStream::connect(addr).await.unwrap();
608        stream
609            .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
610            .await
611            .unwrap();
612        stream.flush().await.unwrap();
613        read_connect_response(&mut stream).await.unwrap();
614    }
615
616    #[tokio::test]
617    async fn read_connect_response_rejects_403() {
618        let addr = spawn_fake_proxy(b"HTTP/1.1 403 Forbidden\r\n\r\n").await;
619        let mut stream = TcpStream::connect(addr).await.unwrap();
620        stream
621            .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
622            .await
623            .unwrap();
624        stream.flush().await.unwrap();
625        let err = read_connect_response(&mut stream).await.unwrap_err();
626        let TransportError::Handshake(msg) = err else {
627            panic!("expected Handshake error");
628        };
629        assert!(msg.contains("403"));
630    }
631
632    /// 300 sits on the upper boundary of the accepted `200..300` range; if
633    /// the check is ever loosened to `200..=300` this test fails. 407 is the
634    /// classic "Proxy Authentication Required" response. Non-numeric status
635    /// probes the parse path.
636    #[rstest]
637    #[case::status_300(&b"HTTP/1.1 300 Multiple Choices\r\n\r\n"[..], "300")]
638    #[case::status_407(
639        &b"HTTP/1.1 407 Proxy Authentication Required\r\nProxy-Authenticate: Basic\r\n\r\n"[..],
640        "407",
641    )]
642    #[case::malformed_status(&b"HTTP/1.1 abc Boom\r\n\r\n"[..], "non-numeric")]
643    #[tokio::test]
644    async fn read_connect_response_rejects_non_2xx(
645        #[case] response: &'static [u8],
646        #[case] expected_msg_substring: &'static str,
647    ) {
648        let addr = spawn_fake_proxy(response).await;
649        let mut stream = TcpStream::connect(addr).await.unwrap();
650        stream
651            .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
652            .await
653            .unwrap();
654        stream.flush().await.unwrap();
655        let err = read_connect_response(&mut stream).await.unwrap_err();
656        let TransportError::Handshake(msg) = err else {
657            panic!("expected Handshake error, was {err:?}");
658        };
659        assert!(
660            msg.contains(expected_msg_substring),
661            "expected error message to contain {expected_msg_substring:?}, was {msg:?}"
662        );
663    }
664
665    /// Closing the connection mid-response should produce a clear handshake
666    /// error rather than spinning on a zero-byte read.
667    #[tokio::test]
668    async fn read_connect_response_rejects_eof_before_terminator() {
669        // Truncated response: missing the empty line that ends the headers
670        let addr = spawn_fake_proxy(b"HTTP/1.1 200 OK\r\n").await;
671        let mut stream = TcpStream::connect(addr).await.unwrap();
672        stream
673            .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
674            .await
675            .unwrap();
676        stream.flush().await.unwrap();
677        let err = read_connect_response(&mut stream).await.unwrap_err();
678        let TransportError::Handshake(msg) = err else {
679            panic!("expected Handshake error, was {err:?}");
680        };
681        assert!(
682            msg.contains("closed connection"),
683            "unexpected handshake error: {msg}"
684        );
685    }
686
687    /// A proxy that streams headers without ever emitting `\r\n\r\n` should
688    /// trip the size cap rather than allocating without bound.
689    #[tokio::test]
690    async fn read_connect_response_rejects_oversize_headers() {
691        let mut response = b"HTTP/1.1 200 OK\r\n".to_vec();
692        while response.len() <= MAX_PROXY_RESPONSE_BYTES {
693            response.extend_from_slice(b"X-Pad: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\r\n");
694        }
695        let leaked: &'static [u8] = response.leak();
696        let addr = spawn_fake_proxy(leaked).await;
697        let mut stream = TcpStream::connect(addr).await.unwrap();
698        stream
699            .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
700            .await
701            .unwrap();
702        stream.flush().await.unwrap();
703        let err = read_connect_response(&mut stream).await.unwrap_err();
704        let TransportError::Handshake(msg) = err else {
705            panic!("expected Handshake error, was {err:?}");
706        };
707        assert!(
708            msg.contains("exceeded"),
709            "unexpected handshake error: {msg}"
710        );
711    }
712
713    /// After accepting the 2xx response, the stream cursor must sit immediately
714    /// after the terminating `\r\n\r\n` so the WebSocket handshake can read its
715    /// own response. Regression guard against over-reading the terminator.
716    #[tokio::test]
717    async fn read_connect_response_preserves_trailing_bytes() {
718        let addr = spawn_fake_proxy(b"HTTP/1.1 200 Connection established\r\n\r\nLEFTOVER").await;
719        let mut stream = TcpStream::connect(addr).await.unwrap();
720        stream
721            .write_all(b"CONNECT host:443 HTTP/1.1\r\nHost: host:443\r\n\r\n")
722            .await
723            .unwrap();
724        stream.flush().await.unwrap();
725        read_connect_response(&mut stream).await.unwrap();
726
727        let mut tail = [0u8; b"LEFTOVER".len()];
728        AsyncReadExt::read_exact(&mut stream, &mut tail)
729            .await
730            .unwrap();
731        assert_eq!(&tail, b"LEFTOVER");
732    }
733}