Skip to main content

compio_ws/
compat.rs

1use std::{
2    marker::PhantomPinned,
3    ops::Deref,
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll, Wake, Waker, ready},
7};
8
9use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
10use futures_util::{Sink, Stream};
11use pin_project_lite::pin_project;
12use tungstenite::{Message, WebSocket};
13
14use crate::WsError;
15
16type PinBoxFuture<T> = Pin<Box<dyn Future<Output = T>>>;
17
18enum Flushing {
19    None,
20    WouldBlock,
21    Flushed,
22}
23
24enum Closing {
25    None,
26    WouldBlockFlush,
27    WouldBlockFill,
28    Closed,
29}
30
31enum Reading {
32    None,
33    AfterRead(Result<Message, WsError>),
34    WouldBlock,
35}
36
37pin_project! {
38    /// A [`futures_util`] compatible WebSocket stream.
39    pub struct CompatWebSocketStream<S> {
40        #[pin]
41        inner: WebSocket<SyncStream<S>>,
42        read_future: Option<PinBoxFuture<Result<usize, std::io::Error>>>,
43        write_future: Option<PinBoxFuture<Result<usize, std::io::Error>>>,
44        ready_waker: Option<Waker>,
45        flush_waker: Option<Waker>,
46        close_waker: Option<Waker>,
47        read_waker: Option<Waker>,
48        flushing: Flushing,
49        closing: Closing,
50        reading: Reading,
51        // This is a self-referential struct, so we need to prevent it from being `Unpin`.
52        #[pin]
53        _p: PhantomPinned,
54    }
55}
56
57impl<S> CompatWebSocketStream<S> {
58    pub(super) fn new(stream: WebSocket<SyncStream<S>>) -> Self {
59        Self {
60            inner: stream,
61            read_future: None,
62            write_future: None,
63            ready_waker: None,
64            flush_waker: None,
65            close_waker: None,
66            read_waker: None,
67            flushing: Flushing::None,
68            closing: Closing::None,
69            reading: Reading::None,
70            _p: PhantomPinned,
71        }
72    }
73}
74
75impl<S> Deref for CompatWebSocketStream<S> {
76    type Target = WebSocket<SyncStream<S>>;
77
78    fn deref(&self) -> &Self::Target {
79        &self.inner
80    }
81}
82
83macro_rules! poll_future {
84    ($f:expr, $cx:expr, $e:expr) => {{
85        let mut future = match $f.take() {
86            Some(f) => f,
87            None => Box::pin($e),
88        };
89        let f = future.as_mut();
90        match f.poll($cx) {
91            Poll::Pending => {
92                $f.replace(future);
93                return Poll::Pending;
94            }
95            Poll::Ready(res) => res,
96        }
97    }};
98}
99
100unsafe fn extend_lifetime<T>(t: &mut T) -> &'static mut T {
101    unsafe { &mut *(t as *mut T) }
102}
103
104impl<S: AsyncRead + AsyncWrite + Unpin + 'static> CompatWebSocketStream<S>
105where
106    for<'a> &'a S: AsyncRead + AsyncWrite,
107{
108    fn poll_flush_write_buf(self: Pin<&mut Self>) -> Poll<Result<usize, WsError>> {
109        let this = self.project();
110        // SAFETY:
111        // - The future won't live longer than the stream.
112        // - The stream is `Unpin`, and is internally mutable.
113        // - The future only accesses the corresponding buffer and fields.
114        //   - No access overlap between the futures.
115        // - The future is polled immediately after creation, so it takes the ownership
116        //   of the inner buffer.
117        // - The sync methods of `SyncStream` check if the inner buffer is already
118        //   borrowed, and returns `WouldBlock` if it is.
119        let inner: &'static mut SyncStream<S> =
120            unsafe { extend_lifetime(this.inner.get_mut().get_mut()) };
121        let arr = WakerArray([
122            this.ready_waker.as_ref().cloned(),
123            this.flush_waker.as_ref().cloned(),
124            this.close_waker.as_ref().cloned(),
125            this.read_waker.as_ref().cloned(),
126        ]);
127        let waker = Waker::from(Arc::new(arr));
128        let cx = &mut Context::from_waker(&waker);
129        let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
130        Poll::Ready(res.map_err(WsError::Io))
131    }
132
133    fn poll_fill_read_buf(self: Pin<&mut Self>) -> Poll<Result<usize, WsError>> {
134        let this = self.project();
135        // SAFETY:
136        // - The future won't live longer than the stream.
137        // - The stream is `Unpin`, and is internally mutable.
138        // - The future only accesses the corresponding buffer and fields.
139        //   - No access overlap between the futures.
140        // - The future is polled immediately after creation, so it takes the ownership
141        //   of the inner buffer.
142        // - The sync methods of `SyncStream` check if the inner buffer is already
143        //   borrowed, and returns `WouldBlock` if it is.
144        let inner: &'static mut SyncStream<S> =
145            unsafe { extend_lifetime(this.inner.get_mut().get_mut()) };
146        let arr = WakerArray([
147            this.close_waker.as_ref().cloned(),
148            this.read_waker.as_ref().cloned(),
149        ]);
150        let waker = Waker::from(Arc::new(arr));
151        let cx = &mut Context::from_waker(&waker);
152        let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
153        Poll::Ready(res.map_err(WsError::Io))
154    }
155
156    fn poll_flush_impl(mut self: Pin<&mut Self>) -> Poll<Result<(), WsError>> {
157        loop {
158            let mut this = self.as_mut().project();
159            match this.flushing {
160                Flushing::None => {
161                    *this.flushing = match this.inner.flush() {
162                        Ok(()) => Flushing::Flushed,
163                        Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
164                            Flushing::WouldBlock
165                        }
166                        Err(WsError::ConnectionClosed) => Flushing::Flushed,
167                        Err(e) => return Poll::Ready(Err(e)),
168                    }
169                }
170                Flushing::WouldBlock => {
171                    ready!(self.as_mut().poll_flush_write_buf())?;
172                    *self.as_mut().project().flushing = Flushing::None
173                }
174                Flushing::Flushed => {
175                    ready!(self.as_mut().poll_flush_write_buf())?;
176                    let this = self.as_mut().project();
177                    *this.flushing = Flushing::None;
178                    this.flush_waker.take();
179                    return Poll::Ready(Ok(()));
180                }
181            }
182        }
183    }
184}
185
186fn replace_waker(waker_slot: &mut Option<Waker>, waker: &Waker) {
187    if !waker_slot.as_ref().is_some_and(|w| w.will_wake(waker)) {
188        waker_slot.replace(waker.clone());
189    }
190}
191
192impl<S: AsyncRead + AsyncWrite + Unpin + 'static> Sink<Message> for CompatWebSocketStream<S>
193where
194    for<'a> &'a S: AsyncRead + AsyncWrite,
195{
196    type Error = tungstenite::Error;
197
198    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199        if self.write_future.is_some() {
200            replace_waker(self.as_mut().project().ready_waker, cx.waker());
201            ready!(self.as_mut().poll_flush_write_buf())?;
202            self.as_mut().project().ready_waker.take();
203        }
204        Poll::Ready(Ok(()))
205    }
206
207    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
208        match self.project().inner.write(item) {
209            Ok(()) => Ok(()),
210            Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
211            Err(e) => Err(e),
212        }
213    }
214
215    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216        replace_waker(self.as_mut().project().flush_waker, cx.waker());
217        self.poll_flush_impl()
218    }
219
220    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
221        replace_waker(self.as_mut().project().close_waker, cx.waker());
222        loop {
223            let mut this = self.as_mut().project();
224            match this.closing {
225                Closing::None => {
226                    *this.closing = match this.inner.close(None) {
227                        Ok(()) => Closing::Closed,
228                        Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
229                            Closing::WouldBlockFlush
230                        }
231                        Err(WsError::ConnectionClosed) => Closing::Closed,
232                        Err(e) => return Poll::Ready(Err(e)),
233                    }
234                }
235                Closing::WouldBlockFlush => {
236                    let flushed = ready!(self.as_mut().poll_flush_write_buf())?;
237                    *self.as_mut().project().closing = if flushed == 0 {
238                        Closing::WouldBlockFill
239                    } else {
240                        Closing::None
241                    }
242                }
243                Closing::WouldBlockFill => {
244                    ready!(self.as_mut().poll_fill_read_buf())?;
245                    *self.as_mut().project().closing = Closing::None;
246                }
247                Closing::Closed => {
248                    ready!(self.as_mut().poll_flush_impl())?;
249                    let this = self.as_mut().project();
250                    *this.closing = Closing::None;
251                    this.close_waker.take();
252                    return Poll::Ready(Ok(()));
253                }
254            }
255        }
256    }
257}
258
259impl<S: AsyncRead + AsyncWrite + Unpin + 'static> Stream for CompatWebSocketStream<S>
260where
261    for<'a> &'a S: AsyncRead + AsyncWrite,
262{
263    type Item = Result<Message, WsError>;
264
265    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266        replace_waker(self.as_mut().project().read_waker, cx.waker());
267        loop {
268            let mut this = self.as_mut().project();
269            match std::mem::replace(this.reading, Reading::None) {
270                Reading::None => {
271                    *this.reading = match this.inner.read() {
272                        Ok(msg) => Reading::AfterRead(Ok(msg)),
273                        Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
274                            Reading::WouldBlock
275                        }
276                        Err(WsError::AlreadyClosed | WsError::ConnectionClosed) => {
277                            return Poll::Ready(None);
278                        }
279                        Err(e) => Reading::AfterRead(Err(e)),
280                    }
281                }
282                Reading::WouldBlock => {
283                    ready!(self.as_mut().poll_fill_read_buf())?;
284                }
285                Reading::AfterRead(res) => {
286                    let res = match self.as_mut().poll_flush_impl() {
287                        Poll::Pending => res,
288                        Poll::Ready(Ok(())) => res,
289                        Poll::Ready(Err(e)) => {
290                            if let Err(ori_e) = res {
291                                Err(ori_e)
292                            } else {
293                                Err(e)
294                            }
295                        }
296                    };
297                    self.as_mut().project().read_waker.take();
298                    return Poll::Ready(Some(res));
299                }
300            }
301        }
302    }
303}
304
305struct WakerArray<const N: usize>([Option<Waker>; N]);
306
307impl<const N: usize> Wake for WakerArray<N> {
308    fn wake(self: Arc<Self>) {
309        self.0.iter().for_each(|w| {
310            if let Some(w) = w {
311                w.wake_by_ref()
312            }
313        });
314    }
315}