Skip to main content

compio_ws/
lib.rs

1//! WebSocket support based on [`tungstenite`].
2//!
3//! This library is an implementation of WebSocket handshakes and streams for
4//! compio. It is based on the tungstenite crate which implements all required
5//! WebSocket protocol logic. This crate brings compio support / compio
6//! integration to it.
7//!
8//! Each WebSocket stream implements message reading and writing.
9//!
10//! [`tungstenite`]: https://docs.rs/tungstenite
11
12#![cfg_attr(docsrs, feature(doc_cfg))]
13#![allow(unused_features)]
14#![warn(missing_docs)]
15#![deny(rustdoc::broken_intra_doc_links)]
16#![doc(
17    html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
18)]
19#![doc(
20    html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
21)]
22
23use std::{
24    pin::Pin,
25    task::{Context, Poll, ready},
26};
27
28use compio_buf::IntoInner;
29use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
30use compio_tls::{MaybeTlsStream, TlsStream};
31use futures_util::{Sink, SinkExt, Stream, StreamExt, stream::FusedStream};
32use pin_project_lite::pin_project;
33use tungstenite::{
34    Error as WsError, Message,
35    client::IntoClientRequest,
36    handshake::server::{Callback, NoCallback},
37    protocol::{CloseFrame, Role, WebSocketConfig},
38};
39
40#[cfg(feature = "connect")]
41mod tls;
42#[cfg(feature = "connect")]
43pub use tls::*;
44pub use tungstenite;
45
46/// Configuration for compio-ws.
47///
48/// ## API Interface
49///
50/// `_with_config` functions in this crate accept `impl Into<Config>`, so
51/// following are all valid:
52/// - [`Config`]
53/// - [`WebSocketConfig`] (use custom WebSocket config with default remaining
54///   settings)
55/// - [`None`] (use default value)
56pub struct Config {
57    /// WebSocket configuration from tungstenite.
58    websocket: Option<WebSocketConfig>,
59
60    /// Base buffer size
61    buffer_size_base: usize,
62
63    /// Maximum buffer size
64    buffer_size_limit: usize,
65
66    /// Disable Nagle's algorithm. This only affects
67    /// [`connect_async_with_config()`] and [`connect_async_tls_with_config()`].
68    disable_nagle: bool,
69}
70
71impl Config {
72    // 128 KiB, see <https://github.com/compio-rs/compio/pull/532>.
73    const DEFAULT_BUF_SIZE: usize = 128 * 1024;
74    // 64 MiB, the same as [`SyncStream`].
75    const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
76
77    /// Creates a new `Config` with default settings.
78    pub fn new() -> Self {
79        Self {
80            websocket: None,
81            buffer_size_base: Self::DEFAULT_BUF_SIZE,
82            buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
83            disable_nagle: false,
84        }
85    }
86
87    /// Get the WebSocket configuration.
88    pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
89        self.websocket.as_ref()
90    }
91
92    /// Get the base buffer size.
93    pub fn buffer_size_base(&self) -> usize {
94        self.buffer_size_base
95    }
96
97    /// Get the maximum buffer size.
98    pub fn buffer_size_limit(&self) -> usize {
99        self.buffer_size_limit
100    }
101
102    /// Set custom base buffer size.
103    ///
104    /// Default to 128 KiB.
105    pub fn with_buffer_size_base(mut self, size: usize) -> Self {
106        self.buffer_size_base = size;
107        self
108    }
109
110    /// Set custom maximum buffer size.
111    ///
112    /// Default to 64 MiB.
113    pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
114        self.buffer_size_limit = size;
115        self
116    }
117
118    /// Set custom buffer sizes.
119    ///
120    /// Default to 128 KiB for base and 64 MiB for limit.
121    pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
122        self.buffer_size_base = base;
123        self.buffer_size_limit = limit;
124        self
125    }
126
127    /// Disable Nagle's algorithm, i.e. `set_nodelay(true)`.
128    ///
129    /// Default to `false`. If you don't know what the Nagle's algorithm is,
130    /// better leave it to `false`.
131    pub fn disable_nagle(mut self, disable: bool) -> Self {
132        self.disable_nagle = disable;
133        self
134    }
135}
136
137impl Default for Config {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143impl From<WebSocketConfig> for Config {
144    fn from(config: WebSocketConfig) -> Self {
145        Self {
146            websocket: Some(config),
147            ..Default::default()
148        }
149    }
150}
151
152impl From<Option<WebSocketConfig>> for Config {
153    fn from(config: Option<WebSocketConfig>) -> Self {
154        Self {
155            websocket: config,
156            ..Default::default()
157        }
158    }
159}
160
161mod private {
162    use super::*;
163
164    pub trait Sealed<S>
165    where
166        S: Splittable,
167    {
168    }
169
170    impl<S: Splittable> Sealed<S> for S {}
171    impl<S: Splittable> Sealed<S> for AsyncStream<S> {}
172    impl<S: Splittable> Sealed<S> for MaybeTlsStream<S> {}
173    impl<S: Splittable> Sealed<S> for TlsStream<S> {}
174}
175
176/// Create [`MaybeTlsStream`] with capacity and buffer size limit.
177pub trait IntoMaybeTlsStream<S>: private::Sealed<S>
178where
179    S: Splittable,
180{
181    /// Create [`MaybeTlsStream`] with capacity and buffer size limit.
182    fn into_maybe_tls_stream(self, capacity: usize, max_buffer_size: usize) -> MaybeTlsStream<S>;
183}
184
185impl<S: Splittable> IntoMaybeTlsStream<S> for S {
186    fn into_maybe_tls_stream(self, capacity: usize, max_buffer_size: usize) -> MaybeTlsStream<S> {
187        MaybeTlsStream::new_plain_compat(AsyncStream::with_limits(capacity, max_buffer_size, self))
188    }
189}
190
191impl<S: Splittable> IntoMaybeTlsStream<S> for AsyncStream<S> {
192    fn into_maybe_tls_stream(self, _: usize, _: usize) -> MaybeTlsStream<S> {
193        MaybeTlsStream::new_plain_compat(self)
194    }
195}
196
197impl<S: Splittable> IntoMaybeTlsStream<S> for MaybeTlsStream<S> {
198    fn into_maybe_tls_stream(self, _: usize, _: usize) -> MaybeTlsStream<S> {
199        self
200    }
201}
202
203impl<S: Splittable> IntoMaybeTlsStream<S> for TlsStream<S> {
204    fn into_maybe_tls_stream(self, _: usize, _: usize) -> MaybeTlsStream<S> {
205        MaybeTlsStream::new_tls(self)
206    }
207}
208
209pin_project! {
210    /// A WebSocket stream that works with compio.
211    #[derive(Debug)]
212    pub struct WebSocketStream<S: Splittable> {
213        #[pin]
214        inner: async_tungstenite::WebSocketStream<MaybeTlsStream<S>>,
215        next_item: Option<Option<Result<Message, WsError>>>,
216    }
217}
218
219impl<S: Splittable + 'static> WebSocketStream<S>
220where
221    S::ReadHalf: AsyncRead + Unpin,
222    S::WriteHalf: AsyncWrite + Unpin,
223{
224    /// Get a reference to the underlying stream.
225    pub fn get_ref(&self) -> &MaybeTlsStream<S> {
226        self.inner.get_ref()
227    }
228
229    /// Get a mutable reference to the underlying stream.
230    pub fn get_mut(&mut self) -> &mut MaybeTlsStream<S> {
231        self.inner.get_mut()
232    }
233
234    /// Convert a raw socket into a [`WebSocketStream`] without performing a
235    /// handshake.
236    ///
237    /// `disable_nagle` will be ignored since the socket is already connected
238    /// and the user can set `nodelay` on the socket directly before calling
239    /// this function if needed.
240    pub async fn from_raw_socket<T: IntoMaybeTlsStream<S>>(
241        stream: T,
242        role: Role,
243        config: impl Into<Config>,
244    ) -> Self {
245        let config = config.into();
246
247        Self::from_inner(
248            async_tungstenite::WebSocketStream::from_raw_socket(
249                stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
250                role,
251                config.websocket,
252            )
253            .await,
254        )
255    }
256
257    /// Convert a raw socket into a [`WebSocketStream`] without performing a
258    /// handshake.
259    ///
260    /// `disable_nagle` will be ignored since the socket is already connected
261    /// and the user can set `nodelay` on the socket directly before calling
262    /// this function if needed.
263    pub async fn from_partially_read<T: IntoMaybeTlsStream<S>>(
264        stream: T,
265        part: Vec<u8>,
266        role: Role,
267        config: impl Into<Config>,
268    ) -> Self {
269        let config = config.into();
270
271        Self::from_inner(
272            async_tungstenite::WebSocketStream::from_partially_read(
273                stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
274                part,
275                role,
276                config.websocket,
277            )
278            .await,
279        )
280    }
281
282    fn from_inner(inner: async_tungstenite::WebSocketStream<MaybeTlsStream<S>>) -> Self {
283        WebSocketStream {
284            inner,
285            next_item: None,
286        }
287    }
288
289    /// Send a message on the WebSocket stream.
290    pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
291        SinkExt::send(self, message).await
292    }
293
294    /// Read a message from the WebSocket stream.
295    pub async fn read(&mut self) -> Result<Message, WsError> {
296        self.next()
297            .await
298            .unwrap_or_else(|| Err(WsError::ConnectionClosed))
299    }
300
301    /// Flush the WebSocket stream.
302    pub async fn flush(&mut self) -> Result<(), WsError> {
303        SinkExt::flush(self).await
304    }
305
306    /// Close the WebSocket connection.
307    pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
308        self.send(Message::Close(close_frame)).await
309    }
310}
311
312impl<S: Splittable> IntoInner for WebSocketStream<S> {
313    type Inner = MaybeTlsStream<S>;
314
315    fn into_inner(self) -> Self::Inner {
316        self.inner.into_inner()
317    }
318}
319
320impl<S: Splittable + 'static> Sink<Message> for WebSocketStream<S>
321where
322    S::ReadHalf: AsyncRead + Unpin,
323    S::WriteHalf: AsyncWrite + Unpin,
324{
325    type Error = WsError;
326
327    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
328        self.project().inner.poll_ready(cx)
329    }
330
331    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
332        self.project().inner.start_send(item)
333    }
334
335    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
336        ready!(self.as_mut().project().inner.poll_flush(cx))?;
337        ready!(futures_util::AsyncWrite::poll_flush(
338            Pin::new(self.project().inner.get_mut().get_mut()),
339            cx
340        ))?;
341        Poll::Ready(Ok(()))
342    }
343
344    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
345        self.project().inner.poll_close(cx)
346    }
347}
348
349impl<S: Splittable + 'static> Stream for WebSocketStream<S>
350where
351    S::ReadHalf: AsyncRead + Unpin,
352    S::WriteHalf: AsyncWrite + Unpin,
353{
354    type Item = Result<Message, WsError>;
355
356    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357        let mut this = self.project();
358        loop {
359            if this.next_item.is_some() {
360                ready!(this.inner.as_mut().poll_flush(cx))?;
361                ready!(futures_util::AsyncWrite::poll_flush(
362                    Pin::new(this.inner.get_mut().get_mut()),
363                    cx
364                ))?;
365                break Poll::Ready(this.next_item.take().expect("next_item should be Some"));
366            } else {
367                let item = ready!(this.inner.as_mut().poll_next(cx));
368                *this.next_item = Some(item);
369            }
370        }
371    }
372}
373
374impl<S: Splittable + 'static> FusedStream for WebSocketStream<S>
375where
376    S::ReadHalf: AsyncRead + Unpin,
377    S::WriteHalf: AsyncWrite + Unpin,
378{
379    fn is_terminated(&self) -> bool {
380        self.inner.is_terminated()
381    }
382}
383
384/// Accepts a new WebSocket connection with the provided stream.
385///
386/// This function will internally create a handshake representation and returns
387/// a future representing the resolution of the WebSocket handshake. The
388/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
389/// depending on if it's successful or not.
390///
391/// This is typically used after a socket has been accepted from a
392/// `TcpListener`. That socket is then passed to this function to perform
393/// the server half of accepting a client's websocket connection.
394pub async fn accept_async<S, T>(stream: T) -> Result<WebSocketStream<S>, WsError>
395where
396    S: Splittable + 'static,
397    <S as Splittable>::ReadHalf: AsyncRead + Unpin,
398    <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
399    T: IntoMaybeTlsStream<S>,
400{
401    accept_hdr_async(stream, NoCallback).await
402}
403
404/// Similar to [`accept_async()`] but user can specify a [`Config`].
405pub async fn accept_async_with_config<S, T>(
406    stream: T,
407    config: impl Into<Config>,
408) -> Result<WebSocketStream<S>, WsError>
409where
410    S: Splittable + 'static,
411    <S as Splittable>::ReadHalf: AsyncRead + Unpin,
412    <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
413    T: IntoMaybeTlsStream<S>,
414{
415    accept_hdr_with_config_async(stream, NoCallback, config).await
416}
417
418/// Accepts a new WebSocket connection with the provided stream.
419///
420/// This function does the same as [`accept_async()`] but accepts an extra
421/// callback for header processing. The callback receives headers of the
422/// incoming requests and is able to add extra headers to the reply.
423pub async fn accept_hdr_async<S, T, C>(
424    stream: T,
425    callback: C,
426) -> Result<WebSocketStream<S>, WsError>
427where
428    S: Splittable + 'static,
429    T: IntoMaybeTlsStream<S>,
430    C: Callback + Unpin,
431    <S as Splittable>::ReadHalf: AsyncRead + Unpin,
432    <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
433{
434    accept_hdr_with_config_async(stream, callback, None).await
435}
436
437/// Similar to [`accept_hdr_async()`] but user can specify a [`Config`].
438pub async fn accept_hdr_with_config_async<S, T, C>(
439    stream: T,
440    callback: C,
441    config: impl Into<Config>,
442) -> Result<WebSocketStream<S>, WsError>
443where
444    S: Splittable + 'static,
445    T: IntoMaybeTlsStream<S>,
446    C: Callback + Unpin,
447    <S as Splittable>::ReadHalf: AsyncRead + Unpin,
448    <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
449{
450    let config = config.into();
451    let inner = async_tungstenite::accept_hdr_async_with_config(
452        stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
453        callback,
454        config.websocket,
455    )
456    .await?;
457    Ok(WebSocketStream::from_inner(inner))
458}
459
460/// Creates a WebSocket handshake from a request and a stream.
461///
462/// For convenience, the user may call this with a url string, a URL,
463/// or a `Request`. Calling with `Request` allows the user to add
464/// a WebSocket protocol or other custom headers.
465///
466/// Internally, this creates a handshake representation and returns
467/// a future representing the resolution of the WebSocket handshake. The
468/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
469/// depending on whether the handshake is successful.
470///
471/// This is typically used for clients who have already established, for
472/// example, a TCP connection to the remote server.
473pub async fn client_async<R, S, T>(
474    request: R,
475    stream: T,
476) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
477where
478    R: IntoClientRequest + Unpin,
479    S: Splittable + 'static,
480    <S as Splittable>::ReadHalf: AsyncRead + Unpin,
481    <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
482    T: IntoMaybeTlsStream<S>,
483{
484    client_async_with_config(request, stream, None).await
485}
486
487/// Similar to [`client_async()`] but user can specify a [`Config`].
488pub async fn client_async_with_config<R, S, T>(
489    request: R,
490    stream: T,
491    config: impl Into<Config>,
492) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
493where
494    R: IntoClientRequest + Unpin,
495    S: Splittable + 'static,
496    <S as Splittable>::ReadHalf: AsyncRead + Unpin,
497    <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
498    T: IntoMaybeTlsStream<S>,
499{
500    let config = config.into();
501    let (inner, response) = async_tungstenite::client_async_with_config(
502        request,
503        stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
504        config.websocket,
505    )
506    .await?;
507    Ok((WebSocketStream::from_inner(inner), response))
508}