1use std::{
2 marker::PhantomPinned,
3 ops::Deref,
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll, Wake, Waker, ready},
7};
8
9use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
10use futures_util::{Sink, Stream};
11use pin_project_lite::pin_project;
12use tungstenite::{Message, WebSocket};
13
14use crate::WsError;
15
16type PinBoxFuture<T> = Pin<Box<dyn Future<Output = T>>>;
17
18enum Flushing {
19 None,
20 WouldBlock,
21 Flushed,
22}
23
24enum Closing {
25 None,
26 WouldBlockFlush,
27 WouldBlockFill,
28 Closed,
29}
30
31enum Reading {
32 None,
33 AfterRead(Result<Message, WsError>),
34 WouldBlock,
35}
36
37pin_project! {
38 pub struct CompatWebSocketStream<S> {
40 #[pin]
41 inner: WebSocket<SyncStream<S>>,
42 read_future: Option<PinBoxFuture<Result<usize, std::io::Error>>>,
43 write_future: Option<PinBoxFuture<Result<usize, std::io::Error>>>,
44 ready_waker: Option<Waker>,
45 flush_waker: Option<Waker>,
46 close_waker: Option<Waker>,
47 read_waker: Option<Waker>,
48 flushing: Flushing,
49 closing: Closing,
50 reading: Reading,
51 #[pin]
53 _p: PhantomPinned,
54 }
55}
56
57impl<S> CompatWebSocketStream<S> {
58 pub(super) fn new(stream: WebSocket<SyncStream<S>>) -> Self {
59 Self {
60 inner: stream,
61 read_future: None,
62 write_future: None,
63 ready_waker: None,
64 flush_waker: None,
65 close_waker: None,
66 read_waker: None,
67 flushing: Flushing::None,
68 closing: Closing::None,
69 reading: Reading::None,
70 _p: PhantomPinned,
71 }
72 }
73}
74
75impl<S> Deref for CompatWebSocketStream<S> {
76 type Target = WebSocket<SyncStream<S>>;
77
78 fn deref(&self) -> &Self::Target {
79 &self.inner
80 }
81}
82
83macro_rules! poll_future {
84 ($f:expr, $cx:expr, $e:expr) => {{
85 let mut future = match $f.take() {
86 Some(f) => f,
87 None => Box::pin($e),
88 };
89 let f = future.as_mut();
90 match f.poll($cx) {
91 Poll::Pending => {
92 $f.replace(future);
93 return Poll::Pending;
94 }
95 Poll::Ready(res) => res,
96 }
97 }};
98}
99
100unsafe fn extend_lifetime<T>(t: &mut T) -> &'static mut T {
101 unsafe { &mut *(t as *mut T) }
102}
103
104impl<S: AsyncRead + AsyncWrite + Unpin + 'static> CompatWebSocketStream<S>
105where
106 for<'a> &'a S: AsyncRead + AsyncWrite,
107{
108 fn poll_flush_write_buf(self: Pin<&mut Self>) -> Poll<Result<usize, WsError>> {
109 let this = self.project();
110 let inner: &'static mut SyncStream<S> =
120 unsafe { extend_lifetime(this.inner.get_mut().get_mut()) };
121 let arr = WakerArray([
122 this.ready_waker.as_ref().cloned(),
123 this.flush_waker.as_ref().cloned(),
124 this.close_waker.as_ref().cloned(),
125 this.read_waker.as_ref().cloned(),
126 ]);
127 let waker = Waker::from(Arc::new(arr));
128 let cx = &mut Context::from_waker(&waker);
129 let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
130 Poll::Ready(res.map_err(WsError::Io))
131 }
132
133 fn poll_fill_read_buf(self: Pin<&mut Self>) -> Poll<Result<usize, WsError>> {
134 let this = self.project();
135 let inner: &'static mut SyncStream<S> =
145 unsafe { extend_lifetime(this.inner.get_mut().get_mut()) };
146 let arr = WakerArray([
147 this.close_waker.as_ref().cloned(),
148 this.read_waker.as_ref().cloned(),
149 ]);
150 let waker = Waker::from(Arc::new(arr));
151 let cx = &mut Context::from_waker(&waker);
152 let res = poll_future!(this.read_future, cx, inner.fill_read_buf());
153 Poll::Ready(res.map_err(WsError::Io))
154 }
155
156 fn poll_flush_impl(mut self: Pin<&mut Self>) -> Poll<Result<(), WsError>> {
157 loop {
158 let mut this = self.as_mut().project();
159 match this.flushing {
160 Flushing::None => {
161 *this.flushing = match this.inner.flush() {
162 Ok(()) => Flushing::Flushed,
163 Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
164 Flushing::WouldBlock
165 }
166 Err(WsError::ConnectionClosed) => Flushing::Flushed,
167 Err(e) => return Poll::Ready(Err(e)),
168 }
169 }
170 Flushing::WouldBlock => {
171 ready!(self.as_mut().poll_flush_write_buf())?;
172 *self.as_mut().project().flushing = Flushing::None
173 }
174 Flushing::Flushed => {
175 ready!(self.as_mut().poll_flush_write_buf())?;
176 let this = self.as_mut().project();
177 *this.flushing = Flushing::None;
178 this.flush_waker.take();
179 return Poll::Ready(Ok(()));
180 }
181 }
182 }
183 }
184}
185
186fn replace_waker(waker_slot: &mut Option<Waker>, waker: &Waker) {
187 if !waker_slot.as_ref().is_some_and(|w| w.will_wake(waker)) {
188 waker_slot.replace(waker.clone());
189 }
190}
191
192impl<S: AsyncRead + AsyncWrite + Unpin + 'static> Sink<Message> for CompatWebSocketStream<S>
193where
194 for<'a> &'a S: AsyncRead + AsyncWrite,
195{
196 type Error = tungstenite::Error;
197
198 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199 if self.write_future.is_some() {
200 replace_waker(self.as_mut().project().ready_waker, cx.waker());
201 ready!(self.as_mut().poll_flush_write_buf())?;
202 self.as_mut().project().ready_waker.take();
203 }
204 Poll::Ready(Ok(()))
205 }
206
207 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
208 match self.project().inner.write(item) {
209 Ok(()) => Ok(()),
210 Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
211 Err(e) => Err(e),
212 }
213 }
214
215 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 replace_waker(self.as_mut().project().flush_waker, cx.waker());
217 self.poll_flush_impl()
218 }
219
220 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
221 replace_waker(self.as_mut().project().close_waker, cx.waker());
222 loop {
223 let mut this = self.as_mut().project();
224 match this.closing {
225 Closing::None => {
226 *this.closing = match this.inner.close(None) {
227 Ok(()) => Closing::Closed,
228 Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
229 Closing::WouldBlockFlush
230 }
231 Err(WsError::ConnectionClosed) => Closing::Closed,
232 Err(e) => return Poll::Ready(Err(e)),
233 }
234 }
235 Closing::WouldBlockFlush => {
236 let flushed = ready!(self.as_mut().poll_flush_write_buf())?;
237 *self.as_mut().project().closing = if flushed == 0 {
238 Closing::WouldBlockFill
239 } else {
240 Closing::None
241 }
242 }
243 Closing::WouldBlockFill => {
244 ready!(self.as_mut().poll_fill_read_buf())?;
245 *self.as_mut().project().closing = Closing::None;
246 }
247 Closing::Closed => {
248 ready!(self.as_mut().poll_flush_impl())?;
249 let this = self.as_mut().project();
250 *this.closing = Closing::None;
251 this.close_waker.take();
252 return Poll::Ready(Ok(()));
253 }
254 }
255 }
256 }
257}
258
259impl<S: AsyncRead + AsyncWrite + Unpin + 'static> Stream for CompatWebSocketStream<S>
260where
261 for<'a> &'a S: AsyncRead + AsyncWrite,
262{
263 type Item = Result<Message, WsError>;
264
265 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266 replace_waker(self.as_mut().project().read_waker, cx.waker());
267 loop {
268 let mut this = self.as_mut().project();
269 match std::mem::replace(this.reading, Reading::None) {
270 Reading::None => {
271 *this.reading = match this.inner.read() {
272 Ok(msg) => Reading::AfterRead(Ok(msg)),
273 Err(WsError::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
274 Reading::WouldBlock
275 }
276 Err(WsError::AlreadyClosed | WsError::ConnectionClosed) => {
277 return Poll::Ready(None);
278 }
279 Err(e) => Reading::AfterRead(Err(e)),
280 }
281 }
282 Reading::WouldBlock => {
283 ready!(self.as_mut().poll_fill_read_buf())?;
284 }
285 Reading::AfterRead(res) => {
286 let res = match self.as_mut().poll_flush_impl() {
287 Poll::Pending => res,
288 Poll::Ready(Ok(())) => res,
289 Poll::Ready(Err(e)) => {
290 if let Err(ori_e) = res {
291 Err(ori_e)
292 } else {
293 Err(e)
294 }
295 }
296 };
297 self.as_mut().project().read_waker.take();
298 return Poll::Ready(Some(res));
299 }
300 }
301 }
302 }
303}
304
305struct WakerArray<const N: usize>([Option<Waker>; N]);
306
307impl<const N: usize> Wake for WakerArray<N> {
308 fn wake(self: Arc<Self>) {
309 self.0.iter().for_each(|w| {
310 if let Some(w) = w {
311 w.wake_by_ref()
312 }
313 });
314 }
315}