Skip to main content

nautilus_network/transport/
tungstenite.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! `tokio-tungstenite` backend for the transport abstraction.
17//!
18//! Provides `From` conversions between the neutral [`Message`] and
19//! [`TransportError`] types and tungstenite's native types, plus the
20//! [`TungsteniteTransport<S>`] adapter that lifts a tungstenite
21//! `WebSocketStream<S>` into a backend-agnostic [`WsTransport`].
22//!
23//! The message conversions are structural (no payload copies): tungstenite
24//! stores payloads in `Bytes` and `Utf8Bytes`, which we re-wrap directly.
25
26use std::{
27    pin::Pin,
28    task::{Context, Poll},
29};
30
31use bytes::Bytes;
32use futures::{Sink, Stream};
33use tokio::io::{AsyncRead, AsyncWrite};
34use tokio_tungstenite::{
35    WebSocketStream,
36    tungstenite::{
37        self, Utf8Bytes,
38        protocol::{CloseFrame as TgCloseFrame, frame::coding::CloseCode},
39    },
40};
41
42use super::{
43    error::TransportError,
44    message::{CloseFrame, Message},
45    stream::WsTransport,
46};
47
48impl From<tungstenite::Message> for Message {
49    fn from(value: tungstenite::Message) -> Self {
50        match value {
51            tungstenite::Message::Text(text) => Self::Text(Bytes::from(text)),
52            tungstenite::Message::Binary(data) => Self::Binary(data),
53            tungstenite::Message::Ping(data) => Self::Ping(data),
54            tungstenite::Message::Pong(data) => Self::Pong(data),
55            tungstenite::Message::Close(frame) => Self::Close(frame.map(Into::into)),
56
57            // Tungstenite only emits Frame when explicitly constructed; treat as binary
58            tungstenite::Message::Frame(frame) => Self::Binary(frame.into_payload()),
59        }
60    }
61}
62
63impl TryFrom<Message> for tungstenite::Message {
64    type Error = TransportError;
65
66    /// Convert a neutral [`Message`] into a tungstenite [`tungstenite::Message`].
67    ///
68    /// Validates the `Text` payload as UTF-8 because tungstenite refuses to
69    /// transmit a Text frame whose body is not valid UTF-8. Other variants
70    /// are infallible.
71    ///
72    /// # Errors
73    ///
74    /// Returns [`TransportError::InvalidUtf8`] if a `Text` payload is not
75    /// valid UTF-8.
76    fn try_from(value: Message) -> Result<Self, Self::Error> {
77        Ok(match value {
78            Message::Text(bytes) => match Utf8Bytes::try_from(bytes) {
79                Ok(text) => Self::Text(text),
80                Err(_) => return Err(TransportError::InvalidUtf8),
81            },
82            Message::Binary(bytes) => Self::Binary(bytes),
83            Message::Ping(bytes) => Self::Ping(bytes),
84            Message::Pong(bytes) => Self::Pong(bytes),
85            Message::Close(frame) => Self::Close(frame.map(Into::into)),
86        })
87    }
88}
89
90impl From<TgCloseFrame> for CloseFrame {
91    fn from(value: TgCloseFrame) -> Self {
92        Self {
93            code: u16::from(value.code),
94            reason: value.reason.as_str().to_owned(),
95        }
96    }
97}
98
99impl From<CloseFrame> for TgCloseFrame {
100    fn from(value: CloseFrame) -> Self {
101        Self {
102            code: CloseCode::from(value.code),
103            reason: value.reason.into(),
104        }
105    }
106}
107
108impl From<tungstenite::Error> for TransportError {
109    fn from(value: tungstenite::Error) -> Self {
110        match value {
111            tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed => {
112                Self::ConnectionClosed
113            }
114            tungstenite::Error::Io(e) => Self::Io(e),
115            tungstenite::Error::Tls(e) => Self::Tls(e.to_string()),
116            tungstenite::Error::Capacity(e) => match e {
117                tungstenite::error::CapacityError::MessageTooLong { .. } => Self::MessageTooLarge,
118                e @ tungstenite::error::CapacityError::TooManyHeaders => Self::Other(e.to_string()),
119            },
120            tungstenite::Error::Protocol(e) => Self::Protocol(e.to_string()),
121            tungstenite::Error::Utf8(_) => Self::InvalidUtf8,
122            tungstenite::Error::Url(e) => Self::InvalidUrl(e.to_string()),
123            tungstenite::Error::Http(resp) => {
124                Self::Handshake(format!("HTTP status {}", resp.status()))
125            }
126            tungstenite::Error::HttpFormat(e) => Self::Handshake(e.to_string()),
127            other => Self::Other(other.to_string()),
128        }
129    }
130}
131
132/// Adapter that lifts a `tokio-tungstenite` [`WebSocketStream<S>`] into a
133/// backend-agnostic [`WsTransport`].
134///
135/// Translates messages and errors to the neutral types on the way through
136/// `Stream::poll_next` and `Sink<Message>::start_send` / `poll_*`. The
137/// underlying stream is owned and forwarded to via pin projection.
138#[derive(Debug)]
139pub struct TungsteniteTransport<S> {
140    inner: WebSocketStream<S>,
141}
142
143impl<S> TungsteniteTransport<S> {
144    /// Wrap an established tungstenite WebSocket stream.
145    #[inline]
146    #[must_use]
147    pub const fn new(inner: WebSocketStream<S>) -> Self {
148        Self { inner }
149    }
150
151    /// Consume the adapter and return the underlying stream.
152    #[inline]
153    pub fn into_inner(self) -> WebSocketStream<S> {
154        self.inner
155    }
156
157    /// Borrow the underlying stream.
158    #[inline]
159    pub const fn get_ref(&self) -> &WebSocketStream<S> {
160        &self.inner
161    }
162}
163
164impl<S> Stream for TungsteniteTransport<S>
165where
166    S: AsyncRead + AsyncWrite + Unpin,
167{
168    type Item = Result<Message, TransportError>;
169
170    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171        match Pin::new(&mut self.inner).poll_next(cx) {
172            Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(Message::from(msg)))),
173            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(TransportError::from(e)))),
174            Poll::Ready(None) => Poll::Ready(None),
175            Poll::Pending => Poll::Pending,
176        }
177    }
178}
179
180impl<S> Sink<Message> for TungsteniteTransport<S>
181where
182    S: AsyncRead + AsyncWrite + Unpin,
183{
184    type Error = TransportError;
185
186    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187        Pin::new(&mut self.inner)
188            .poll_ready(cx)
189            .map_err(TransportError::from)
190    }
191
192    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
193        let native = tungstenite::Message::try_from(item)?;
194        Pin::new(&mut self.inner)
195            .start_send(native)
196            .map_err(TransportError::from)
197    }
198
199    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200        Pin::new(&mut self.inner)
201            .poll_flush(cx)
202            .map_err(TransportError::from)
203    }
204
205    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206        Pin::new(&mut self.inner)
207            .poll_close(cx)
208            .map_err(TransportError::from)
209    }
210}
211
212const _: fn() = || {
213    fn assert_ws_transport<T: WsTransport>() {}
214    assert_ws_transport::<TungsteniteTransport<tokio::net::TcpStream>>();
215};
216
217#[cfg(test)]
218mod tests {
219    use bytes::Bytes;
220    use rstest::rstest;
221    use tokio_tungstenite::tungstenite::{self, Utf8Bytes};
222
223    use super::*;
224
225    #[rstest]
226    fn round_trip_text() {
227        let original = tungstenite::Message::Text(Utf8Bytes::from("hello"));
228        let neutral: Message = original.into();
229        assert!(neutral.is_text());
230        assert_eq!(neutral.as_bytes(), b"hello");
231
232        let back = tungstenite::Message::try_from(neutral).unwrap();
233        match back {
234            tungstenite::Message::Text(t) => assert_eq!(t.as_str(), "hello"),
235            other => panic!("expected text, was {other:?}"),
236        }
237    }
238
239    #[rstest]
240    fn try_from_text_rejects_invalid_utf8() {
241        let neutral = Message::Text(Bytes::from_static(&[0xFF, 0xFE]));
242        let err = tungstenite::Message::try_from(neutral).unwrap_err();
243        assert!(matches!(err, TransportError::InvalidUtf8));
244    }
245
246    #[rstest]
247    fn round_trip_binary() {
248        let original = tungstenite::Message::Binary(Bytes::from_static(&[1, 2, 3]));
249        let neutral: Message = original.into();
250        assert_eq!(neutral.as_bytes(), &[1, 2, 3]);
251
252        let back = tungstenite::Message::try_from(neutral).unwrap();
253        match back {
254            tungstenite::Message::Binary(b) => assert_eq!(&b[..], &[1, 2, 3]),
255            other => panic!("expected binary, was {other:?}"),
256        }
257    }
258
259    #[rstest]
260    fn round_trip_ping_pong() {
261        let ping = tungstenite::Message::Ping(Bytes::from_static(b"p"));
262        let neutral: Message = ping.into();
263        assert!(neutral.is_ping());
264
265        let pong = tungstenite::Message::Pong(Bytes::from_static(b"q"));
266        let neutral: Message = pong.into();
267        assert!(neutral.is_pong());
268    }
269
270    #[rstest]
271    fn close_frame_round_trip() {
272        let original = tungstenite::Message::Close(Some(TgCloseFrame {
273            code: CloseCode::Normal,
274            reason: "bye".into(),
275        }));
276        let neutral: Message = original.into();
277        let Message::Close(Some(frame)) = &neutral else {
278            panic!("expected close frame");
279        };
280        assert_eq!(frame.code, 1000);
281        assert_eq!(frame.reason, "bye");
282
283        let back = tungstenite::Message::try_from(neutral).unwrap();
284        let tungstenite::Message::Close(Some(frame)) = back else {
285            panic!("expected close frame");
286        };
287        assert_eq!(u16::from(frame.code), 1000);
288        assert_eq!(frame.reason.as_str(), "bye");
289    }
290
291    #[rstest]
292    fn error_translation_closed() {
293        let err: TransportError = tungstenite::Error::ConnectionClosed.into();
294        assert!(matches!(err, TransportError::ConnectionClosed));
295    }
296
297    #[rstest]
298    fn error_translation_utf8() {
299        let err: TransportError = tungstenite::Error::Utf8(String::from("bad")).into();
300        assert!(matches!(err, TransportError::InvalidUtf8));
301    }
302}