1#![cfg_attr(docsrs, feature(doc_cfg))]
13#![warn(missing_docs)]
14#![deny(rustdoc::broken_intra_doc_links)]
15#![doc(
16 html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
17)]
18#![doc(
19 html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
20)]
21
22use std::io::ErrorKind;
23
24use compio_buf::IntoInner;
25use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
26use tungstenite::{
27 Error as WsError, HandshakeError, Message, WebSocket,
28 client::IntoClientRequest,
29 handshake::server::{Callback, NoCallback},
30 protocol::{CloseFrame, WebSocketConfig},
31};
32
33mod tls;
34pub use tls::*;
35pub use tungstenite;
36
37pub struct Config {
48 websocket: Option<WebSocketConfig>,
50
51 buffer_size_base: usize,
53
54 buffer_size_limit: usize,
56
57 disable_nagle: bool,
60}
61
62impl Config {
63 const DEFAULT_BUF_SIZE: usize = 128 * 1024;
65 const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
67
68 pub fn new() -> Self {
70 Self {
71 websocket: None,
72 buffer_size_base: Self::DEFAULT_BUF_SIZE,
73 buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
74 disable_nagle: false,
75 }
76 }
77
78 pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
80 self.websocket.as_ref()
81 }
82
83 pub fn buffer_size_base(&self) -> usize {
85 self.buffer_size_base
86 }
87
88 pub fn buffer_size_limit(&self) -> usize {
90 self.buffer_size_limit
91 }
92
93 pub fn with_buffer_size_base(mut self, size: usize) -> Self {
97 self.buffer_size_base = size;
98 self
99 }
100
101 pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
105 self.buffer_size_limit = size;
106 self
107 }
108
109 pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
113 self.buffer_size_base = base;
114 self.buffer_size_limit = limit;
115 self
116 }
117
118 pub fn disable_nagle(mut self, disable: bool) -> Self {
123 self.disable_nagle = disable;
124 self
125 }
126}
127
128impl Default for Config {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl From<WebSocketConfig> for Config {
135 fn from(config: WebSocketConfig) -> Self {
136 Self {
137 websocket: Some(config),
138 ..Default::default()
139 }
140 }
141}
142
143impl From<Option<WebSocketConfig>> for Config {
144 fn from(config: Option<WebSocketConfig>) -> Self {
145 Self {
146 websocket: config,
147 ..Default::default()
148 }
149 }
150}
151
152#[derive(Debug)]
154pub struct WebSocketStream<S> {
155 inner: WebSocket<SyncStream<S>>,
156}
157
158impl<S> WebSocketStream<S>
159where
160 S: AsyncRead + AsyncWrite,
161{
162 pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
164 self.inner.send(message)?;
167
168 self.flush().await
170 }
171
172 pub async fn read(&mut self) -> Result<Message, WsError> {
174 loop {
175 match self.inner.read() {
176 Ok(msg) => {
177 self.flush().await?;
178 return Ok(msg);
179 }
180 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
181 self.inner
183 .get_mut()
184 .fill_read_buf()
185 .await
186 .map_err(WsError::Io)?;
187 }
188 Err(e) => {
189 let _ = self.flush().await;
190 return Err(e);
191 }
192 }
193 }
194 }
195
196 pub async fn flush(&mut self) -> Result<(), WsError> {
198 loop {
199 match self.inner.flush() {
200 Ok(()) => break,
201 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
202 self.inner
203 .get_mut()
204 .flush_write_buf()
205 .await
206 .map_err(WsError::Io)?;
207 }
208 Err(WsError::ConnectionClosed) => break,
209 Err(e) => return Err(e),
210 }
211 }
212 self.inner
213 .get_mut()
214 .flush_write_buf()
215 .await
216 .map_err(WsError::Io)?;
217 Ok(())
218 }
219
220 pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
222 loop {
223 match self.inner.close(close_frame.clone()) {
224 Ok(()) => break,
225 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
226 let sync_stream = self.inner.get_mut();
227
228 let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
229
230 if flushed == 0 {
231 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
232 }
233 }
234 Err(WsError::ConnectionClosed) => break,
235 Err(e) => return Err(e),
236 }
237 }
238 self.flush().await
239 }
240
241 pub fn get_ref(&self) -> &S {
243 self.inner.get_ref().get_ref()
244 }
245
246 pub fn get_mut(&mut self) -> &mut S {
248 self.inner.get_mut().get_mut()
249 }
250}
251
252impl<S> IntoInner for WebSocketStream<S> {
253 type Inner = WebSocket<SyncStream<S>>;
254
255 fn into_inner(self) -> Self::Inner {
256 self.inner
257 }
258}
259
260pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
271where
272 S: AsyncRead + AsyncWrite,
273{
274 accept_hdr_async(stream, NoCallback).await
275}
276
277pub async fn accept_async_with_config<S>(
279 stream: S,
280 config: impl Into<Config>,
281) -> Result<WebSocketStream<S>, WsError>
282where
283 S: AsyncRead + AsyncWrite,
284{
285 accept_hdr_with_config_async(stream, NoCallback, config).await
286}
287pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
293where
294 S: AsyncRead + AsyncWrite,
295 C: Callback,
296{
297 accept_hdr_with_config_async(stream, callback, None).await
298}
299
300pub async fn accept_hdr_with_config_async<S, C>(
302 stream: S,
303 callback: C,
304 config: impl Into<Config>,
305) -> Result<WebSocketStream<S>, WsError>
306where
307 S: AsyncRead + AsyncWrite,
308 C: Callback,
309{
310 let config = config.into();
311 let sync_stream =
312 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
313 let mut handshake_result =
314 tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
315
316 loop {
317 match handshake_result {
318 Ok(mut websocket) => {
319 websocket
320 .get_mut()
321 .flush_write_buf()
322 .await
323 .map_err(WsError::Io)?;
324 return Ok(WebSocketStream { inner: websocket });
325 }
326 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
327 let sync_stream = mid_handshake.get_mut().get_mut();
328
329 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
330
331 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
332
333 handshake_result = mid_handshake.handshake();
334 }
335 Err(HandshakeError::Failure(error)) => {
336 return Err(error);
337 }
338 }
339 }
340}
341
342pub async fn client_async<R, S>(
356 request: R,
357 stream: S,
358) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
359where
360 R: IntoClientRequest,
361 S: AsyncRead + AsyncWrite,
362{
363 client_async_with_config(request, stream, None).await
364}
365
366pub async fn client_async_with_config<R, S>(
368 request: R,
369 stream: S,
370 config: impl Into<Config>,
371) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
372where
373 R: IntoClientRequest,
374 S: AsyncRead + AsyncWrite,
375{
376 let config = config.into();
377 let sync_stream =
378 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
379 let mut handshake_result =
380 tungstenite::client::client_with_config(request, sync_stream, config.websocket);
381
382 loop {
383 match handshake_result {
384 Ok((mut websocket, response)) => {
385 websocket
387 .get_mut()
388 .flush_write_buf()
389 .await
390 .map_err(WsError::Io)?;
391 return Ok((WebSocketStream { inner: websocket }, response));
392 }
393 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
394 let sync_stream = mid_handshake.get_mut().get_mut();
395
396 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
398
399 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
400
401 handshake_result = mid_handshake.handshake();
402 }
403 Err(HandshakeError::Failure(error)) => {
404 return Err(error);
405 }
406 }
407 }
408}