1#![cfg_attr(docsrs, feature(doc_cfg))]
13#![allow(unused_features)]
14#![warn(missing_docs)]
15#![deny(rustdoc::broken_intra_doc_links)]
16#![doc(
17 html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
18)]
19#![doc(
20 html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
21)]
22
23use std::io::ErrorKind;
24
25use compio_buf::IntoInner;
26use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
27use tungstenite::{
28 Error as WsError, HandshakeError, Message, WebSocket,
29 client::IntoClientRequest,
30 handshake::server::{Callback, NoCallback},
31 protocol::{CloseFrame, Role, WebSocketConfig},
32};
33
34mod tls;
35#[cfg(feature = "io-compat")]
36pub use compat::CompatWebSocketStream;
37pub use tls::*;
38pub use tungstenite;
39#[cfg(feature = "io-compat")]
40mod compat;
41
42pub struct Config {
53 websocket: Option<WebSocketConfig>,
55
56 buffer_size_base: usize,
58
59 buffer_size_limit: usize,
61
62 disable_nagle: bool,
65}
66
67impl Config {
68 const DEFAULT_BUF_SIZE: usize = 128 * 1024;
70 const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
72
73 pub fn new() -> Self {
75 Self {
76 websocket: None,
77 buffer_size_base: Self::DEFAULT_BUF_SIZE,
78 buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
79 disable_nagle: false,
80 }
81 }
82
83 pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
85 self.websocket.as_ref()
86 }
87
88 pub fn buffer_size_base(&self) -> usize {
90 self.buffer_size_base
91 }
92
93 pub fn buffer_size_limit(&self) -> usize {
95 self.buffer_size_limit
96 }
97
98 pub fn with_buffer_size_base(mut self, size: usize) -> Self {
102 self.buffer_size_base = size;
103 self
104 }
105
106 pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
110 self.buffer_size_limit = size;
111 self
112 }
113
114 pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
118 self.buffer_size_base = base;
119 self.buffer_size_limit = limit;
120 self
121 }
122
123 pub fn disable_nagle(mut self, disable: bool) -> Self {
128 self.disable_nagle = disable;
129 self
130 }
131}
132
133impl Default for Config {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl From<WebSocketConfig> for Config {
140 fn from(config: WebSocketConfig) -> Self {
141 Self {
142 websocket: Some(config),
143 ..Default::default()
144 }
145 }
146}
147
148impl From<Option<WebSocketConfig>> for Config {
149 fn from(config: Option<WebSocketConfig>) -> Self {
150 Self {
151 websocket: config,
152 ..Default::default()
153 }
154 }
155}
156
157#[derive(Debug)]
159pub struct WebSocketStream<S> {
160 inner: WebSocket<SyncStream<S>>,
161}
162
163impl<S> WebSocketStream<S> {
164 pub fn get_ref(&self) -> &S {
166 self.inner.get_ref().get_ref()
167 }
168
169 pub fn get_mut(&mut self) -> &mut S {
171 self.inner.get_mut().get_mut()
172 }
173}
174
175impl<S> WebSocketStream<S>
176where
177 S: AsyncRead + AsyncWrite,
178{
179 pub async fn from_raw_socket(stream: S, role: Role, config: impl Into<Config>) -> Self {
186 let config = config.into();
187 let sync_stream =
188 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
189
190 WebSocketStream {
191 inner: WebSocket::from_raw_socket(sync_stream, role, config.websocket),
192 }
193 }
194
195 pub async fn from_partially_read(
202 stream: S,
203 part: Vec<u8>,
204 role: Role,
205 config: impl Into<Config>,
206 ) -> Self {
207 let config = config.into();
208 let sync_stream =
209 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
210
211 WebSocketStream {
212 inner: WebSocket::from_partially_read(sync_stream, part, role, config.websocket),
213 }
214 }
215
216 pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
218 match self.inner.write(message) {
219 Ok(()) => {}
220 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {}
221 Err(e) => return Err(e),
222 }
223 self.flush().await
225 }
226
227 pub async fn read(&mut self) -> Result<Message, WsError> {
229 loop {
230 match self.inner.read() {
231 Ok(msg) => {
232 self.flush().await?;
233 return Ok(msg);
234 }
235 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
236 self.fill_read_buf().await?;
238 }
239 Err(e) => {
240 let _ = self.flush().await;
241 return Err(e);
242 }
243 }
244 }
245 }
246
247 pub async fn flush(&mut self) -> Result<(), WsError> {
249 loop {
250 match self.inner.flush() {
251 Ok(()) => break,
252 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
253 self.flush_write_buf().await?;
254 }
255 Err(WsError::ConnectionClosed) => break,
256 Err(e) => return Err(e),
257 }
258 }
259 self.flush_write_buf().await?;
260 Ok(())
261 }
262
263 pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
265 loop {
266 match self.inner.close(close_frame.clone()) {
267 Ok(()) => break,
268 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
269 let flushed = self.flush_write_buf().await?;
270 if flushed == 0 {
271 self.fill_read_buf().await?;
272 }
273 }
274 Err(WsError::ConnectionClosed) => break,
275 Err(e) => return Err(e),
276 }
277 }
278 self.flush().await
279 }
280
281 pub(crate) async fn flush_write_buf(&mut self) -> Result<usize, WsError> {
282 self.inner
283 .get_mut()
284 .flush_write_buf()
285 .await
286 .map_err(WsError::Io)
287 }
288
289 pub(crate) async fn fill_read_buf(&mut self) -> Result<usize, WsError> {
290 self.inner
291 .get_mut()
292 .fill_read_buf()
293 .await
294 .map_err(WsError::Io)
295 }
296
297 #[cfg(feature = "io-compat")]
299 pub fn into_compat(self) -> CompatWebSocketStream<S>
300 where
302 for<'a> &'a S: AsyncRead + AsyncWrite,
303 S: Unpin,
304 {
305 CompatWebSocketStream::new(self.inner)
306 }
307}
308
309impl<S> IntoInner for WebSocketStream<S> {
310 type Inner = WebSocket<SyncStream<S>>;
311
312 fn into_inner(self) -> Self::Inner {
313 self.inner
314 }
315}
316
317pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
328where
329 S: AsyncRead + AsyncWrite,
330{
331 accept_hdr_async(stream, NoCallback).await
332}
333
334pub async fn accept_async_with_config<S>(
336 stream: S,
337 config: impl Into<Config>,
338) -> Result<WebSocketStream<S>, WsError>
339where
340 S: AsyncRead + AsyncWrite,
341{
342 accept_hdr_with_config_async(stream, NoCallback, config).await
343}
344pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
350where
351 S: AsyncRead + AsyncWrite,
352 C: Callback,
353{
354 accept_hdr_with_config_async(stream, callback, None).await
355}
356
357pub async fn accept_hdr_with_config_async<S, C>(
359 stream: S,
360 callback: C,
361 config: impl Into<Config>,
362) -> Result<WebSocketStream<S>, WsError>
363where
364 S: AsyncRead + AsyncWrite,
365 C: Callback,
366{
367 let config = config.into();
368 let sync_stream =
369 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
370 let mut handshake_result =
371 tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
372
373 loop {
374 match handshake_result {
375 Ok(mut websocket) => {
376 websocket
377 .get_mut()
378 .flush_write_buf()
379 .await
380 .map_err(WsError::Io)?;
381 return Ok(WebSocketStream { inner: websocket });
382 }
383 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
384 let sync_stream = mid_handshake.get_mut().get_mut();
385
386 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
387
388 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
389
390 handshake_result = mid_handshake.handshake();
391 }
392 Err(HandshakeError::Failure(error)) => {
393 return Err(error);
394 }
395 }
396 }
397}
398
399pub async fn client_async<R, S>(
413 request: R,
414 stream: S,
415) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
416where
417 R: IntoClientRequest,
418 S: AsyncRead + AsyncWrite,
419{
420 client_async_with_config(request, stream, None).await
421}
422
423pub async fn client_async_with_config<R, S>(
425 request: R,
426 stream: S,
427 config: impl Into<Config>,
428) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
429where
430 R: IntoClientRequest,
431 S: AsyncRead + AsyncWrite,
432{
433 let config = config.into();
434 let sync_stream =
435 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
436 let mut handshake_result =
437 tungstenite::client::client_with_config(request, sync_stream, config.websocket);
438
439 loop {
440 match handshake_result {
441 Ok((mut websocket, response)) => {
442 websocket
444 .get_mut()
445 .flush_write_buf()
446 .await
447 .map_err(WsError::Io)?;
448 return Ok((WebSocketStream { inner: websocket }, response));
449 }
450 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
451 let sync_stream = mid_handshake.get_mut().get_mut();
452
453 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
455
456 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
457
458 handshake_result = mid_handshake.handshake();
459 }
460 Err(HandshakeError::Failure(error)) => {
461 return Err(error);
462 }
463 }
464 }
465}