Skip to main content

compio_io/compat/
sync_stream.rs

1use std::{
2    io::{self, BufRead, Read, Write},
3    mem::MaybeUninit,
4};
5
6use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut};
7
8use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE};
9
10/// A growable buffered stream adapter that bridges async I/O with sync traits.
11///
12/// # Buffer Growth Strategy
13///
14/// - **Read buffer**: Grows as needed to accommodate incoming data, up to
15///   `max_buffer_size`
16/// - **Write buffer**: Grows as needed for outgoing data, up to
17///   `max_buffer_size`
18/// - Both buffers shrink back to `base_capacity` when fully consumed and
19///   capacity exceeds 4x base
20///
21/// # Usage Pattern
22///
23/// The sync `Read` and `Write` implementations will return `WouldBlock` errors
24/// when buffers need servicing via the async methods:
25///
26/// - Call `fill_read_buf()` when `Read::read()` returns `WouldBlock`
27/// - Call `flush_write_buf()` when `Write::write()` returns `WouldBlock`
28///
29/// # Note on flush()
30///
31/// The `Write::flush()` method intentionally returns `Ok(())` without checking
32/// if there's buffered data. This is for compatibility with libraries like
33/// tungstenite that call `flush()` after every write. Actual flushing happens
34/// via the async `flush_write_buf()` method.
35#[derive(Debug)]
36pub struct SyncStream<S> {
37    inner: S,
38    read_buf: Buffer,
39    write_buf: Buffer,
40    eof: bool,
41    base_capacity: usize,
42    max_buffer_size: usize,
43}
44
45impl<S> SyncStream<S> {
46    // 64MiB max
47    pub(crate) const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
48
49    /// Creates a new `SyncStream` with default buffer sizes.
50    ///
51    /// - Base capacity: 8KiB
52    /// - Max buffer size: 64MiB
53    pub fn new(stream: S) -> Self {
54        Self::with_capacity(DEFAULT_BUF_SIZE, stream)
55    }
56
57    /// Creates a new `SyncStream` with a custom base capacity.
58    ///
59    /// The maximum buffer size defaults to 64MiB.
60    pub fn with_capacity(base_capacity: usize, stream: S) -> Self {
61        Self::with_limits(base_capacity, Self::DEFAULT_MAX_BUFFER, stream)
62    }
63
64    /// Creates a new `SyncStream` with custom base capacity and maximum
65    /// buffer size.
66    pub fn with_limits(base_capacity: usize, max_buffer_size: usize, stream: S) -> Self {
67        Self {
68            inner: stream,
69            read_buf: Buffer::with_capacity(base_capacity),
70            write_buf: Buffer::with_capacity(base_capacity),
71            eof: false,
72            base_capacity,
73            max_buffer_size,
74        }
75    }
76
77    pub(crate) fn with_limits2(
78        read_capacity: usize,
79        write_capacity: usize,
80        base_capacity: usize,
81        max_buffer_size: usize,
82        stream: S,
83    ) -> Self {
84        Self {
85            inner: stream,
86            read_buf: Buffer::with_capacity(read_capacity),
87            write_buf: Buffer::with_capacity(write_capacity),
88            eof: false,
89            base_capacity,
90            max_buffer_size,
91        }
92    }
93
94    /// Returns a reference to the underlying stream.
95    pub fn get_ref(&self) -> &S {
96        &self.inner
97    }
98
99    /// Returns a mutable reference to the underlying stream.
100    pub fn get_mut(&mut self) -> &mut S {
101        &mut self.inner
102    }
103
104    /// Consumes the `SyncStream`, returning the underlying stream.
105    pub fn into_inner(self) -> S {
106        self.inner
107    }
108
109    /// Returns `true` if the stream has reached EOF.
110    pub fn is_eof(&self) -> bool {
111        self.eof
112    }
113
114    /// Returns the available bytes in the read buffer.
115    fn available_read(&self) -> io::Result<&[u8]> {
116        if self.read_buf.has_inner() {
117            Ok(self.read_buf.buffer())
118        } else {
119            Err(would_block("the read buffer is in use"))
120        }
121    }
122
123    /// Marks `amt` bytes as consumed from the read buffer.
124    ///
125    /// Resets the buffer when all data is consumed and shrinks capacity
126    /// if it has grown significantly beyond the base capacity.
127    fn consume_read(&mut self, amt: usize) {
128        let all_done = self.read_buf.advance(amt);
129
130        // Shrink oversized buffers back to base capacity
131        if all_done {
132            self.read_buf
133                .compact_to(self.base_capacity, self.max_buffer_size);
134        }
135    }
136
137    /// Pull some bytes from this source into the specified buffer.
138    pub fn read_buf_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
139        let available = self.fill_buf()?;
140
141        let to_read = available.len().min(buf.len());
142        buf[..to_read].copy_from_slice(unsafe {
143            std::slice::from_raw_parts(available.as_ptr().cast(), to_read)
144        });
145        self.consume(to_read);
146
147        Ok(to_read)
148    }
149
150    /// Returns `true` if there is pending data in the write buffer that needs
151    /// to be flushed.
152    pub fn has_pending_write(&self) -> bool {
153        !self.write_buf.is_empty()
154    }
155}
156
157impl<S> Read for SyncStream<S> {
158    /// Reads data from the internal buffer.
159    ///
160    /// Returns `WouldBlock` if the buffer is empty and not at EOF,
161    /// indicating that `fill_read_buf()` should be called.
162    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
163        let mut slice = self.fill_buf()?;
164        slice.read(buf).inspect(|res| {
165            self.consume(*res);
166        })
167    }
168
169    #[cfg(feature = "read_buf")]
170    fn read_buf(&mut self, mut buf: io::BorrowedCursor<'_>) -> io::Result<()> {
171        let mut slice = self.fill_buf()?;
172        let old_written = buf.written();
173        slice.read_buf(buf.reborrow())?;
174        let len = buf.written() - old_written;
175        self.consume(len);
176        Ok(())
177    }
178}
179
180impl<S> BufRead for SyncStream<S> {
181    fn fill_buf(&mut self) -> io::Result<&[u8]> {
182        let available = self.available_read()?;
183
184        if available.is_empty() && !self.eof {
185            return Err(would_block("need to fill read buffer"));
186        }
187
188        Ok(available)
189    }
190
191    fn consume(&mut self, amt: usize) {
192        self.consume_read(amt);
193    }
194}
195
196impl<S> Write for SyncStream<S> {
197    /// Writes data to the internal buffer.
198    ///
199    /// Returns `WouldBlock` if the buffer needs flushing or has reached max
200    /// capacity. In the latter case, it may write partial data before
201    /// returning `WouldBlock`.
202    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
203        if !self.write_buf.has_inner() {
204            return Err(would_block("the write buffer is in use"));
205        }
206        // Check if we should flush first
207        if self.write_buf.need_flush() && !self.write_buf.is_empty() {
208            return Err(would_block("need to flush write buffer"));
209        }
210
211        let written = self.write_buf.with_sync(|mut inner| {
212            let res = (|| {
213                if inner.buf_len() + buf.len() > self.max_buffer_size {
214                    let space = self.max_buffer_size - inner.buf_len();
215                    if space == 0 {
216                        Err(would_block("write buffer full, need to flush"))
217                    } else {
218                        inner.extend_from_slice(&buf[..space])?;
219                        Ok(space)
220                    }
221                } else {
222                    inner.extend_from_slice(buf)?;
223                    Ok(buf.len())
224                }
225            })();
226            BufResult(res, inner)
227        })?;
228
229        Ok(written)
230    }
231
232    /// Returns `Ok(())` without checking for buffered data.
233    ///
234    /// **Important**: This does NOT actually flush data to the underlying
235    /// stream. This behavior is intentional for compatibility with
236    /// libraries like tungstenite that call `flush()` after every write
237    /// operation. The actual async flush happens when `flush_write_buf()`
238    /// is called.
239    ///
240    /// This prevents spurious errors in sync code that expects `flush()` to
241    /// succeed after successfully buffering data.
242    fn flush(&mut self) -> io::Result<()> {
243        Ok(())
244    }
245}
246
247fn would_block(msg: &str) -> io::Error {
248    io::Error::new(io::ErrorKind::WouldBlock, msg)
249}
250
251impl<S: crate::AsyncRead> SyncStream<S> {
252    /// Fills the read buffer by reading from the underlying async stream.
253    ///
254    /// This method:
255    /// 1. Compacts the buffer if there's unconsumed data
256    /// 2. Ensures there's space for at least `base_capacity` more bytes
257    /// 3. Reads data from the underlying stream
258    /// 4. Returns the number of bytes read (0 indicates EOF)
259    ///
260    /// # Errors
261    ///
262    /// Returns an error if:
263    /// - The read buffer has reached `max_buffer_size`
264    /// - The underlying stream returns an error
265    pub async fn fill_read_buf(&mut self) -> io::Result<usize> {
266        if self.eof {
267            return Ok(0);
268        }
269
270        // Compact buffer, move unconsumed data to the front
271        self.read_buf
272            .compact_to(self.base_capacity, self.max_buffer_size);
273
274        let read = self
275            .read_buf
276            .with(|mut inner| async {
277                let current_len = inner.buf_len();
278
279                if current_len >= self.max_buffer_size {
280                    return BufResult(
281                        Err(io::Error::new(
282                            io::ErrorKind::OutOfMemory,
283                            format!("read buffer size limit ({}) exceeded", self.max_buffer_size),
284                        )),
285                        inner,
286                    );
287                }
288
289                let capacity = inner.buf_capacity();
290                let available_space = capacity - current_len;
291
292                // If target space is less than base capacity, grow the buffer.
293                let target_space = self.base_capacity;
294                if available_space < target_space {
295                    let new_capacity = current_len + target_space;
296                    let _ = inner.reserve_exact(new_capacity - capacity);
297                }
298
299                let len = inner.buf_len();
300                let read_slice = inner.slice(len..);
301                self.inner.read(read_slice).await.into_inner()
302            })
303            .await?;
304        if read == 0 {
305            self.eof = true;
306        }
307        Ok(read)
308    }
309}
310
311impl<S: crate::AsyncWrite> SyncStream<S> {
312    /// Flushes the write buffer to the underlying async stream.
313    ///
314    /// This method:
315    /// 1. Writes all buffered data to the underlying stream
316    /// 2. Calls `flush()` on the underlying stream
317    /// 3. Returns the total number of bytes flushed
318    ///
319    /// On error, any unwritten data remains in the buffer and can be retried.
320    ///
321    /// # Errors
322    ///
323    /// Returns an error if the underlying stream returns an error.
324    /// In this case, the buffer retains any data that wasn't successfully
325    /// written.
326    pub async fn flush_write_buf(&mut self) -> io::Result<usize> {
327        let flushed = self.write_buf.flush_to(&mut self.inner).await?;
328        self.write_buf
329            .compact_to(self.base_capacity, self.max_buffer_size);
330        self.inner.flush().await?;
331        Ok(flushed)
332    }
333}