1#![cfg_attr(docsrs, feature(doc_cfg))]
13#![allow(unused_features)]
14#![warn(missing_docs)]
15#![deny(rustdoc::broken_intra_doc_links)]
16#![doc(
17 html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
18)]
19#![doc(
20 html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
21)]
22
23use std::{
24 pin::Pin,
25 task::{Context, Poll, ready},
26};
27
28use compio_buf::IntoInner;
29use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};
30use compio_tls::{MaybeTlsStream, TlsStream};
31use futures_util::{Sink, SinkExt, Stream, StreamExt, stream::FusedStream};
32use pin_project_lite::pin_project;
33use tungstenite::{
34 Error as WsError, Message,
35 client::IntoClientRequest,
36 handshake::server::{Callback, NoCallback},
37 protocol::{CloseFrame, Role, WebSocketConfig},
38};
39
40#[cfg(feature = "connect")]
41mod tls;
42#[cfg(feature = "connect")]
43pub use tls::*;
44pub use tungstenite;
45
46pub struct Config {
57 websocket: Option<WebSocketConfig>,
59
60 buffer_size_base: usize,
62
63 buffer_size_limit: usize,
65
66 disable_nagle: bool,
69}
70
71impl Config {
72 const DEFAULT_BUF_SIZE: usize = 128 * 1024;
74 const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
76
77 pub fn new() -> Self {
79 Self {
80 websocket: None,
81 buffer_size_base: Self::DEFAULT_BUF_SIZE,
82 buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
83 disable_nagle: false,
84 }
85 }
86
87 pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
89 self.websocket.as_ref()
90 }
91
92 pub fn buffer_size_base(&self) -> usize {
94 self.buffer_size_base
95 }
96
97 pub fn buffer_size_limit(&self) -> usize {
99 self.buffer_size_limit
100 }
101
102 pub fn with_buffer_size_base(mut self, size: usize) -> Self {
106 self.buffer_size_base = size;
107 self
108 }
109
110 pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
114 self.buffer_size_limit = size;
115 self
116 }
117
118 pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
122 self.buffer_size_base = base;
123 self.buffer_size_limit = limit;
124 self
125 }
126
127 pub fn disable_nagle(mut self, disable: bool) -> Self {
132 self.disable_nagle = disable;
133 self
134 }
135}
136
137impl Default for Config {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143impl From<WebSocketConfig> for Config {
144 fn from(config: WebSocketConfig) -> Self {
145 Self {
146 websocket: Some(config),
147 ..Default::default()
148 }
149 }
150}
151
152impl From<Option<WebSocketConfig>> for Config {
153 fn from(config: Option<WebSocketConfig>) -> Self {
154 Self {
155 websocket: config,
156 ..Default::default()
157 }
158 }
159}
160
161mod private {
162 use super::*;
163
164 pub trait Sealed<S>
165 where
166 S: Splittable,
167 {
168 }
169
170 impl<S: Splittable> Sealed<S> for S {}
171 impl<S: Splittable> Sealed<S> for AsyncStream<S> {}
172 impl<S: Splittable> Sealed<S> for MaybeTlsStream<S> {}
173 impl<S: Splittable> Sealed<S> for TlsStream<S> {}
174}
175
176pub trait IntoMaybeTlsStream<S>: private::Sealed<S>
178where
179 S: Splittable,
180{
181 fn into_maybe_tls_stream(self, capacity: usize, max_buffer_size: usize) -> MaybeTlsStream<S>;
183}
184
185impl<S: Splittable> IntoMaybeTlsStream<S> for S {
186 fn into_maybe_tls_stream(self, capacity: usize, max_buffer_size: usize) -> MaybeTlsStream<S> {
187 MaybeTlsStream::new_plain_compat(AsyncStream::with_limits(capacity, max_buffer_size, self))
188 }
189}
190
191impl<S: Splittable> IntoMaybeTlsStream<S> for AsyncStream<S> {
192 fn into_maybe_tls_stream(self, _: usize, _: usize) -> MaybeTlsStream<S> {
193 MaybeTlsStream::new_plain_compat(self)
194 }
195}
196
197impl<S: Splittable> IntoMaybeTlsStream<S> for MaybeTlsStream<S> {
198 fn into_maybe_tls_stream(self, _: usize, _: usize) -> MaybeTlsStream<S> {
199 self
200 }
201}
202
203impl<S: Splittable> IntoMaybeTlsStream<S> for TlsStream<S> {
204 fn into_maybe_tls_stream(self, _: usize, _: usize) -> MaybeTlsStream<S> {
205 MaybeTlsStream::new_tls(self)
206 }
207}
208
209pin_project! {
210 #[derive(Debug)]
212 pub struct WebSocketStream<S: Splittable> {
213 #[pin]
214 inner: async_tungstenite::WebSocketStream<MaybeTlsStream<S>>,
215 next_item: Option<Option<Result<Message, WsError>>>,
216 }
217}
218
219impl<S: Splittable + 'static> WebSocketStream<S>
220where
221 S::ReadHalf: AsyncRead + Unpin,
222 S::WriteHalf: AsyncWrite + Unpin,
223{
224 pub fn get_ref(&self) -> &MaybeTlsStream<S> {
226 self.inner.get_ref()
227 }
228
229 pub fn get_mut(&mut self) -> &mut MaybeTlsStream<S> {
231 self.inner.get_mut()
232 }
233
234 pub async fn from_raw_socket<T: IntoMaybeTlsStream<S>>(
241 stream: T,
242 role: Role,
243 config: impl Into<Config>,
244 ) -> Self {
245 let config = config.into();
246
247 Self::from_inner(
248 async_tungstenite::WebSocketStream::from_raw_socket(
249 stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
250 role,
251 config.websocket,
252 )
253 .await,
254 )
255 }
256
257 pub async fn from_partially_read<T: IntoMaybeTlsStream<S>>(
264 stream: T,
265 part: Vec<u8>,
266 role: Role,
267 config: impl Into<Config>,
268 ) -> Self {
269 let config = config.into();
270
271 Self::from_inner(
272 async_tungstenite::WebSocketStream::from_partially_read(
273 stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
274 part,
275 role,
276 config.websocket,
277 )
278 .await,
279 )
280 }
281
282 fn from_inner(inner: async_tungstenite::WebSocketStream<MaybeTlsStream<S>>) -> Self {
283 WebSocketStream {
284 inner,
285 next_item: None,
286 }
287 }
288
289 pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
291 SinkExt::send(self, message).await
292 }
293
294 pub async fn read(&mut self) -> Result<Message, WsError> {
296 self.next()
297 .await
298 .unwrap_or_else(|| Err(WsError::ConnectionClosed))
299 }
300
301 pub async fn flush(&mut self) -> Result<(), WsError> {
303 SinkExt::flush(self).await
304 }
305
306 pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
308 self.send(Message::Close(close_frame)).await
309 }
310}
311
312impl<S: Splittable> IntoInner for WebSocketStream<S> {
313 type Inner = MaybeTlsStream<S>;
314
315 fn into_inner(self) -> Self::Inner {
316 self.inner.into_inner()
317 }
318}
319
320impl<S: Splittable + 'static> Sink<Message> for WebSocketStream<S>
321where
322 S::ReadHalf: AsyncRead + Unpin,
323 S::WriteHalf: AsyncWrite + Unpin,
324{
325 type Error = WsError;
326
327 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
328 self.project().inner.poll_ready(cx)
329 }
330
331 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
332 self.project().inner.start_send(item)
333 }
334
335 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
336 ready!(self.as_mut().project().inner.poll_flush(cx))?;
337 ready!(futures_util::AsyncWrite::poll_flush(
338 Pin::new(self.project().inner.get_mut().get_mut()),
339 cx
340 ))?;
341 Poll::Ready(Ok(()))
342 }
343
344 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
345 self.project().inner.poll_close(cx)
346 }
347}
348
349impl<S: Splittable + 'static> Stream for WebSocketStream<S>
350where
351 S::ReadHalf: AsyncRead + Unpin,
352 S::WriteHalf: AsyncWrite + Unpin,
353{
354 type Item = Result<Message, WsError>;
355
356 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357 let mut this = self.project();
358 loop {
359 if this.next_item.is_some() {
360 ready!(this.inner.as_mut().poll_flush(cx))?;
361 ready!(futures_util::AsyncWrite::poll_flush(
362 Pin::new(this.inner.get_mut().get_mut()),
363 cx
364 ))?;
365 break Poll::Ready(this.next_item.take().expect("next_item should be Some"));
366 } else {
367 let item = ready!(this.inner.as_mut().poll_next(cx));
368 *this.next_item = Some(item);
369 }
370 }
371 }
372}
373
374impl<S: Splittable + 'static> FusedStream for WebSocketStream<S>
375where
376 S::ReadHalf: AsyncRead + Unpin,
377 S::WriteHalf: AsyncWrite + Unpin,
378{
379 fn is_terminated(&self) -> bool {
380 self.inner.is_terminated()
381 }
382}
383
384pub async fn accept_async<S, T>(stream: T) -> Result<WebSocketStream<S>, WsError>
395where
396 S: Splittable + 'static,
397 <S as Splittable>::ReadHalf: AsyncRead + Unpin,
398 <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
399 T: IntoMaybeTlsStream<S>,
400{
401 accept_hdr_async(stream, NoCallback).await
402}
403
404pub async fn accept_async_with_config<S, T>(
406 stream: T,
407 config: impl Into<Config>,
408) -> Result<WebSocketStream<S>, WsError>
409where
410 S: Splittable + 'static,
411 <S as Splittable>::ReadHalf: AsyncRead + Unpin,
412 <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
413 T: IntoMaybeTlsStream<S>,
414{
415 accept_hdr_with_config_async(stream, NoCallback, config).await
416}
417
418pub async fn accept_hdr_async<S, T, C>(
424 stream: T,
425 callback: C,
426) -> Result<WebSocketStream<S>, WsError>
427where
428 S: Splittable + 'static,
429 T: IntoMaybeTlsStream<S>,
430 C: Callback + Unpin,
431 <S as Splittable>::ReadHalf: AsyncRead + Unpin,
432 <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
433{
434 accept_hdr_with_config_async(stream, callback, None).await
435}
436
437pub async fn accept_hdr_with_config_async<S, T, C>(
439 stream: T,
440 callback: C,
441 config: impl Into<Config>,
442) -> Result<WebSocketStream<S>, WsError>
443where
444 S: Splittable + 'static,
445 T: IntoMaybeTlsStream<S>,
446 C: Callback + Unpin,
447 <S as Splittable>::ReadHalf: AsyncRead + Unpin,
448 <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
449{
450 let config = config.into();
451 let inner = async_tungstenite::accept_hdr_async_with_config(
452 stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
453 callback,
454 config.websocket,
455 )
456 .await?;
457 Ok(WebSocketStream::from_inner(inner))
458}
459
460pub async fn client_async<R, S, T>(
474 request: R,
475 stream: T,
476) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
477where
478 R: IntoClientRequest + Unpin,
479 S: Splittable + 'static,
480 <S as Splittable>::ReadHalf: AsyncRead + Unpin,
481 <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
482 T: IntoMaybeTlsStream<S>,
483{
484 client_async_with_config(request, stream, None).await
485}
486
487pub async fn client_async_with_config<R, S, T>(
489 request: R,
490 stream: T,
491 config: impl Into<Config>,
492) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
493where
494 R: IntoClientRequest + Unpin,
495 S: Splittable + 'static,
496 <S as Splittable>::ReadHalf: AsyncRead + Unpin,
497 <S as Splittable>::WriteHalf: AsyncWrite + Unpin,
498 T: IntoMaybeTlsStream<S>,
499{
500 let config = config.into();
501 let (inner, response) = async_tungstenite::client_async_with_config(
502 request,
503 stream.into_maybe_tls_stream(config.buffer_size_base, config.buffer_size_limit),
504 config.websocket,
505 )
506 .await?;
507 Ok((WebSocketStream::from_inner(inner), response))
508}