Skip to main content

compio_io/compat/
sync_stream.rs

1use std::{
2    io::{self, BufRead, Read, Write},
3    mem::MaybeUninit,
4};
5
6use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut};
7
8use crate::{
9    buffer::Buffer,
10    util::{DEFAULT_BUF_SIZE, Splittable},
11};
12
13// 64MiB max
14pub(crate) const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
15
16#[derive(Debug)]
17struct SyncReadBuf {
18    buf: Buffer,
19    eof: bool,
20    base_capacity: usize,
21    max_buffer_size: usize,
22}
23
24impl SyncReadBuf {
25    pub fn new(start_capacity: usize, base_capacity: usize, max_buffer_size: usize) -> Self {
26        Self {
27            buf: Buffer::with_capacity(start_capacity),
28            eof: false,
29            base_capacity,
30            max_buffer_size,
31        }
32    }
33
34    pub fn is_eof(&self) -> bool {
35        self.eof
36    }
37
38    pub fn into_inner(mut self) -> Vec<u8> {
39        if self.buf.has_inner() {
40            let slice = self.buf.take_inner();
41            let begin = slice.begin();
42            let mut vec = slice.into_inner();
43            if begin > 0 {
44                vec.drain(..begin);
45            }
46            vec
47        } else {
48            Vec::new()
49        }
50    }
51
52    /// Returns the available bytes in the read buffer.
53    fn available_read(&self) -> io::Result<&[u8]> {
54        if self.buf.has_inner() {
55            Ok(self.buf.buffer())
56        } else {
57            Err(would_block("the read buffer is in use"))
58        }
59    }
60
61    /// Marks `amt` bytes as consumed from the read buffer.
62    ///
63    /// Resets the buffer when all data is consumed and shrinks capacity
64    /// if it has grown significantly beyond the base capacity.
65    pub fn consume(&mut self, amt: usize) {
66        let all_done = self.buf.advance(amt);
67
68        // Shrink oversized buffers back to base capacity
69        if all_done {
70            self.buf
71                .compact_to(self.base_capacity, self.max_buffer_size);
72        }
73    }
74
75    pub fn read_buf_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
76        let available = self.fill_buf()?;
77
78        let to_read = available.len().min(buf.len());
79        buf[..to_read].copy_from_slice(unsafe {
80            std::slice::from_raw_parts(available.as_ptr().cast(), to_read)
81        });
82        self.consume(to_read);
83
84        Ok(to_read)
85    }
86
87    pub fn fill_buf(&mut self) -> io::Result<&[u8]> {
88        let available = self.available_read()?;
89
90        if available.is_empty() && !self.eof {
91            return Err(would_block("need to fill read buffer"));
92        }
93
94        Ok(available)
95    }
96
97    pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
98        let mut slice = self.fill_buf()?;
99        slice.read(buf).inspect(|res| {
100            self.consume(*res);
101        })
102    }
103
104    #[cfg(feature = "read_buf")]
105    pub fn read_buf(&mut self, mut buf: io::BorrowedCursor<'_>) -> io::Result<()> {
106        let mut slice = self.fill_buf()?;
107        let old_written = buf.written();
108        slice.read_buf(buf.reborrow())?;
109        let len = buf.written() - old_written;
110        self.consume(len);
111        Ok(())
112    }
113
114    pub async fn fill_read_buf<S: crate::AsyncRead>(
115        &mut self,
116        stream: &mut S,
117    ) -> io::Result<usize> {
118        if self.eof {
119            return Ok(0);
120        }
121
122        // Compact buffer, move unconsumed data to the front
123        self.buf
124            .compact_to(self.base_capacity, self.max_buffer_size);
125
126        let read = self
127            .buf
128            .with(|mut inner| async {
129                let current_len = inner.buf_len();
130
131                if current_len >= self.max_buffer_size {
132                    return BufResult(
133                        Err(io::Error::new(
134                            io::ErrorKind::OutOfMemory,
135                            format!("read buffer size limit ({}) exceeded", self.max_buffer_size),
136                        )),
137                        inner,
138                    );
139                }
140
141                let capacity = inner.buf_capacity();
142                let available_space = capacity - current_len;
143
144                // If target space is less than base capacity, grow the buffer.
145                let target_space = self.base_capacity;
146                if available_space < target_space {
147                    let new_capacity = current_len + target_space;
148                    let _ = inner.reserve_exact(new_capacity - capacity);
149                }
150
151                let len = inner.buf_len();
152                let read_slice = inner.slice(len..);
153                stream.read(read_slice).await.into_inner()
154            })
155            .await?;
156        if read == 0 {
157            self.eof = true;
158        }
159        Ok(read)
160    }
161}
162
163#[derive(Debug)]
164struct SyncWriteBuf {
165    buf: Buffer,
166    base_capacity: usize,
167    max_buffer_size: usize,
168}
169
170impl SyncWriteBuf {
171    pub fn new(start_capacity: usize, base_capacity: usize, max_buffer_size: usize) -> Self {
172        Self {
173            buf: Buffer::with_capacity(start_capacity),
174            base_capacity,
175            max_buffer_size,
176        }
177    }
178
179    pub fn has_pending_write(&self) -> bool {
180        !self.buf.is_empty()
181    }
182
183    pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
184        if !self.buf.has_inner() {
185            return Err(would_block("the write buffer is in use"));
186        }
187        // Check if we should flush first
188        if self.buf.need_flush() && !self.buf.is_empty() {
189            return Err(would_block("need to flush write buffer"));
190        }
191
192        let written = self.buf.with_sync(|mut inner| {
193            let res = (|| {
194                if inner.buf_len() + buf.len() > self.max_buffer_size {
195                    let space = self.max_buffer_size - inner.buf_len();
196                    if space == 0 {
197                        Err(would_block("write buffer full, need to flush"))
198                    } else {
199                        inner.extend_from_slice(&buf[..space])?;
200                        Ok(space)
201                    }
202                } else {
203                    inner.extend_from_slice(buf)?;
204                    Ok(buf.len())
205                }
206            })();
207            BufResult(res, inner)
208        })?;
209
210        Ok(written)
211    }
212
213    pub async fn flush_write_buf<S: crate::AsyncWrite>(
214        &mut self,
215        stream: &mut S,
216    ) -> io::Result<usize> {
217        let flushed = self.buf.flush_to(stream).await?;
218        self.buf
219            .compact_to(self.base_capacity, self.max_buffer_size);
220        stream.flush().await?;
221        Ok(flushed)
222    }
223}
224
225/// A growable buffered stream adapter that bridges async I/O with sync traits.
226///
227/// # Buffer Growth Strategy
228///
229/// - **Read buffer**: Grows as needed to accommodate incoming data, up to
230///   `max_buffer_size`
231/// - **Write buffer**: Grows as needed for outgoing data, up to
232///   `max_buffer_size`
233/// - Both buffers shrink back to `base_capacity` when fully consumed and
234///   capacity exceeds 4x base
235///
236/// # Usage Pattern
237///
238/// The sync `Read` and `Write` implementations will return `WouldBlock` errors
239/// when buffers need servicing via the async methods:
240///
241/// - Call `fill_read_buf()` when `Read::read()` returns `WouldBlock`
242/// - Call `flush_write_buf()` when `Write::write()` returns `WouldBlock`
243///
244/// # Note on flush()
245///
246/// The `Write::flush()` method intentionally returns `Ok(())` without checking
247/// if there's buffered data. This is for compatibility with libraries like
248/// tungstenite that call `flush()` after every write. Actual flushing happens
249/// via the async `flush_write_buf()` method.
250#[derive(Debug)]
251pub struct SyncStream<S> {
252    inner: S,
253    read_buf: SyncReadBuf,
254    write_buf: SyncWriteBuf,
255}
256
257/// Read half of a [`SyncStream`] after splitting.
258#[derive(Debug)]
259pub struct SyncStreamReadHalf<S> {
260    inner: S,
261    read_buf: SyncReadBuf,
262}
263
264/// Write half of a [`SyncStream`] after splitting.
265#[derive(Debug)]
266pub struct SyncStreamWriteHalf<S> {
267    inner: S,
268    write_buf: SyncWriteBuf,
269}
270
271impl<S> SyncStream<S> {
272    /// Creates a new `SyncStream` with default buffer sizes.
273    ///
274    /// - Base capacity: 8KiB
275    /// - Max buffer size: 64MiB
276    pub fn new(stream: S) -> Self {
277        Self::with_capacity(DEFAULT_BUF_SIZE, stream)
278    }
279
280    /// Creates a new `SyncStream` with a custom base capacity.
281    ///
282    /// The maximum buffer size defaults to 64MiB.
283    pub fn with_capacity(base_capacity: usize, stream: S) -> Self {
284        Self::with_limits(base_capacity, DEFAULT_MAX_BUFFER, stream)
285    }
286
287    /// Creates a new `SyncStream` with custom base capacity and maximum
288    /// buffer size.
289    pub fn with_limits(base_capacity: usize, max_buffer_size: usize, stream: S) -> Self {
290        Self {
291            inner: stream,
292            read_buf: SyncReadBuf::new(base_capacity, base_capacity, max_buffer_size),
293            write_buf: SyncWriteBuf::new(base_capacity, base_capacity, max_buffer_size),
294        }
295    }
296
297    /// Returns a reference to the underlying stream.
298    pub fn get_ref(&self) -> &S {
299        &self.inner
300    }
301
302    /// Returns a mutable reference to the underlying stream.
303    pub fn get_mut(&mut self) -> &mut S {
304        &mut self.inner
305    }
306
307    /// Consumes the `SyncStream`, returning the underlying stream.
308    ///
309    /// Any buffered data is discarded. Use [`into_parts`](Self::into_parts)
310    /// if you need to preserve unread data.
311    pub fn into_inner(self) -> S {
312        self.inner
313    }
314
315    /// Consumes the `SyncStream`, returning the underlying stream and any
316    /// unread buffered data.
317    ///
318    /// If the read buffer is currently lent to an IO operation, the returned
319    /// `Vec` will be empty.
320    pub fn into_parts(self) -> (S, Vec<u8>) {
321        let remaining = self.read_buf.into_inner();
322        (self.inner, remaining)
323    }
324
325    /// Returns `true` if the stream has reached EOF.
326    pub fn is_eof(&self) -> bool {
327        self.read_buf.is_eof()
328    }
329
330    /// Pull some bytes from this source into the specified buffer.
331    pub fn read_buf_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
332        self.read_buf.read_buf_uninit(buf)
333    }
334
335    /// Returns `true` if there is pending data in the write buffer that needs
336    /// to be flushed.
337    pub fn has_pending_write(&self) -> bool {
338        self.write_buf.has_pending_write()
339    }
340}
341
342impl<S> SyncStreamReadHalf<S> {
343    pub(crate) fn with_limits(base_capacity: usize, max_buffer_size: usize, stream: S) -> Self {
344        Self {
345            inner: stream,
346            read_buf: SyncReadBuf::new(base_capacity, base_capacity, max_buffer_size),
347        }
348    }
349
350    /// Returns a reference to the underlying stream.
351    pub fn get_ref(&self) -> &S {
352        &self.inner
353    }
354
355    /// Returns a mutable reference to the underlying stream.
356    pub fn get_mut(&mut self) -> &mut S {
357        &mut self.inner
358    }
359
360    /// Consumes the `SyncStreamReadHalf`, returning the underlying stream.
361    pub fn into_inner(self) -> S {
362        self.inner
363    }
364
365    /// Consumes the `SyncStream`, returning the underlying stream and any
366    /// unread buffered data.
367    ///
368    /// If the read buffer is currently lent to an IO operation, the returned
369    /// `Vec` will be empty.
370    pub fn into_parts(self) -> (S, Vec<u8>) {
371        let remaining = self.read_buf.into_inner();
372        (self.inner, remaining)
373    }
374
375    /// Returns `true` if the stream has reached EOF.
376    pub fn is_eof(&self) -> bool {
377        self.read_buf.is_eof()
378    }
379
380    /// Pull some bytes from this source into the specified buffer.
381    pub fn read_buf_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
382        self.read_buf.read_buf_uninit(buf)
383    }
384}
385
386impl<S> SyncStreamWriteHalf<S> {
387    pub(crate) fn with_limits(base_capacity: usize, max_buffer_size: usize, stream: S) -> Self {
388        Self {
389            inner: stream,
390            write_buf: SyncWriteBuf::new(base_capacity, base_capacity, max_buffer_size),
391        }
392    }
393
394    /// Returns a reference to the underlying stream.
395    pub fn get_ref(&self) -> &S {
396        &self.inner
397    }
398
399    /// Returns a mutable reference to the underlying stream.
400    pub fn get_mut(&mut self) -> &mut S {
401        &mut self.inner
402    }
403
404    /// Consumes the `SyncStreamWriteHalf`, returning the underlying stream.
405    pub fn into_inner(self) -> S {
406        self.inner
407    }
408
409    /// Returns `true` if there is pending data in the write buffer that needs
410    /// to be flushed.
411    pub fn has_pending_write(&self) -> bool {
412        self.write_buf.has_pending_write()
413    }
414}
415
416impl<S> Read for SyncStream<S> {
417    /// Reads data from the internal buffer.
418    ///
419    /// Returns `WouldBlock` if the buffer is empty and not at EOF,
420    /// indicating that `fill_read_buf()` should be called.
421    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
422        self.read_buf.read(buf)
423    }
424
425    #[cfg(feature = "read_buf")]
426    fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
427        self.read_buf.read_buf(buf)
428    }
429}
430
431impl<S> BufRead for SyncStream<S> {
432    fn fill_buf(&mut self) -> io::Result<&[u8]> {
433        self.read_buf.fill_buf()
434    }
435
436    fn consume(&mut self, amt: usize) {
437        self.read_buf.consume(amt);
438    }
439}
440
441impl<S> Read for SyncStreamReadHalf<S> {
442    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
443        self.read_buf.read(buf)
444    }
445
446    #[cfg(feature = "read_buf")]
447    fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
448        self.read_buf.read_buf(buf)
449    }
450}
451
452impl<S> BufRead for SyncStreamReadHalf<S> {
453    fn fill_buf(&mut self) -> io::Result<&[u8]> {
454        self.read_buf.fill_buf()
455    }
456
457    fn consume(&mut self, amt: usize) {
458        self.read_buf.consume(amt);
459    }
460}
461
462impl<S> Write for SyncStream<S> {
463    /// Writes data to the internal buffer.
464    ///
465    /// Returns `WouldBlock` if the buffer needs flushing or has reached max
466    /// capacity. In the latter case, it may write partial data before
467    /// returning `WouldBlock`.
468    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
469        self.write_buf.write(buf)
470    }
471
472    /// Returns `Ok(())` without checking for buffered data.
473    ///
474    /// **Important**: This does NOT actually flush data to the underlying
475    /// stream. This behavior is intentional for compatibility with
476    /// libraries like tungstenite that call `flush()` after every write
477    /// operation. The actual async flush happens when `flush_write_buf()`
478    /// is called.
479    ///
480    /// This prevents spurious errors in sync code that expects `flush()` to
481    /// succeed after successfully buffering data.
482    fn flush(&mut self) -> io::Result<()> {
483        Ok(())
484    }
485}
486
487impl<S> Write for SyncStreamWriteHalf<S> {
488    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
489        self.write_buf.write(buf)
490    }
491
492    fn flush(&mut self) -> io::Result<()> {
493        Ok(())
494    }
495}
496
497fn would_block(msg: &str) -> io::Error {
498    io::Error::new(io::ErrorKind::WouldBlock, msg)
499}
500
501impl<S: crate::AsyncRead> SyncStream<S> {
502    /// Fills the read buffer by reading from the underlying async stream.
503    ///
504    /// This method:
505    /// 1. Compacts the buffer if there's unconsumed data
506    /// 2. Ensures there's space for at least `base_capacity` more bytes
507    /// 3. Reads data from the underlying stream
508    /// 4. Returns the number of bytes read (0 indicates EOF)
509    ///
510    /// # Errors
511    ///
512    /// Returns an error if:
513    /// - The read buffer has reached `max_buffer_size`
514    /// - The underlying stream returns an error
515    pub async fn fill_read_buf(&mut self) -> io::Result<usize> {
516        self.read_buf.fill_read_buf(&mut self.inner).await
517    }
518}
519
520impl<S: crate::AsyncRead> SyncStreamReadHalf<S> {
521    /// See [`SyncStream::fill_read_buf`].
522    pub async fn fill_read_buf(&mut self) -> io::Result<usize> {
523        self.read_buf.fill_read_buf(&mut self.inner).await
524    }
525}
526
527impl<S: crate::AsyncWrite> SyncStream<S> {
528    /// Flushes the write buffer to the underlying async stream.
529    ///
530    /// This method:
531    /// 1. Writes all buffered data to the underlying stream
532    /// 2. Calls `flush()` on the underlying stream
533    /// 3. Returns the total number of bytes flushed
534    ///
535    /// On error, any unwritten data remains in the buffer and can be retried.
536    ///
537    /// # Errors
538    ///
539    /// Returns an error if the underlying stream returns an error.
540    /// In this case, the buffer retains any data that wasn't successfully
541    /// written.
542    pub async fn flush_write_buf(&mut self) -> io::Result<usize> {
543        self.write_buf.flush_write_buf(&mut self.inner).await
544    }
545}
546
547impl<S: crate::AsyncWrite> SyncStreamWriteHalf<S> {
548    /// See [`SyncStream::flush_write_buf`].
549    pub async fn flush_write_buf(&mut self) -> io::Result<usize> {
550        self.write_buf.flush_write_buf(&mut self.inner).await
551    }
552}
553
554impl<S: Splittable> Splittable for SyncStream<S> {
555    type ReadHalf = SyncStreamReadHalf<S::ReadHalf>;
556    type WriteHalf = SyncStreamWriteHalf<S::WriteHalf>;
557
558    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
559        let (r, w) = self.inner.split();
560        let read_half = SyncStreamReadHalf {
561            inner: r,
562            read_buf: self.read_buf,
563        };
564        let write_half = SyncStreamWriteHalf {
565            inner: w,
566            write_buf: self.write_buf,
567        };
568        (read_half, write_half)
569    }
570}