Skip to main content

nautilus_network/transport/
sockudo.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! `sockudo-ws` backend for the transport abstraction.
17//!
18//! Mirrors the layout of the [`tungstenite`](super::tungstenite) module: provides
19//! `From`/`TryFrom` conversions between the neutral [`Message`] / [`TransportError`]
20//! and sockudo's native types, plus a [`SockudoTransport<S>`] adapter that lifts a
21//! sockudo [`WebSocketStream<S>`] into the backend-agnostic [`WsTransport`] trait.
22//!
23//! The `Message` enums are structurally identical: both carry payloads as `bytes::Bytes`
24//! across all five variants, so conversions are zero-copy and infallible.
25//!
26//! sockudo's public HTTP/1.1 client API does not expose custom headers, so this
27//! module provides a small handshake helper for upgrade requests that need them.
28
29use std::{
30    pin::Pin,
31    task::{Context, Poll},
32};
33
34use bytes::{BufMut, Bytes, BytesMut};
35use futures::{Sink, Stream};
36use sockudo_ws::{
37    HandshakeResult,
38    error::{CloseReason as SockudoCloseReason, Error as SockudoError},
39    handshake,
40    protocol::Message as SockudoMessage,
41    stream::WebSocketStream,
42};
43use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
44
45use super::{
46    error::TransportError,
47    message::{CloseFrame, Message},
48    stream::WsTransport,
49};
50
51const MAX_HTTP_HEADER_SIZE: usize = 8192;
52
53// WebSocket upgrade headers we always set, plus body-framing headers that have
54// no place on a GET upgrade.
55const RESERVED_UPGRADE_HEADERS: &[&str] = &[
56    "host",
57    "upgrade",
58    "connection",
59    "sec-websocket-key",
60    "sec-websocket-version",
61    "sec-websocket-protocol",
62    "sec-websocket-extensions",
63    "content-length",
64    "transfer-encoding",
65    "te",
66    "trailer",
67];
68
69/// Mirror of `sockudo_ws::handshake::client_handshake` (1.7.4) with custom headers.
70///
71/// Caller pre-validates `extra_headers` via [`validate_extra_headers`].
72pub(crate) async fn client_handshake_with_headers<S>(
73    stream: &mut S,
74    host: &str,
75    path: &str,
76    protocol: Option<&str>,
77    extra_headers: &[(String, String)],
78) -> Result<HandshakeResult, SockudoError>
79where
80    S: AsyncRead + AsyncWrite + Unpin,
81{
82    use tokio::io::{AsyncReadExt, AsyncWriteExt};
83
84    let key = handshake::generate_key();
85    let request = build_request_with_headers(host, path, &key, protocol, None, extra_headers);
86
87    stream.write_all(&request).await?;
88    stream.flush().await?;
89
90    let mut buf = BytesMut::with_capacity(4096);
91
92    loop {
93        if buf.len() > MAX_HTTP_HEADER_SIZE {
94            return Err(SockudoError::InvalidHttp("response too large"));
95        }
96
97        let n = stream.read_buf(&mut buf).await?;
98        if n == 0 {
99            return Err(SockudoError::ConnectionClosed);
100        }
101
102        let parsed = match handshake::parse_response(&buf) {
103            Ok(parsed) => parsed,
104            Err(e) => {
105                log_handshake_response(host, path, &e, &buf);
106                return Err(e);
107            }
108        };
109
110        if let Some((res, consumed)) = parsed {
111            let accept = res.accept.ok_or_else(|| {
112                let e = SockudoError::HandshakeFailed("missing Sec-WebSocket-Accept");
113                log_handshake_response(host, path, &e, &buf);
114                e
115            })?;
116
117            if !handshake::validate_accept_key(&key, accept) {
118                let e = SockudoError::HandshakeFailed("invalid Sec-WebSocket-Accept");
119                log_handshake_response(host, path, &e, &buf);
120                return Err(e);
121            }
122
123            let res_protocol = res.protocol.map(String::from);
124            let res_extensions = res.extensions.map(String::from);
125            let leftover = if consumed < buf.len() {
126                Some(buf.split_off(consumed).freeze())
127            } else {
128                None
129            };
130
131            return Ok(HandshakeResult {
132                path: path.to_string(),
133                protocol: res_protocol,
134                extensions: res_extensions,
135                leftover,
136            });
137        }
138    }
139}
140
141// Surface the upstream HTTP response on parse failure so non-101 statuses are visible.
142fn log_handshake_response(host: &str, path: &str, err: &SockudoError, buf: &BytesMut) {
143    const PREVIEW_BYTES: usize = 512;
144    let take = buf.len().min(PREVIEW_BYTES);
145    let preview = String::from_utf8_lossy(&buf[..take]);
146    let truncated = if buf.len() > take { " (truncated)" } else { "" };
147    log::error!(
148        "Sockudo handshake failed for {host}{path}: {err}; response{truncated}:\n{preview}"
149    );
150}
151
152// Mirror of `sockudo_ws::handshake::build_request` (1.7.4) with `extra_headers`
153// appended; caller pre-validates.
154fn build_request_with_headers(
155    host: &str,
156    path: &str,
157    key: &str,
158    protocol: Option<&str>,
159    extensions: Option<&str>,
160    extra_headers: &[(String, String)],
161) -> Bytes {
162    let mut buf = BytesMut::with_capacity(512);
163
164    buf.put_slice(b"GET ");
165    buf.put_slice(path.as_bytes());
166    buf.put_slice(b" HTTP/1.1\r\n");
167    buf.put_slice(b"Host: ");
168    buf.put_slice(host.as_bytes());
169    buf.put_slice(b"\r\n");
170    buf.put_slice(b"Upgrade: websocket\r\n");
171    buf.put_slice(b"Connection: Upgrade\r\n");
172    buf.put_slice(b"Sec-WebSocket-Key: ");
173    buf.put_slice(key.as_bytes());
174    buf.put_slice(b"\r\n");
175    buf.put_slice(b"Sec-WebSocket-Version: 13\r\n");
176
177    if let Some(proto) = protocol {
178        buf.put_slice(b"Sec-WebSocket-Protocol: ");
179        buf.put_slice(proto.as_bytes());
180        buf.put_slice(b"\r\n");
181    }
182
183    if let Some(ext) = extensions {
184        buf.put_slice(b"Sec-WebSocket-Extensions: ");
185        buf.put_slice(ext.as_bytes());
186        buf.put_slice(b"\r\n");
187    }
188
189    for (name, value) in extra_headers {
190        buf.put_slice(name.as_bytes());
191        buf.put_slice(b": ");
192        buf.put_slice(value.as_bytes());
193        buf.put_slice(b"\r\n");
194    }
195
196    buf.put_slice(b"\r\n");
197    buf.freeze()
198}
199
200pub(crate) fn validate_extra_headers(headers: &[(String, String)]) -> Result<(), SockudoError> {
201    for (name, value) in headers {
202        validate_extra_header(name, value)?;
203    }
204    Ok(())
205}
206
207fn validate_extra_header(name: &str, value: &str) -> Result<(), SockudoError> {
208    let parsed_name = name
209        .parse::<http::HeaderName>()
210        .map_err(|_| SockudoError::InvalidHttp("invalid header name"))?;
211
212    if RESERVED_UPGRADE_HEADERS.contains(&parsed_name.as_str()) {
213        return Err(SockudoError::InvalidHttp(
214            "reserved upgrade header not allowed in extra_headers",
215        ));
216    }
217
218    http::HeaderValue::from_str(value)
219        .map_err(|_| SockudoError::InvalidHttp("invalid header value"))?;
220    Ok(())
221}
222
223/// Replay bytes read during the handshake before forwarding to the inner IO.
224pub(crate) struct PrefixedIo<S> {
225    inner: S,
226    prefix: Bytes,
227}
228
229impl<S> PrefixedIo<S> {
230    pub(crate) const fn new(inner: S, prefix: Bytes) -> Self {
231        Self { inner, prefix }
232    }
233}
234
235impl<S> AsyncRead for PrefixedIo<S>
236where
237    S: AsyncRead + Unpin,
238{
239    fn poll_read(
240        mut self: Pin<&mut Self>,
241        cx: &mut Context<'_>,
242        buf: &mut ReadBuf<'_>,
243    ) -> Poll<std::io::Result<()>> {
244        if !self.prefix.is_empty() {
245            let n = self.prefix.len().min(buf.remaining());
246            let chunk = self.prefix.split_to(n);
247            buf.put_slice(&chunk);
248            return Poll::Ready(Ok(()));
249        }
250
251        Pin::new(&mut self.inner).poll_read(cx, buf)
252    }
253}
254
255impl<S> AsyncWrite for PrefixedIo<S>
256where
257    S: AsyncWrite + Unpin,
258{
259    fn poll_write(
260        mut self: Pin<&mut Self>,
261        cx: &mut Context<'_>,
262        buf: &[u8],
263    ) -> Poll<std::io::Result<usize>> {
264        Pin::new(&mut self.inner).poll_write(cx, buf)
265    }
266
267    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
268        Pin::new(&mut self.inner).poll_flush(cx)
269    }
270
271    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
272        Pin::new(&mut self.inner).poll_shutdown(cx)
273    }
274}
275
276impl From<SockudoMessage> for Message {
277    fn from(value: SockudoMessage) -> Self {
278        match value {
279            SockudoMessage::Text(b) => Self::Text(b),
280            SockudoMessage::Binary(b) => Self::Binary(b),
281            SockudoMessage::Ping(b) => Self::Ping(b),
282            SockudoMessage::Pong(b) => Self::Pong(b),
283            SockudoMessage::Close(reason) => Self::Close(reason.map(Into::into)),
284        }
285    }
286}
287
288impl From<Message> for SockudoMessage {
289    /// Convert a neutral [`Message`] into a sockudo [`SockudoMessage`].
290    ///
291    /// Conversion is infallible: both enums carry payloads as `bytes::Bytes` across
292    /// all variants. Sockudo validates UTF-8 on Text frames at parse time, not at
293    /// send time, so feeding it non-UTF-8 bytes via [`Self::Text`] is the caller's
294    /// responsibility.
295    fn from(value: Message) -> Self {
296        match value {
297            Message::Text(b) => Self::Text(b),
298            Message::Binary(b) => Self::Binary(b),
299            Message::Ping(b) => Self::Ping(b),
300            Message::Pong(b) => Self::Pong(b),
301            Message::Close(frame) => Self::Close(frame.map(Into::into)),
302        }
303    }
304}
305
306impl From<SockudoCloseReason> for CloseFrame {
307    fn from(value: SockudoCloseReason) -> Self {
308        Self {
309            code: value.code,
310            reason: value.reason,
311        }
312    }
313}
314
315impl From<CloseFrame> for SockudoCloseReason {
316    fn from(value: CloseFrame) -> Self {
317        Self {
318            code: value.code,
319            reason: value.reason,
320        }
321    }
322}
323
324impl From<SockudoError> for TransportError {
325    fn from(value: SockudoError) -> Self {
326        match value {
327            SockudoError::Io(e) => Self::Io(e),
328            SockudoError::ConnectionClosed => Self::ConnectionClosed,
329            SockudoError::ConnectionReset => Self::ConnectionReset,
330            SockudoError::Closed(reason) => Self::ClosedByPeer(reason.map(Into::into)),
331            SockudoError::MessageTooLarge => Self::MessageTooLarge,
332            SockudoError::FrameTooLarge => Self::FrameTooLarge,
333            SockudoError::InvalidUtf8 => Self::InvalidUtf8,
334            SockudoError::InvalidFrame(msg) | SockudoError::Protocol(msg) => {
335                Self::Protocol(msg.to_string())
336            }
337            SockudoError::InvalidHttp(msg) | SockudoError::HandshakeFailed(msg) => {
338                Self::Handshake(msg.to_string())
339            }
340            other => Self::Other(other.to_string()),
341        }
342    }
343}
344
345/// Adapter that lifts a `sockudo-ws` [`WebSocketStream<S>`] into a
346/// backend-agnostic [`WsTransport`].
347///
348/// Translates messages and errors to the neutral types on the way through
349/// `Stream::poll_next` and `Sink<Message>::start_send` / `poll_*`. The
350/// underlying stream is owned and forwarded to via pin projection.
351pub struct SockudoTransport<S> {
352    inner: WebSocketStream<S>,
353    /// Tracks a flush of the inner write buffer that returned `Pending`. The
354    /// next [`Stream::poll_next`] retries the flush before reading so queued
355    /// control responses (Pong, close reply) are not stranded under sustained
356    /// write backpressure on a quiet reader.
357    pending_flush: bool,
358}
359
360impl<S> SockudoTransport<S> {
361    /// Wrap an established sockudo WebSocket stream.
362    #[inline]
363    #[must_use]
364    pub const fn new(inner: WebSocketStream<S>) -> Self {
365        Self {
366            inner,
367            pending_flush: false,
368        }
369    }
370
371    /// Consume the adapter and return the underlying stream.
372    #[inline]
373    pub fn into_inner(self) -> WebSocketStream<S> {
374        self.inner
375    }
376
377    /// Borrow the underlying stream.
378    #[inline]
379    pub const fn get_ref(&self) -> &WebSocketStream<S> {
380        &self.inner
381    }
382}
383
384impl<S> std::fmt::Debug for SockudoTransport<S> {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        f.debug_struct(stringify!(SockudoTransport))
387            .finish_non_exhaustive()
388    }
389}
390
391impl<S> Stream for SockudoTransport<S>
392where
393    S: AsyncRead + AsyncWrite + Unpin,
394{
395    type Item = Result<Message, TransportError>;
396
397    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        // Drain any flush that returned Pending on a prior poll so queued
399        // control responses (Pong, close reply) reach the peer before the
400        // next read. Errors are dropped here; subsequent writes through the
401        // sink half surface them.
402        if self.pending_flush {
403            match Pin::new(&mut self.inner).poll_flush(cx) {
404                Poll::Ready(_) => self.pending_flush = false,
405                Poll::Pending => {}
406            }
407        }
408
409        let result = match Pin::new(&mut self.inner).poll_next(cx) {
410            Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(Message::from(msg)))),
411            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(TransportError::from(e)))),
412            Poll::Ready(None) => Poll::Ready(None),
413            Poll::Pending => return Poll::Pending,
414        };
415
416        // Sockudo queues automatic Pong / close-response frames into the
417        // write buffer during poll_next. Nudge a flush so they reach the peer
418        // promptly even on a reader-only client; track a pending flush so the
419        // next poll retries when backpressure stalls the write socket.
420        match Pin::new(&mut self.inner).poll_flush(cx) {
421            Poll::Ready(_) => self.pending_flush = false,
422            Poll::Pending => self.pending_flush = true,
423        }
424
425        result
426    }
427}
428
429impl<S> Sink<Message> for SockudoTransport<S>
430where
431    S: AsyncRead + AsyncWrite + Unpin,
432{
433    type Error = TransportError;
434
435    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436        Pin::new(&mut self.inner)
437            .poll_ready(cx)
438            .map_err(TransportError::from)
439    }
440
441    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
442        Pin::new(&mut self.inner)
443            .start_send(SockudoMessage::from(item))
444            .map_err(TransportError::from)
445    }
446
447    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
448        Pin::new(&mut self.inner)
449            .poll_flush(cx)
450            .map_err(TransportError::from)
451    }
452
453    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
454        Pin::new(&mut self.inner)
455            .poll_close(cx)
456            .map_err(TransportError::from)
457    }
458}
459
460const _: fn() = || {
461    fn assert_ws_transport<T: WsTransport>() {}
462    assert_ws_transport::<SockudoTransport<tokio::net::TcpStream>>();
463};
464
465#[cfg(test)]
466mod tests {
467    use bytes::Bytes;
468    use rstest::rstest;
469    #[cfg(not(feature = "turmoil"))]
470    use sockudo_ws::handshake::generate_accept_key;
471    #[cfg(not(feature = "turmoil"))]
472    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex};
473
474    use super::*;
475
476    #[cfg(not(feature = "turmoil"))]
477    async fn read_http_request<S>(stream: &mut S) -> Vec<u8>
478    where
479        S: AsyncRead + Unpin,
480    {
481        let mut buf = Vec::new();
482        let mut chunk = [0u8; 256];
483
484        loop {
485            let n = stream.read(&mut chunk).await.unwrap();
486            assert!(n > 0, "HTTP request closed before headers completed");
487            buf.extend_from_slice(&chunk[..n]);
488            if buf.windows(4).any(|window| window == b"\r\n\r\n") {
489                return buf;
490            }
491        }
492    }
493
494    #[cfg(not(feature = "turmoil"))]
495    fn build_test_response(sec_websocket_key: &str, extra_bytes: &[u8]) -> Vec<u8> {
496        let accept = generate_accept_key(sec_websocket_key);
497        let mut response = format!(
498            concat!(
499                "HTTP/1.1 101 Switching Protocols\r\n",
500                "Upgrade: websocket\r\n",
501                "Connection: Upgrade\r\n",
502                "Sec-WebSocket-Accept: {}\r\n",
503                "\r\n",
504            ),
505            accept
506        )
507        .into_bytes();
508        response.extend_from_slice(extra_bytes);
509        response
510    }
511
512    #[cfg(not(feature = "turmoil"))]
513    fn extract_header<'a>(request: &'a str, name: &str) -> Option<&'a str> {
514        request.lines().find_map(|line| {
515            let (header_name, header_value) = line.split_once(':')?;
516            if header_name.eq_ignore_ascii_case(name) {
517                Some(header_value.trim())
518            } else {
519                None
520            }
521        })
522    }
523
524    #[tokio::test]
525    #[cfg(not(feature = "turmoil"))]
526    async fn client_handshake_with_headers_sends_custom_headers() {
527        let (mut client, mut server) = duplex(4096);
528        let headers = vec![
529            ("ok-access-key".to_string(), "key-1".to_string()),
530            ("ok-access-passphrase".to_string(), "pass-1".to_string()),
531        ];
532
533        let server_task = tokio::spawn(async move {
534            let request = read_http_request(&mut server).await;
535            let request = String::from_utf8(request).unwrap();
536
537            assert!(request.starts_with("GET /ws/v5/public-sbe?instId=BTC-USDT HTTP/1.1\r\n"));
538            assert_eq!(extract_header(&request, "Host"), Some("ws.okx.com:8443"));
539            assert_eq!(extract_header(&request, "ok-access-key"), Some("key-1"));
540            assert_eq!(
541                extract_header(&request, "ok-access-passphrase"),
542                Some("pass-1")
543            );
544
545            let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
546            let response = build_test_response(sec_websocket_key, &[]);
547            server.write_all(&response).await.unwrap();
548        });
549
550        let handshake = client_handshake_with_headers(
551            &mut client,
552            "ws.okx.com:8443",
553            "/ws/v5/public-sbe?instId=BTC-USDT",
554            None,
555            &headers,
556        )
557        .await
558        .unwrap();
559
560        assert_eq!(handshake.path, "/ws/v5/public-sbe?instId=BTC-USDT");
561        assert!(handshake.leftover.is_none());
562        server_task.await.unwrap();
563    }
564
565    #[rstest]
566    #[cfg(not(feature = "turmoil"))]
567    #[case::host("Host")]
568    #[case::upgrade("Upgrade")]
569    #[case::connection("Connection")]
570    #[case::sec_websocket_key("Sec-WebSocket-Key")]
571    #[case::sec_websocket_version("Sec-WebSocket-Version")]
572    #[case::sec_websocket_protocol("Sec-WebSocket-Protocol")]
573    #[case::sec_websocket_extensions("Sec-WebSocket-Extensions")]
574    #[case::content_length("Content-Length")]
575    #[case::transfer_encoding("Transfer-Encoding")]
576    #[case::te("TE")]
577    #[case::trailer("Trailer")]
578    fn validate_extra_header_rejects_reserved_upgrade_headers(#[case] name: &str) {
579        let err = validate_extra_header(name, "value").unwrap_err();
580
581        assert!(matches!(
582            err,
583            SockudoError::InvalidHttp("reserved upgrade header not allowed in extra_headers")
584        ));
585    }
586
587    #[tokio::test]
588    #[cfg(not(feature = "turmoil"))]
589    async fn client_handshake_with_headers_rejects_missing_accept() {
590        let (mut client, mut server) = duplex(4096);
591
592        let server_task = tokio::spawn(async move {
593            let _request = read_http_request(&mut server).await;
594            server
595                .write_all(
596                    b"HTTP/1.1 101 Switching Protocols\r\n\
597                      Upgrade: websocket\r\n\
598                      Connection: Upgrade\r\n\
599                      \r\n",
600                )
601                .await
602                .unwrap();
603        });
604
605        let err = client_handshake_with_headers(&mut client, "example.com", "/ws", None, &[])
606            .await
607            .unwrap_err();
608
609        assert!(matches!(
610            err,
611            SockudoError::HandshakeFailed("missing Sec-WebSocket-Accept")
612        ));
613        server_task.await.unwrap();
614    }
615
616    #[tokio::test]
617    #[cfg(not(feature = "turmoil"))]
618    async fn client_handshake_with_headers_returns_leftover_bytes() {
619        let (mut client, mut server) = duplex(4096);
620        let extra = b"\x81\x05hello";
621
622        let server_task = tokio::spawn(async move {
623            let request = read_http_request(&mut server).await;
624            let request = String::from_utf8(request).unwrap();
625            let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
626            let response = build_test_response(sec_websocket_key, extra);
627            server.write_all(&response).await.unwrap();
628        });
629
630        let handshake = client_handshake_with_headers(&mut client, "example.com", "/ws", None, &[])
631            .await
632            .unwrap();
633
634        assert_eq!(handshake.leftover.as_deref(), Some(extra.as_slice()));
635        server_task.await.unwrap();
636    }
637
638    #[tokio::test]
639    #[cfg(not(feature = "turmoil"))]
640    async fn prefixed_io_replays_leftover_before_socket() {
641        let (client, mut server) = duplex(4096);
642        let mut prefixed = PrefixedIo::new(client, Bytes::from_static(b"abc"));
643
644        let server_task = tokio::spawn(async move {
645            server.write_all(b"def").await.unwrap();
646        });
647
648        let mut buf = [0u8; 6];
649        prefixed.read_exact(&mut buf).await.unwrap();
650
651        assert_eq!(&buf, b"abcdef");
652        server_task.await.unwrap();
653    }
654
655    #[rstest]
656    fn round_trip_text() {
657        let original = SockudoMessage::Text(Bytes::from_static(b"hello"));
658        let neutral: Message = original.into();
659        assert!(neutral.is_text());
660        assert_eq!(neutral.as_bytes(), b"hello");
661
662        let back: SockudoMessage = neutral.into();
663        match back {
664            SockudoMessage::Text(b) => assert_eq!(&b[..], b"hello"),
665            other => panic!("expected text, was {other:?}"),
666        }
667    }
668
669    #[rstest]
670    fn round_trip_binary() {
671        let original = SockudoMessage::Binary(Bytes::from_static(&[1, 2, 3]));
672        let neutral: Message = original.into();
673        assert_eq!(neutral.as_bytes(), &[1, 2, 3]);
674
675        let back: SockudoMessage = neutral.into();
676        match back {
677            SockudoMessage::Binary(b) => assert_eq!(&b[..], &[1, 2, 3]),
678            other => panic!("expected binary, was {other:?}"),
679        }
680    }
681
682    #[rstest]
683    fn round_trip_ping_pong() {
684        let neutral: Message = SockudoMessage::Ping(Bytes::from_static(b"p")).into();
685        assert!(neutral.is_ping());
686
687        let neutral: Message = SockudoMessage::Pong(Bytes::from_static(b"q")).into();
688        assert!(neutral.is_pong());
689    }
690
691    #[rstest]
692    fn close_frame_round_trip() {
693        let original = SockudoMessage::Close(Some(SockudoCloseReason {
694            code: 1000,
695            reason: "bye".into(),
696        }));
697        let neutral: Message = original.into();
698        let Message::Close(Some(frame)) = &neutral else {
699            panic!("expected close frame");
700        };
701        assert_eq!(frame.code, 1000);
702        assert_eq!(frame.reason, "bye");
703
704        let back: SockudoMessage = neutral.into();
705        let SockudoMessage::Close(Some(reason)) = back else {
706            panic!("expected close frame");
707        };
708        assert_eq!(reason.code, 1000);
709        assert_eq!(reason.reason, "bye");
710    }
711
712    #[rstest]
713    fn error_translation_closed() {
714        let err: TransportError = SockudoError::ConnectionClosed.into();
715        assert!(matches!(err, TransportError::ConnectionClosed));
716    }
717
718    #[rstest]
719    fn error_translation_utf8() {
720        let err: TransportError = SockudoError::InvalidUtf8.into();
721        assert!(matches!(err, TransportError::InvalidUtf8));
722    }
723
724    #[rstest]
725    fn error_translation_handshake() {
726        let err: TransportError = SockudoError::HandshakeFailed("bad").into();
727        assert!(matches!(err, TransportError::Handshake(_)));
728    }
729}