nautilus_network/transport/
tungstenite.rs1use std::{
27 pin::Pin,
28 task::{Context, Poll},
29};
30
31use bytes::Bytes;
32use futures::{Sink, Stream};
33use tokio::io::{AsyncRead, AsyncWrite};
34use tokio_tungstenite::{
35 WebSocketStream,
36 tungstenite::{
37 self, Utf8Bytes,
38 protocol::{CloseFrame as TgCloseFrame, frame::coding::CloseCode},
39 },
40};
41
42use super::{
43 error::TransportError,
44 message::{CloseFrame, Message},
45 stream::WsTransport,
46};
47
48impl From<tungstenite::Message> for Message {
49 fn from(value: tungstenite::Message) -> Self {
50 match value {
51 tungstenite::Message::Text(text) => Self::Text(Bytes::from(text)),
52 tungstenite::Message::Binary(data) => Self::Binary(data),
53 tungstenite::Message::Ping(data) => Self::Ping(data),
54 tungstenite::Message::Pong(data) => Self::Pong(data),
55 tungstenite::Message::Close(frame) => Self::Close(frame.map(Into::into)),
56
57 tungstenite::Message::Frame(frame) => Self::Binary(frame.into_payload()),
59 }
60 }
61}
62
63impl TryFrom<Message> for tungstenite::Message {
64 type Error = TransportError;
65
66 fn try_from(value: Message) -> Result<Self, Self::Error> {
77 Ok(match value {
78 Message::Text(bytes) => match Utf8Bytes::try_from(bytes) {
79 Ok(text) => Self::Text(text),
80 Err(_) => return Err(TransportError::InvalidUtf8),
81 },
82 Message::Binary(bytes) => Self::Binary(bytes),
83 Message::Ping(bytes) => Self::Ping(bytes),
84 Message::Pong(bytes) => Self::Pong(bytes),
85 Message::Close(frame) => Self::Close(frame.map(Into::into)),
86 })
87 }
88}
89
90impl From<TgCloseFrame> for CloseFrame {
91 fn from(value: TgCloseFrame) -> Self {
92 Self {
93 code: u16::from(value.code),
94 reason: value.reason.as_str().to_owned(),
95 }
96 }
97}
98
99impl From<CloseFrame> for TgCloseFrame {
100 fn from(value: CloseFrame) -> Self {
101 Self {
102 code: CloseCode::from(value.code),
103 reason: value.reason.into(),
104 }
105 }
106}
107
108impl From<tungstenite::Error> for TransportError {
109 fn from(value: tungstenite::Error) -> Self {
110 match value {
111 tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed => {
112 Self::ConnectionClosed
113 }
114 tungstenite::Error::Io(e) => Self::Io(e),
115 tungstenite::Error::Tls(e) => Self::Tls(e.to_string()),
116 tungstenite::Error::Capacity(e) => match e {
117 tungstenite::error::CapacityError::MessageTooLong { .. } => Self::MessageTooLarge,
118 e @ tungstenite::error::CapacityError::TooManyHeaders => Self::Other(e.to_string()),
119 },
120 tungstenite::Error::Protocol(e) => Self::Protocol(e.to_string()),
121 tungstenite::Error::Utf8(_) => Self::InvalidUtf8,
122 tungstenite::Error::Url(e) => Self::InvalidUrl(e.to_string()),
123 tungstenite::Error::Http(resp) => {
124 Self::Handshake(format!("HTTP status {}", resp.status()))
125 }
126 tungstenite::Error::HttpFormat(e) => Self::Handshake(e.to_string()),
127 other => Self::Other(other.to_string()),
128 }
129 }
130}
131
132#[derive(Debug)]
139pub struct TungsteniteTransport<S> {
140 inner: WebSocketStream<S>,
141}
142
143impl<S> TungsteniteTransport<S> {
144 #[inline]
146 #[must_use]
147 pub const fn new(inner: WebSocketStream<S>) -> Self {
148 Self { inner }
149 }
150
151 #[inline]
153 pub fn into_inner(self) -> WebSocketStream<S> {
154 self.inner
155 }
156
157 #[inline]
159 pub const fn get_ref(&self) -> &WebSocketStream<S> {
160 &self.inner
161 }
162}
163
164impl<S> Stream for TungsteniteTransport<S>
165where
166 S: AsyncRead + AsyncWrite + Unpin,
167{
168 type Item = Result<Message, TransportError>;
169
170 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171 match Pin::new(&mut self.inner).poll_next(cx) {
172 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(Message::from(msg)))),
173 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(TransportError::from(e)))),
174 Poll::Ready(None) => Poll::Ready(None),
175 Poll::Pending => Poll::Pending,
176 }
177 }
178}
179
180impl<S> Sink<Message> for TungsteniteTransport<S>
181where
182 S: AsyncRead + AsyncWrite + Unpin,
183{
184 type Error = TransportError;
185
186 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187 Pin::new(&mut self.inner)
188 .poll_ready(cx)
189 .map_err(TransportError::from)
190 }
191
192 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
193 let native = tungstenite::Message::try_from(item)?;
194 Pin::new(&mut self.inner)
195 .start_send(native)
196 .map_err(TransportError::from)
197 }
198
199 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200 Pin::new(&mut self.inner)
201 .poll_flush(cx)
202 .map_err(TransportError::from)
203 }
204
205 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206 Pin::new(&mut self.inner)
207 .poll_close(cx)
208 .map_err(TransportError::from)
209 }
210}
211
212const _: fn() = || {
213 fn assert_ws_transport<T: WsTransport>() {}
214 assert_ws_transport::<TungsteniteTransport<tokio::net::TcpStream>>();
215};
216
217#[cfg(test)]
218mod tests {
219 use bytes::Bytes;
220 use rstest::rstest;
221 use tokio_tungstenite::tungstenite::{self, Utf8Bytes};
222
223 use super::*;
224
225 #[rstest]
226 fn round_trip_text() {
227 let original = tungstenite::Message::Text(Utf8Bytes::from("hello"));
228 let neutral: Message = original.into();
229 assert!(neutral.is_text());
230 assert_eq!(neutral.as_bytes(), b"hello");
231
232 let back = tungstenite::Message::try_from(neutral).unwrap();
233 match back {
234 tungstenite::Message::Text(t) => assert_eq!(t.as_str(), "hello"),
235 other => panic!("expected text, was {other:?}"),
236 }
237 }
238
239 #[rstest]
240 fn try_from_text_rejects_invalid_utf8() {
241 let neutral = Message::Text(Bytes::from_static(&[0xFF, 0xFE]));
242 let err = tungstenite::Message::try_from(neutral).unwrap_err();
243 assert!(matches!(err, TransportError::InvalidUtf8));
244 }
245
246 #[rstest]
247 fn round_trip_binary() {
248 let original = tungstenite::Message::Binary(Bytes::from_static(&[1, 2, 3]));
249 let neutral: Message = original.into();
250 assert_eq!(neutral.as_bytes(), &[1, 2, 3]);
251
252 let back = tungstenite::Message::try_from(neutral).unwrap();
253 match back {
254 tungstenite::Message::Binary(b) => assert_eq!(&b[..], &[1, 2, 3]),
255 other => panic!("expected binary, was {other:?}"),
256 }
257 }
258
259 #[rstest]
260 fn round_trip_ping_pong() {
261 let ping = tungstenite::Message::Ping(Bytes::from_static(b"p"));
262 let neutral: Message = ping.into();
263 assert!(neutral.is_ping());
264
265 let pong = tungstenite::Message::Pong(Bytes::from_static(b"q"));
266 let neutral: Message = pong.into();
267 assert!(neutral.is_pong());
268 }
269
270 #[rstest]
271 fn close_frame_round_trip() {
272 let original = tungstenite::Message::Close(Some(TgCloseFrame {
273 code: CloseCode::Normal,
274 reason: "bye".into(),
275 }));
276 let neutral: Message = original.into();
277 let Message::Close(Some(frame)) = &neutral else {
278 panic!("expected close frame");
279 };
280 assert_eq!(frame.code, 1000);
281 assert_eq!(frame.reason, "bye");
282
283 let back = tungstenite::Message::try_from(neutral).unwrap();
284 let tungstenite::Message::Close(Some(frame)) = back else {
285 panic!("expected close frame");
286 };
287 assert_eq!(u16::from(frame.code), 1000);
288 assert_eq!(frame.reason.as_str(), "bye");
289 }
290
291 #[rstest]
292 fn error_translation_closed() {
293 let err: TransportError = tungstenite::Error::ConnectionClosed.into();
294 assert!(matches!(err, TransportError::ConnectionClosed));
295 }
296
297 #[rstest]
298 fn error_translation_utf8() {
299 let err: TransportError = tungstenite::Error::Utf8(String::from("bad")).into();
300 assert!(matches!(err, TransportError::InvalidUtf8));
301 }
302}