Skip to main content

compio_ws/
lib.rs

1//! WebSocket support based on [`tungstenite`].
2//!
3//! This library is an implementation of WebSocket handshakes and streams for
4//! compio. It is based on the tungstenite crate which implements all required
5//! WebSocket protocol logic. This crate brings compio support / compio
6//! integration to it.
7//!
8//! Each WebSocket stream implements message reading and writing.
9//!
10//! [`tungstenite`]: https://docs.rs/tungstenite
11
12#![cfg_attr(docsrs, feature(doc_cfg))]
13#![allow(unused_features)]
14#![warn(missing_docs)]
15#![deny(rustdoc::broken_intra_doc_links)]
16#![doc(
17    html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
18)]
19#![doc(
20    html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
21)]
22
23use std::io::ErrorKind;
24
25use compio_buf::IntoInner;
26use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
27use tungstenite::{
28    Error as WsError, HandshakeError, Message, WebSocket,
29    client::IntoClientRequest,
30    handshake::server::{Callback, NoCallback},
31    protocol::{CloseFrame, Role, WebSocketConfig},
32};
33
34mod tls;
35#[cfg(feature = "io-compat")]
36pub use compat::CompatWebSocketStream;
37pub use tls::*;
38pub use tungstenite;
39#[cfg(feature = "io-compat")]
40mod compat;
41
42/// Configuration for compio-ws.
43///
44/// ## API Interface
45///
46/// `_with_config` functions in this crate accept `impl Into<Config>`, so
47/// following are all valid:
48/// - [`Config`]
49/// - [`WebSocketConfig`] (use custom WebSocket config with default remaining
50///   settings)
51/// - [`None`] (use default value)
52pub struct Config {
53    /// WebSocket configuration from tungstenite.
54    websocket: Option<WebSocketConfig>,
55
56    /// Base buffer size
57    buffer_size_base: usize,
58
59    /// Maximum buffer size
60    buffer_size_limit: usize,
61
62    /// Disable Nagle's algorithm. This only affects
63    /// [`connect_async_with_config()`] and [`connect_async_tls_with_config()`].
64    disable_nagle: bool,
65}
66
67impl Config {
68    // 128 KiB, see <https://github.com/compio-rs/compio/pull/532>.
69    const DEFAULT_BUF_SIZE: usize = 128 * 1024;
70    // 64 MiB, the same as [`SyncStream`].
71    const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
72
73    /// Creates a new `Config` with default settings.
74    pub fn new() -> Self {
75        Self {
76            websocket: None,
77            buffer_size_base: Self::DEFAULT_BUF_SIZE,
78            buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
79            disable_nagle: false,
80        }
81    }
82
83    /// Get the WebSocket configuration.
84    pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
85        self.websocket.as_ref()
86    }
87
88    /// Get the base buffer size.
89    pub fn buffer_size_base(&self) -> usize {
90        self.buffer_size_base
91    }
92
93    /// Get the maximum buffer size.
94    pub fn buffer_size_limit(&self) -> usize {
95        self.buffer_size_limit
96    }
97
98    /// Set custom base buffer size.
99    ///
100    /// Default to 128 KiB.
101    pub fn with_buffer_size_base(mut self, size: usize) -> Self {
102        self.buffer_size_base = size;
103        self
104    }
105
106    /// Set custom maximum buffer size.
107    ///
108    /// Default to 64 MiB.
109    pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
110        self.buffer_size_limit = size;
111        self
112    }
113
114    /// Set custom buffer sizes.
115    ///
116    /// Default to 128 KiB for base and 64 MiB for limit.
117    pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
118        self.buffer_size_base = base;
119        self.buffer_size_limit = limit;
120        self
121    }
122
123    /// Disable Nagle's algorithm, i.e. `set_nodelay(true)`.
124    ///
125    /// Default to `false`. If you don't know what the Nagle's algorithm is,
126    /// better leave it to `false`.
127    pub fn disable_nagle(mut self, disable: bool) -> Self {
128        self.disable_nagle = disable;
129        self
130    }
131}
132
133impl Default for Config {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl From<WebSocketConfig> for Config {
140    fn from(config: WebSocketConfig) -> Self {
141        Self {
142            websocket: Some(config),
143            ..Default::default()
144        }
145    }
146}
147
148impl From<Option<WebSocketConfig>> for Config {
149    fn from(config: Option<WebSocketConfig>) -> Self {
150        Self {
151            websocket: config,
152            ..Default::default()
153        }
154    }
155}
156
157/// A WebSocket stream that works with compio.
158#[derive(Debug)]
159pub struct WebSocketStream<S> {
160    inner: WebSocket<SyncStream<S>>,
161}
162
163impl<S> WebSocketStream<S> {
164    /// Get a reference to the underlying stream.
165    pub fn get_ref(&self) -> &S {
166        self.inner.get_ref().get_ref()
167    }
168
169    /// Get a mutable reference to the underlying stream.
170    pub fn get_mut(&mut self) -> &mut S {
171        self.inner.get_mut().get_mut()
172    }
173}
174
175impl<S> WebSocketStream<S>
176where
177    S: AsyncRead + AsyncWrite,
178{
179    /// Convert a raw socket into a [`WebSocketStream`] without performing a
180    /// handshake.
181    ///
182    /// `disable_nagle` will be ignored since the socket is already connected
183    /// and the user can set `nodelay` on the socket directly before calling
184    /// this function if needed.
185    pub async fn from_raw_socket(stream: S, role: Role, config: impl Into<Config>) -> Self {
186        let config = config.into();
187        let sync_stream =
188            SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
189
190        WebSocketStream {
191            inner: WebSocket::from_raw_socket(sync_stream, role, config.websocket),
192        }
193    }
194
195    /// Convert a raw socket into a [`WebSocketStream`] without performing a
196    /// handshake.
197    ///
198    /// `disable_nagle` will be ignored since the socket is already connected
199    /// and the user can set `nodelay` on the socket directly before calling
200    /// this function if needed.
201    pub async fn from_partially_read(
202        stream: S,
203        part: Vec<u8>,
204        role: Role,
205        config: impl Into<Config>,
206    ) -> Self {
207        let config = config.into();
208        let sync_stream =
209            SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
210
211        WebSocketStream {
212            inner: WebSocket::from_partially_read(sync_stream, part, role, config.websocket),
213        }
214    }
215
216    /// Send a message on the WebSocket stream.
217    pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
218        match self.inner.write(message) {
219            Ok(()) => {}
220            Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {}
221            Err(e) => return Err(e),
222        }
223        // Need to flush the write buffer before we can send the message
224        self.flush().await
225    }
226
227    /// Read a message from the WebSocket stream.
228    pub async fn read(&mut self) -> Result<Message, WsError> {
229        loop {
230            match self.inner.read() {
231                Ok(msg) => {
232                    self.flush().await?;
233                    return Ok(msg);
234                }
235                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
236                    // Need more data - fill the read buffer
237                    self.fill_read_buf().await?;
238                }
239                Err(e) => {
240                    let _ = self.flush().await;
241                    return Err(e);
242                }
243            }
244        }
245    }
246
247    /// Flush the WebSocket stream.
248    pub async fn flush(&mut self) -> Result<(), WsError> {
249        loop {
250            match self.inner.flush() {
251                Ok(()) => break,
252                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
253                    self.flush_write_buf().await?;
254                }
255                Err(WsError::ConnectionClosed) => break,
256                Err(e) => return Err(e),
257            }
258        }
259        self.flush_write_buf().await?;
260        Ok(())
261    }
262
263    /// Close the WebSocket connection.
264    pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
265        loop {
266            match self.inner.close(close_frame.clone()) {
267                Ok(()) => break,
268                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
269                    let flushed = self.flush_write_buf().await?;
270                    if flushed == 0 {
271                        self.fill_read_buf().await?;
272                    }
273                }
274                Err(WsError::ConnectionClosed) => break,
275                Err(e) => return Err(e),
276            }
277        }
278        self.flush().await
279    }
280
281    pub(crate) async fn flush_write_buf(&mut self) -> Result<usize, WsError> {
282        self.inner
283            .get_mut()
284            .flush_write_buf()
285            .await
286            .map_err(WsError::Io)
287    }
288
289    pub(crate) async fn fill_read_buf(&mut self) -> Result<usize, WsError> {
290        self.inner
291            .get_mut()
292            .fill_read_buf()
293            .await
294            .map_err(WsError::Io)
295    }
296
297    /// Convert this stream into a [`futures_util`] compatible stream.
298    #[cfg(feature = "io-compat")]
299    pub fn into_compat(self) -> CompatWebSocketStream<S>
300    // Ensure internal mutability of the stream.
301    where
302        for<'a> &'a S: AsyncRead + AsyncWrite,
303        S: Unpin,
304    {
305        CompatWebSocketStream::new(self.inner)
306    }
307}
308
309impl<S> IntoInner for WebSocketStream<S> {
310    type Inner = WebSocket<SyncStream<S>>;
311
312    fn into_inner(self) -> Self::Inner {
313        self.inner
314    }
315}
316
317/// Accepts a new WebSocket connection with the provided stream.
318///
319/// This function will internally create a handshake representation and returns
320/// a future representing the resolution of the WebSocket handshake. The
321/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
322/// depending on if it's successful or not.
323///
324/// This is typically used after a socket has been accepted from a
325/// `TcpListener`. That socket is then passed to this function to perform
326/// the server half of accepting a client's websocket connection.
327pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
328where
329    S: AsyncRead + AsyncWrite,
330{
331    accept_hdr_async(stream, NoCallback).await
332}
333
334/// Similar to [`accept_async()`] but user can specify a [`Config`].
335pub async fn accept_async_with_config<S>(
336    stream: S,
337    config: impl Into<Config>,
338) -> Result<WebSocketStream<S>, WsError>
339where
340    S: AsyncRead + AsyncWrite,
341{
342    accept_hdr_with_config_async(stream, NoCallback, config).await
343}
344/// Accepts a new WebSocket connection with the provided stream.
345///
346/// This function does the same as [`accept_async()`] but accepts an extra
347/// callback for header processing. The callback receives headers of the
348/// incoming requests and is able to add extra headers to the reply.
349pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
350where
351    S: AsyncRead + AsyncWrite,
352    C: Callback,
353{
354    accept_hdr_with_config_async(stream, callback, None).await
355}
356
357/// Similar to [`accept_hdr_async()`] but user can specify a [`Config`].
358pub async fn accept_hdr_with_config_async<S, C>(
359    stream: S,
360    callback: C,
361    config: impl Into<Config>,
362) -> Result<WebSocketStream<S>, WsError>
363where
364    S: AsyncRead + AsyncWrite,
365    C: Callback,
366{
367    let config = config.into();
368    let sync_stream =
369        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
370    let mut handshake_result =
371        tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
372
373    loop {
374        match handshake_result {
375            Ok(mut websocket) => {
376                websocket
377                    .get_mut()
378                    .flush_write_buf()
379                    .await
380                    .map_err(WsError::Io)?;
381                return Ok(WebSocketStream { inner: websocket });
382            }
383            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
384                let sync_stream = mid_handshake.get_mut().get_mut();
385
386                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
387
388                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
389
390                handshake_result = mid_handshake.handshake();
391            }
392            Err(HandshakeError::Failure(error)) => {
393                return Err(error);
394            }
395        }
396    }
397}
398
399/// Creates a WebSocket handshake from a request and a stream.
400///
401/// For convenience, the user may call this with a url string, a URL,
402/// or a `Request`. Calling with `Request` allows the user to add
403/// a WebSocket protocol or other custom headers.
404///
405/// Internally, this creates a handshake representation and returns
406/// a future representing the resolution of the WebSocket handshake. The
407/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
408/// depending on whether the handshake is successful.
409///
410/// This is typically used for clients who have already established, for
411/// example, a TCP connection to the remote server.
412pub async fn client_async<R, S>(
413    request: R,
414    stream: S,
415) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
416where
417    R: IntoClientRequest,
418    S: AsyncRead + AsyncWrite,
419{
420    client_async_with_config(request, stream, None).await
421}
422
423/// Similar to [`client_async()`] but user can specify a [`Config`].
424pub async fn client_async_with_config<R, S>(
425    request: R,
426    stream: S,
427    config: impl Into<Config>,
428) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
429where
430    R: IntoClientRequest,
431    S: AsyncRead + AsyncWrite,
432{
433    let config = config.into();
434    let sync_stream =
435        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
436    let mut handshake_result =
437        tungstenite::client::client_with_config(request, sync_stream, config.websocket);
438
439    loop {
440        match handshake_result {
441            Ok((mut websocket, response)) => {
442                // Ensure any remaining data is flushed
443                websocket
444                    .get_mut()
445                    .flush_write_buf()
446                    .await
447                    .map_err(WsError::Io)?;
448                return Ok((WebSocketStream { inner: websocket }, response));
449            }
450            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
451                let sync_stream = mid_handshake.get_mut().get_mut();
452
453                // For handshake: always try both operations
454                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
455
456                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
457
458                handshake_result = mid_handshake.handshake();
459            }
460            Err(HandshakeError::Failure(error)) => {
461                return Err(error);
462            }
463        }
464    }
465}