compio_ws/
lib.rs

1//! WebSocket support based on [`tungstenite`].
2//!
3//! This library is an implementation of WebSocket handshakes and streams for
4//! compio. It is based on the tungstenite crate which implements all required
5//! WebSocket protocol logic. This crate brings compio support / compio
6//! integration to it.
7//!
8//! Each WebSocket stream implements message reading and writing.
9//!
10//! [`tungstenite`]: https://docs.rs/tungstenite
11
12#![cfg_attr(docsrs, feature(doc_cfg))]
13#![warn(missing_docs)]
14#![deny(rustdoc::broken_intra_doc_links)]
15#![doc(
16    html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
17)]
18#![doc(
19    html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
20)]
21
22use std::io::ErrorKind;
23
24use compio_buf::IntoInner;
25use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
26use tungstenite::{
27    Error as WsError, HandshakeError, Message, WebSocket,
28    client::IntoClientRequest,
29    handshake::server::{Callback, NoCallback},
30    protocol::{CloseFrame, WebSocketConfig},
31};
32
33mod tls;
34pub use tls::*;
35pub use tungstenite;
36
37/// Configuration for compio-ws.
38///
39/// ## API Interface
40///
41/// `_with_config` functions in this crate accept `impl Into<Config>`, so
42/// following are all valid:
43/// - [`Config`]
44/// - [`WebSocketConfig`] (use custom WebSocket config with default remaining
45///   settings)
46/// - [`None`] (use default value)
47pub struct Config {
48    /// WebSocket configuration from tungstenite.
49    websocket: Option<WebSocketConfig>,
50
51    /// Base buffer size
52    buffer_size_base: usize,
53
54    /// Maximum buffer size
55    buffer_size_limit: usize,
56
57    /// Disable Nagle's algorithm. This only affects
58    /// [`connect_async_with_config()`] and [`connect_async_tls_with_config()`].
59    disable_nagle: bool,
60}
61
62impl Config {
63    // 128 KiB, see <https://github.com/compio-rs/compio/pull/532>.
64    const DEFAULT_BUF_SIZE: usize = 128 * 1024;
65    // 64 MiB, the same as [`SyncStream`].
66    const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
67
68    /// Creates a new `Config` with default settings.
69    pub fn new() -> Self {
70        Self {
71            websocket: None,
72            buffer_size_base: Self::DEFAULT_BUF_SIZE,
73            buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
74            disable_nagle: false,
75        }
76    }
77
78    /// Get the WebSocket configuration.
79    pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
80        self.websocket.as_ref()
81    }
82
83    /// Get the base buffer size.
84    pub fn buffer_size_base(&self) -> usize {
85        self.buffer_size_base
86    }
87
88    /// Get the maximum buffer size.
89    pub fn buffer_size_limit(&self) -> usize {
90        self.buffer_size_limit
91    }
92
93    /// Set custom base buffer size.
94    ///
95    /// Default to 128 KiB.
96    pub fn with_buffer_size_base(mut self, size: usize) -> Self {
97        self.buffer_size_base = size;
98        self
99    }
100
101    /// Set custom maximum buffer size.
102    ///
103    /// Default to 64 MiB.
104    pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
105        self.buffer_size_limit = size;
106        self
107    }
108
109    /// Set custom buffer sizes.
110    ///
111    /// Default to 128 KiB for base and 64 MiB for limit.
112    pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
113        self.buffer_size_base = base;
114        self.buffer_size_limit = limit;
115        self
116    }
117
118    /// Disable Nagle's algorithm, i.e. `set_nodelay(true)`.
119    ///
120    /// Default to `false`. If you don't know what the Nagle's algorithm is,
121    /// better leave it to `false`.
122    pub fn disable_nagle(mut self, disable: bool) -> Self {
123        self.disable_nagle = disable;
124        self
125    }
126}
127
128impl Default for Config {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl From<WebSocketConfig> for Config {
135    fn from(config: WebSocketConfig) -> Self {
136        Self {
137            websocket: Some(config),
138            ..Default::default()
139        }
140    }
141}
142
143impl From<Option<WebSocketConfig>> for Config {
144    fn from(config: Option<WebSocketConfig>) -> Self {
145        Self {
146            websocket: config,
147            ..Default::default()
148        }
149    }
150}
151
152/// A WebSocket stream that works with compio.
153#[derive(Debug)]
154pub struct WebSocketStream<S> {
155    inner: WebSocket<SyncStream<S>>,
156}
157
158impl<S> WebSocketStream<S>
159where
160    S: AsyncRead + AsyncWrite,
161{
162    /// Send a message on the WebSocket stream.
163    pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
164        // Send the message - this buffers it
165        // Since CompioStream::flush() now returns Ok, this should succeed on first try
166        self.inner.send(message)?;
167
168        // flush the buffer to the network
169        self.flush().await
170    }
171
172    /// Read a message from the WebSocket stream.
173    pub async fn read(&mut self) -> Result<Message, WsError> {
174        loop {
175            match self.inner.read() {
176                Ok(msg) => {
177                    self.flush().await?;
178                    return Ok(msg);
179                }
180                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
181                    // Need more data - fill the read buffer
182                    self.inner
183                        .get_mut()
184                        .fill_read_buf()
185                        .await
186                        .map_err(WsError::Io)?;
187                }
188                Err(e) => {
189                    let _ = self.flush().await;
190                    return Err(e);
191                }
192            }
193        }
194    }
195
196    /// Flush the WebSocket stream.
197    pub async fn flush(&mut self) -> Result<(), WsError> {
198        loop {
199            match self.inner.flush() {
200                Ok(()) => break,
201                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
202                    self.inner
203                        .get_mut()
204                        .flush_write_buf()
205                        .await
206                        .map_err(WsError::Io)?;
207                }
208                Err(WsError::ConnectionClosed) => break,
209                Err(e) => return Err(e),
210            }
211        }
212        self.inner
213            .get_mut()
214            .flush_write_buf()
215            .await
216            .map_err(WsError::Io)?;
217        Ok(())
218    }
219
220    /// Close the WebSocket connection.
221    pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
222        loop {
223            match self.inner.close(close_frame.clone()) {
224                Ok(()) => break,
225                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
226                    let sync_stream = self.inner.get_mut();
227
228                    let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
229
230                    if flushed == 0 {
231                        sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
232                    }
233                }
234                Err(WsError::ConnectionClosed) => break,
235                Err(e) => return Err(e),
236            }
237        }
238        self.flush().await
239    }
240
241    /// Get a reference to the underlying stream.
242    pub fn get_ref(&self) -> &S {
243        self.inner.get_ref().get_ref()
244    }
245
246    /// Get a mutable reference to the underlying stream.
247    pub fn get_mut(&mut self) -> &mut S {
248        self.inner.get_mut().get_mut()
249    }
250}
251
252impl<S> IntoInner for WebSocketStream<S> {
253    type Inner = WebSocket<SyncStream<S>>;
254
255    fn into_inner(self) -> Self::Inner {
256        self.inner
257    }
258}
259
260/// Accepts a new WebSocket connection with the provided stream.
261///
262/// This function will internally create a handshake representation and returns
263/// a future representing the resolution of the WebSocket handshake. The
264/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
265/// depending on if it's successful or not.
266///
267/// This is typically used after a socket has been accepted from a
268/// `TcpListener`. That socket is then passed to this function to perform
269/// the server half of accepting a client's websocket connection.
270pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
271where
272    S: AsyncRead + AsyncWrite,
273{
274    accept_hdr_async(stream, NoCallback).await
275}
276
277/// Similar to [`accept_async()`] but user can specify a [`Config`].
278pub async fn accept_async_with_config<S>(
279    stream: S,
280    config: impl Into<Config>,
281) -> Result<WebSocketStream<S>, WsError>
282where
283    S: AsyncRead + AsyncWrite,
284{
285    accept_hdr_with_config_async(stream, NoCallback, config).await
286}
287/// Accepts a new WebSocket connection with the provided stream.
288///
289/// This function does the same as [`accept_async()`] but accepts an extra
290/// callback for header processing. The callback receives headers of the
291/// incoming requests and is able to add extra headers to the reply.
292pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
293where
294    S: AsyncRead + AsyncWrite,
295    C: Callback,
296{
297    accept_hdr_with_config_async(stream, callback, None).await
298}
299
300/// Similar to [`accept_hdr_async()`] but user can specify a [`Config`].
301pub async fn accept_hdr_with_config_async<S, C>(
302    stream: S,
303    callback: C,
304    config: impl Into<Config>,
305) -> Result<WebSocketStream<S>, WsError>
306where
307    S: AsyncRead + AsyncWrite,
308    C: Callback,
309{
310    let config = config.into();
311    let sync_stream =
312        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
313    let mut handshake_result =
314        tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
315
316    loop {
317        match handshake_result {
318            Ok(mut websocket) => {
319                websocket
320                    .get_mut()
321                    .flush_write_buf()
322                    .await
323                    .map_err(WsError::Io)?;
324                return Ok(WebSocketStream { inner: websocket });
325            }
326            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
327                let sync_stream = mid_handshake.get_mut().get_mut();
328
329                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
330
331                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
332
333                handshake_result = mid_handshake.handshake();
334            }
335            Err(HandshakeError::Failure(error)) => {
336                return Err(error);
337            }
338        }
339    }
340}
341
342/// Creates a WebSocket handshake from a request and a stream.
343///
344/// For convenience, the user may call this with a url string, a URL,
345/// or a `Request`. Calling with `Request` allows the user to add
346/// a WebSocket protocol or other custom headers.
347///
348/// Internally, this creates a handshake representation and returns
349/// a future representing the resolution of the WebSocket handshake. The
350/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
351/// depending on whether the handshake is successful.
352///
353/// This is typically used for clients who have already established, for
354/// example, a TCP connection to the remote server.
355pub async fn client_async<R, S>(
356    request: R,
357    stream: S,
358) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
359where
360    R: IntoClientRequest,
361    S: AsyncRead + AsyncWrite,
362{
363    client_async_with_config(request, stream, None).await
364}
365
366/// Similar to [`client_async()`] but user can specify a [`Config`].
367pub async fn client_async_with_config<R, S>(
368    request: R,
369    stream: S,
370    config: impl Into<Config>,
371) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
372where
373    R: IntoClientRequest,
374    S: AsyncRead + AsyncWrite,
375{
376    let config = config.into();
377    let sync_stream =
378        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
379    let mut handshake_result =
380        tungstenite::client::client_with_config(request, sync_stream, config.websocket);
381
382    loop {
383        match handshake_result {
384            Ok((mut websocket, response)) => {
385                // Ensure any remaining data is flushed
386                websocket
387                    .get_mut()
388                    .flush_write_buf()
389                    .await
390                    .map_err(WsError::Io)?;
391                return Ok((WebSocketStream { inner: websocket }, response));
392            }
393            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
394                let sync_stream = mid_handshake.get_mut().get_mut();
395
396                // For handshake: always try both operations
397                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
398
399                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
400
401                handshake_result = mid_handshake.handshake();
402            }
403            Err(HandshakeError::Failure(error)) => {
404                return Err(error);
405            }
406        }
407    }
408}