1use std::{
30 pin::Pin,
31 task::{Context, Poll},
32};
33
34use bytes::{BufMut, Bytes, BytesMut};
35use futures::{Sink, Stream};
36use sockudo_ws::{
37 HandshakeResult,
38 error::{CloseReason as SockudoCloseReason, Error as SockudoError},
39 handshake,
40 protocol::Message as SockudoMessage,
41 stream::WebSocketStream,
42};
43use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
44
45use super::{
46 error::TransportError,
47 message::{CloseFrame, Message},
48 stream::WsTransport,
49};
50
51const MAX_HTTP_HEADER_SIZE: usize = 8192;
52
53const RESERVED_UPGRADE_HEADERS: &[&str] = &[
56 "host",
57 "upgrade",
58 "connection",
59 "sec-websocket-key",
60 "sec-websocket-version",
61 "sec-websocket-protocol",
62 "sec-websocket-extensions",
63 "content-length",
64 "transfer-encoding",
65 "te",
66 "trailer",
67];
68
69pub(crate) async fn client_handshake_with_headers<S>(
73 stream: &mut S,
74 host: &str,
75 path: &str,
76 protocol: Option<&str>,
77 extra_headers: &[(String, String)],
78) -> Result<HandshakeResult, SockudoError>
79where
80 S: AsyncRead + AsyncWrite + Unpin,
81{
82 use tokio::io::{AsyncReadExt, AsyncWriteExt};
83
84 let key = handshake::generate_key();
85 let request = build_request_with_headers(host, path, &key, protocol, None, extra_headers);
86
87 stream.write_all(&request).await?;
88 stream.flush().await?;
89
90 let mut buf = BytesMut::with_capacity(4096);
91
92 loop {
93 if buf.len() > MAX_HTTP_HEADER_SIZE {
94 return Err(SockudoError::InvalidHttp("response too large"));
95 }
96
97 let n = stream.read_buf(&mut buf).await?;
98 if n == 0 {
99 return Err(SockudoError::ConnectionClosed);
100 }
101
102 let parsed = match handshake::parse_response(&buf) {
103 Ok(parsed) => parsed,
104 Err(e) => {
105 log_handshake_response(host, path, &e, &buf);
106 return Err(e);
107 }
108 };
109
110 if let Some((res, consumed)) = parsed {
111 let accept = res.accept.ok_or_else(|| {
112 let e = SockudoError::HandshakeFailed("missing Sec-WebSocket-Accept");
113 log_handshake_response(host, path, &e, &buf);
114 e
115 })?;
116
117 if !handshake::validate_accept_key(&key, accept) {
118 let e = SockudoError::HandshakeFailed("invalid Sec-WebSocket-Accept");
119 log_handshake_response(host, path, &e, &buf);
120 return Err(e);
121 }
122
123 let res_protocol = res.protocol.map(String::from);
124 let res_extensions = res.extensions.map(String::from);
125 let leftover = if consumed < buf.len() {
126 Some(buf.split_off(consumed).freeze())
127 } else {
128 None
129 };
130
131 return Ok(HandshakeResult {
132 path: path.to_string(),
133 protocol: res_protocol,
134 extensions: res_extensions,
135 leftover,
136 });
137 }
138 }
139}
140
141fn log_handshake_response(host: &str, path: &str, err: &SockudoError, buf: &BytesMut) {
143 const PREVIEW_BYTES: usize = 512;
144 let take = buf.len().min(PREVIEW_BYTES);
145 let preview = String::from_utf8_lossy(&buf[..take]);
146 let truncated = if buf.len() > take { " (truncated)" } else { "" };
147 log::error!(
148 "Sockudo handshake failed for {host}{path}: {err}; response{truncated}:\n{preview}"
149 );
150}
151
152fn build_request_with_headers(
155 host: &str,
156 path: &str,
157 key: &str,
158 protocol: Option<&str>,
159 extensions: Option<&str>,
160 extra_headers: &[(String, String)],
161) -> Bytes {
162 let mut buf = BytesMut::with_capacity(512);
163
164 buf.put_slice(b"GET ");
165 buf.put_slice(path.as_bytes());
166 buf.put_slice(b" HTTP/1.1\r\n");
167 buf.put_slice(b"Host: ");
168 buf.put_slice(host.as_bytes());
169 buf.put_slice(b"\r\n");
170 buf.put_slice(b"Upgrade: websocket\r\n");
171 buf.put_slice(b"Connection: Upgrade\r\n");
172 buf.put_slice(b"Sec-WebSocket-Key: ");
173 buf.put_slice(key.as_bytes());
174 buf.put_slice(b"\r\n");
175 buf.put_slice(b"Sec-WebSocket-Version: 13\r\n");
176
177 if let Some(proto) = protocol {
178 buf.put_slice(b"Sec-WebSocket-Protocol: ");
179 buf.put_slice(proto.as_bytes());
180 buf.put_slice(b"\r\n");
181 }
182
183 if let Some(ext) = extensions {
184 buf.put_slice(b"Sec-WebSocket-Extensions: ");
185 buf.put_slice(ext.as_bytes());
186 buf.put_slice(b"\r\n");
187 }
188
189 for (name, value) in extra_headers {
190 buf.put_slice(name.as_bytes());
191 buf.put_slice(b": ");
192 buf.put_slice(value.as_bytes());
193 buf.put_slice(b"\r\n");
194 }
195
196 buf.put_slice(b"\r\n");
197 buf.freeze()
198}
199
200pub(crate) fn validate_extra_headers(headers: &[(String, String)]) -> Result<(), SockudoError> {
201 for (name, value) in headers {
202 validate_extra_header(name, value)?;
203 }
204 Ok(())
205}
206
207fn validate_extra_header(name: &str, value: &str) -> Result<(), SockudoError> {
208 let parsed_name = name
209 .parse::<http::HeaderName>()
210 .map_err(|_| SockudoError::InvalidHttp("invalid header name"))?;
211
212 if RESERVED_UPGRADE_HEADERS.contains(&parsed_name.as_str()) {
213 return Err(SockudoError::InvalidHttp(
214 "reserved upgrade header not allowed in extra_headers",
215 ));
216 }
217
218 http::HeaderValue::from_str(value)
219 .map_err(|_| SockudoError::InvalidHttp("invalid header value"))?;
220 Ok(())
221}
222
223pub(crate) struct PrefixedIo<S> {
225 inner: S,
226 prefix: Bytes,
227}
228
229impl<S> PrefixedIo<S> {
230 pub(crate) const fn new(inner: S, prefix: Bytes) -> Self {
231 Self { inner, prefix }
232 }
233}
234
235impl<S> AsyncRead for PrefixedIo<S>
236where
237 S: AsyncRead + Unpin,
238{
239 fn poll_read(
240 mut self: Pin<&mut Self>,
241 cx: &mut Context<'_>,
242 buf: &mut ReadBuf<'_>,
243 ) -> Poll<std::io::Result<()>> {
244 if !self.prefix.is_empty() {
245 let n = self.prefix.len().min(buf.remaining());
246 let chunk = self.prefix.split_to(n);
247 buf.put_slice(&chunk);
248 return Poll::Ready(Ok(()));
249 }
250
251 Pin::new(&mut self.inner).poll_read(cx, buf)
252 }
253}
254
255impl<S> AsyncWrite for PrefixedIo<S>
256where
257 S: AsyncWrite + Unpin,
258{
259 fn poll_write(
260 mut self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 buf: &[u8],
263 ) -> Poll<std::io::Result<usize>> {
264 Pin::new(&mut self.inner).poll_write(cx, buf)
265 }
266
267 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
268 Pin::new(&mut self.inner).poll_flush(cx)
269 }
270
271 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
272 Pin::new(&mut self.inner).poll_shutdown(cx)
273 }
274}
275
276impl From<SockudoMessage> for Message {
277 fn from(value: SockudoMessage) -> Self {
278 match value {
279 SockudoMessage::Text(b) => Self::Text(b),
280 SockudoMessage::Binary(b) => Self::Binary(b),
281 SockudoMessage::Ping(b) => Self::Ping(b),
282 SockudoMessage::Pong(b) => Self::Pong(b),
283 SockudoMessage::Close(reason) => Self::Close(reason.map(Into::into)),
284 }
285 }
286}
287
288impl From<Message> for SockudoMessage {
289 fn from(value: Message) -> Self {
296 match value {
297 Message::Text(b) => Self::Text(b),
298 Message::Binary(b) => Self::Binary(b),
299 Message::Ping(b) => Self::Ping(b),
300 Message::Pong(b) => Self::Pong(b),
301 Message::Close(frame) => Self::Close(frame.map(Into::into)),
302 }
303 }
304}
305
306impl From<SockudoCloseReason> for CloseFrame {
307 fn from(value: SockudoCloseReason) -> Self {
308 Self {
309 code: value.code,
310 reason: value.reason,
311 }
312 }
313}
314
315impl From<CloseFrame> for SockudoCloseReason {
316 fn from(value: CloseFrame) -> Self {
317 Self {
318 code: value.code,
319 reason: value.reason,
320 }
321 }
322}
323
324impl From<SockudoError> for TransportError {
325 fn from(value: SockudoError) -> Self {
326 match value {
327 SockudoError::Io(e) => Self::Io(e),
328 SockudoError::ConnectionClosed => Self::ConnectionClosed,
329 SockudoError::ConnectionReset => Self::ConnectionReset,
330 SockudoError::Closed(reason) => Self::ClosedByPeer(reason.map(Into::into)),
331 SockudoError::MessageTooLarge => Self::MessageTooLarge,
332 SockudoError::FrameTooLarge => Self::FrameTooLarge,
333 SockudoError::InvalidUtf8 => Self::InvalidUtf8,
334 SockudoError::InvalidFrame(msg) | SockudoError::Protocol(msg) => {
335 Self::Protocol(msg.to_string())
336 }
337 SockudoError::InvalidHttp(msg) | SockudoError::HandshakeFailed(msg) => {
338 Self::Handshake(msg.to_string())
339 }
340 other => Self::Other(other.to_string()),
341 }
342 }
343}
344
345pub struct SockudoTransport<S> {
352 inner: WebSocketStream<S>,
353 pending_flush: bool,
358}
359
360impl<S> SockudoTransport<S> {
361 #[inline]
363 #[must_use]
364 pub const fn new(inner: WebSocketStream<S>) -> Self {
365 Self {
366 inner,
367 pending_flush: false,
368 }
369 }
370
371 #[inline]
373 pub fn into_inner(self) -> WebSocketStream<S> {
374 self.inner
375 }
376
377 #[inline]
379 pub const fn get_ref(&self) -> &WebSocketStream<S> {
380 &self.inner
381 }
382}
383
384impl<S> std::fmt::Debug for SockudoTransport<S> {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.debug_struct(stringify!(SockudoTransport))
387 .finish_non_exhaustive()
388 }
389}
390
391impl<S> Stream for SockudoTransport<S>
392where
393 S: AsyncRead + AsyncWrite + Unpin,
394{
395 type Item = Result<Message, TransportError>;
396
397 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398 if self.pending_flush {
403 match Pin::new(&mut self.inner).poll_flush(cx) {
404 Poll::Ready(_) => self.pending_flush = false,
405 Poll::Pending => {}
406 }
407 }
408
409 let result = match Pin::new(&mut self.inner).poll_next(cx) {
410 Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(Message::from(msg)))),
411 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(TransportError::from(e)))),
412 Poll::Ready(None) => Poll::Ready(None),
413 Poll::Pending => return Poll::Pending,
414 };
415
416 match Pin::new(&mut self.inner).poll_flush(cx) {
421 Poll::Ready(_) => self.pending_flush = false,
422 Poll::Pending => self.pending_flush = true,
423 }
424
425 result
426 }
427}
428
429impl<S> Sink<Message> for SockudoTransport<S>
430where
431 S: AsyncRead + AsyncWrite + Unpin,
432{
433 type Error = TransportError;
434
435 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436 Pin::new(&mut self.inner)
437 .poll_ready(cx)
438 .map_err(TransportError::from)
439 }
440
441 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
442 Pin::new(&mut self.inner)
443 .start_send(SockudoMessage::from(item))
444 .map_err(TransportError::from)
445 }
446
447 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
448 Pin::new(&mut self.inner)
449 .poll_flush(cx)
450 .map_err(TransportError::from)
451 }
452
453 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
454 Pin::new(&mut self.inner)
455 .poll_close(cx)
456 .map_err(TransportError::from)
457 }
458}
459
460const _: fn() = || {
461 fn assert_ws_transport<T: WsTransport>() {}
462 assert_ws_transport::<SockudoTransport<tokio::net::TcpStream>>();
463};
464
465#[cfg(test)]
466mod tests {
467 use bytes::Bytes;
468 use rstest::rstest;
469 #[cfg(not(feature = "turmoil"))]
470 use sockudo_ws::handshake::generate_accept_key;
471 #[cfg(not(feature = "turmoil"))]
472 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex};
473
474 use super::*;
475
476 #[cfg(not(feature = "turmoil"))]
477 async fn read_http_request<S>(stream: &mut S) -> Vec<u8>
478 where
479 S: AsyncRead + Unpin,
480 {
481 let mut buf = Vec::new();
482 let mut chunk = [0u8; 256];
483
484 loop {
485 let n = stream.read(&mut chunk).await.unwrap();
486 assert!(n > 0, "HTTP request closed before headers completed");
487 buf.extend_from_slice(&chunk[..n]);
488 if buf.windows(4).any(|window| window == b"\r\n\r\n") {
489 return buf;
490 }
491 }
492 }
493
494 #[cfg(not(feature = "turmoil"))]
495 fn build_test_response(sec_websocket_key: &str, extra_bytes: &[u8]) -> Vec<u8> {
496 let accept = generate_accept_key(sec_websocket_key);
497 let mut response = format!(
498 concat!(
499 "HTTP/1.1 101 Switching Protocols\r\n",
500 "Upgrade: websocket\r\n",
501 "Connection: Upgrade\r\n",
502 "Sec-WebSocket-Accept: {}\r\n",
503 "\r\n",
504 ),
505 accept
506 )
507 .into_bytes();
508 response.extend_from_slice(extra_bytes);
509 response
510 }
511
512 #[cfg(not(feature = "turmoil"))]
513 fn extract_header<'a>(request: &'a str, name: &str) -> Option<&'a str> {
514 request.lines().find_map(|line| {
515 let (header_name, header_value) = line.split_once(':')?;
516 if header_name.eq_ignore_ascii_case(name) {
517 Some(header_value.trim())
518 } else {
519 None
520 }
521 })
522 }
523
524 #[tokio::test]
525 #[cfg(not(feature = "turmoil"))]
526 async fn client_handshake_with_headers_sends_custom_headers() {
527 let (mut client, mut server) = duplex(4096);
528 let headers = vec![
529 ("ok-access-key".to_string(), "key-1".to_string()),
530 ("ok-access-passphrase".to_string(), "pass-1".to_string()),
531 ];
532
533 let server_task = tokio::spawn(async move {
534 let request = read_http_request(&mut server).await;
535 let request = String::from_utf8(request).unwrap();
536
537 assert!(request.starts_with("GET /ws/v5/public-sbe?instId=BTC-USDT HTTP/1.1\r\n"));
538 assert_eq!(extract_header(&request, "Host"), Some("ws.okx.com:8443"));
539 assert_eq!(extract_header(&request, "ok-access-key"), Some("key-1"));
540 assert_eq!(
541 extract_header(&request, "ok-access-passphrase"),
542 Some("pass-1")
543 );
544
545 let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
546 let response = build_test_response(sec_websocket_key, &[]);
547 server.write_all(&response).await.unwrap();
548 });
549
550 let handshake = client_handshake_with_headers(
551 &mut client,
552 "ws.okx.com:8443",
553 "/ws/v5/public-sbe?instId=BTC-USDT",
554 None,
555 &headers,
556 )
557 .await
558 .unwrap();
559
560 assert_eq!(handshake.path, "/ws/v5/public-sbe?instId=BTC-USDT");
561 assert!(handshake.leftover.is_none());
562 server_task.await.unwrap();
563 }
564
565 #[rstest]
566 #[cfg(not(feature = "turmoil"))]
567 #[case::host("Host")]
568 #[case::upgrade("Upgrade")]
569 #[case::connection("Connection")]
570 #[case::sec_websocket_key("Sec-WebSocket-Key")]
571 #[case::sec_websocket_version("Sec-WebSocket-Version")]
572 #[case::sec_websocket_protocol("Sec-WebSocket-Protocol")]
573 #[case::sec_websocket_extensions("Sec-WebSocket-Extensions")]
574 #[case::content_length("Content-Length")]
575 #[case::transfer_encoding("Transfer-Encoding")]
576 #[case::te("TE")]
577 #[case::trailer("Trailer")]
578 fn validate_extra_header_rejects_reserved_upgrade_headers(#[case] name: &str) {
579 let err = validate_extra_header(name, "value").unwrap_err();
580
581 assert!(matches!(
582 err,
583 SockudoError::InvalidHttp("reserved upgrade header not allowed in extra_headers")
584 ));
585 }
586
587 #[tokio::test]
588 #[cfg(not(feature = "turmoil"))]
589 async fn client_handshake_with_headers_rejects_missing_accept() {
590 let (mut client, mut server) = duplex(4096);
591
592 let server_task = tokio::spawn(async move {
593 let _request = read_http_request(&mut server).await;
594 server
595 .write_all(
596 b"HTTP/1.1 101 Switching Protocols\r\n\
597 Upgrade: websocket\r\n\
598 Connection: Upgrade\r\n\
599 \r\n",
600 )
601 .await
602 .unwrap();
603 });
604
605 let err = client_handshake_with_headers(&mut client, "example.com", "/ws", None, &[])
606 .await
607 .unwrap_err();
608
609 assert!(matches!(
610 err,
611 SockudoError::HandshakeFailed("missing Sec-WebSocket-Accept")
612 ));
613 server_task.await.unwrap();
614 }
615
616 #[tokio::test]
617 #[cfg(not(feature = "turmoil"))]
618 async fn client_handshake_with_headers_returns_leftover_bytes() {
619 let (mut client, mut server) = duplex(4096);
620 let extra = b"\x81\x05hello";
621
622 let server_task = tokio::spawn(async move {
623 let request = read_http_request(&mut server).await;
624 let request = String::from_utf8(request).unwrap();
625 let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
626 let response = build_test_response(sec_websocket_key, extra);
627 server.write_all(&response).await.unwrap();
628 });
629
630 let handshake = client_handshake_with_headers(&mut client, "example.com", "/ws", None, &[])
631 .await
632 .unwrap();
633
634 assert_eq!(handshake.leftover.as_deref(), Some(extra.as_slice()));
635 server_task.await.unwrap();
636 }
637
638 #[tokio::test]
639 #[cfg(not(feature = "turmoil"))]
640 async fn prefixed_io_replays_leftover_before_socket() {
641 let (client, mut server) = duplex(4096);
642 let mut prefixed = PrefixedIo::new(client, Bytes::from_static(b"abc"));
643
644 let server_task = tokio::spawn(async move {
645 server.write_all(b"def").await.unwrap();
646 });
647
648 let mut buf = [0u8; 6];
649 prefixed.read_exact(&mut buf).await.unwrap();
650
651 assert_eq!(&buf, b"abcdef");
652 server_task.await.unwrap();
653 }
654
655 #[rstest]
656 fn round_trip_text() {
657 let original = SockudoMessage::Text(Bytes::from_static(b"hello"));
658 let neutral: Message = original.into();
659 assert!(neutral.is_text());
660 assert_eq!(neutral.as_bytes(), b"hello");
661
662 let back: SockudoMessage = neutral.into();
663 match back {
664 SockudoMessage::Text(b) => assert_eq!(&b[..], b"hello"),
665 other => panic!("expected text, was {other:?}"),
666 }
667 }
668
669 #[rstest]
670 fn round_trip_binary() {
671 let original = SockudoMessage::Binary(Bytes::from_static(&[1, 2, 3]));
672 let neutral: Message = original.into();
673 assert_eq!(neutral.as_bytes(), &[1, 2, 3]);
674
675 let back: SockudoMessage = neutral.into();
676 match back {
677 SockudoMessage::Binary(b) => assert_eq!(&b[..], &[1, 2, 3]),
678 other => panic!("expected binary, was {other:?}"),
679 }
680 }
681
682 #[rstest]
683 fn round_trip_ping_pong() {
684 let neutral: Message = SockudoMessage::Ping(Bytes::from_static(b"p")).into();
685 assert!(neutral.is_ping());
686
687 let neutral: Message = SockudoMessage::Pong(Bytes::from_static(b"q")).into();
688 assert!(neutral.is_pong());
689 }
690
691 #[rstest]
692 fn close_frame_round_trip() {
693 let original = SockudoMessage::Close(Some(SockudoCloseReason {
694 code: 1000,
695 reason: "bye".into(),
696 }));
697 let neutral: Message = original.into();
698 let Message::Close(Some(frame)) = &neutral else {
699 panic!("expected close frame");
700 };
701 assert_eq!(frame.code, 1000);
702 assert_eq!(frame.reason, "bye");
703
704 let back: SockudoMessage = neutral.into();
705 let SockudoMessage::Close(Some(reason)) = back else {
706 panic!("expected close frame");
707 };
708 assert_eq!(reason.code, 1000);
709 assert_eq!(reason.reason, "bye");
710 }
711
712 #[rstest]
713 fn error_translation_closed() {
714 let err: TransportError = SockudoError::ConnectionClosed.into();
715 assert!(matches!(err, TransportError::ConnectionClosed));
716 }
717
718 #[rstest]
719 fn error_translation_utf8() {
720 let err: TransportError = SockudoError::InvalidUtf8.into();
721 assert!(matches!(err, TransportError::InvalidUtf8));
722 }
723
724 #[rstest]
725 fn error_translation_handshake() {
726 let err: TransportError = SockudoError::HandshakeFailed("bad").into();
727 assert!(matches!(err, TransportError::Handshake(_)));
728 }
729}