Skip to main content

compio_io\read/
ext.rs

1#[cfg(feature = "allocator_api")]
2use std::alloc::Allocator;
3use std::{io, io::ErrorKind};
4
5use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, Uninit, t_alloc};
6
7use crate::{
8    AsyncRead, AsyncReadAt, IoResult, framed,
9    util::{Splittable, Take},
10};
11
12/// Shared code for read a scalar value from the underlying reader.
13macro_rules! read_scalar {
14    ($t:ty, $be:ident, $le:ident) => {
15        ::paste::paste! {
16            #[doc = concat!("Read a big endian `", stringify!($t), "` from the underlying reader.")]
17            async fn [< read_ $t >](&mut self) -> IoResult<$t> {
18                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
19
20                const LEN: usize = ::std::mem::size_of::<$t>();
21                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
22                res?;
23                // SAFETY: We just checked that the buffer is the correct size
24                Ok($t::$be(unsafe { buf.into_inner_unchecked() }))
25            }
26
27            #[doc = concat!("Read a little endian `", stringify!($t), "` from the underlying reader.")]
28            async fn [< read_ $t _le >](&mut self) -> IoResult<$t> {
29                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
30
31                const LEN: usize = ::std::mem::size_of::<$t>();
32                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
33                res?;
34                // SAFETY: We just checked that the buffer is the correct size
35                Ok($t::$le(unsafe { buf.into_inner_unchecked() }))
36            }
37        }
38    };
39}
40
41/// Shared code for loop reading until reaching a certain length.
42macro_rules! loop_read_exact {
43    ($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => {
44        let mut $tracker = 0;
45        let len = $len;
46
47        while $tracker < len {
48            match $read_expr.await.into_inner() {
49                BufResult(Ok(0), buf) => {
50                    return BufResult(
51                        Err(::std::io::Error::new(
52                            ::std::io::ErrorKind::UnexpectedEof,
53                            "failed to fill whole buffer",
54                        )),
55                        buf,
56                    );
57                }
58                BufResult(Ok(n), buf) => {
59                    $tracker += n;
60                    $buf = buf;
61                }
62                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
63                    $buf = buf;
64                }
65                BufResult(Err(e), buf) => return BufResult(Err(e), buf),
66            }
67        }
68        return BufResult(Ok(()), $buf)
69    };
70}
71
72macro_rules! loop_read_vectored {
73    ($buf:ident, $iter:ident, $read_expr:expr) => {{
74        let mut $iter = match $buf.owned_iter() {
75            Ok(buf) => buf,
76            Err(buf) => return BufResult(Ok(0), buf),
77        };
78
79        loop {
80            let len = $iter.buf_capacity();
81            if len > 0 {
82                return $read_expr.await.into_inner();
83            }
84
85            match $iter.next() {
86                Ok(next) => $iter = next,
87                Err(buf) => return BufResult(Ok(0), buf),
88            }
89        }
90    }};
91}
92
93macro_rules! loop_read_to_end {
94    ($buf:ident, $tracker:ident : $tracker_ty:ty,loop $read_expr:expr) => {{
95        let mut $tracker: $tracker_ty = 0;
96        loop {
97            if $buf.len() == $buf.capacity() {
98                $buf.reserve(32);
99            }
100            match $read_expr.await.into_inner() {
101                BufResult(Ok(0), buf) => {
102                    $buf = buf;
103                    break;
104                }
105                BufResult(Ok(read), buf) => {
106                    $tracker += read as $tracker_ty;
107                    $buf = buf;
108                }
109                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
110                    $buf = buf
111                }
112                res => return res,
113            }
114        }
115        BufResult(Ok($tracker as usize), $buf)
116    }};
117}
118
119#[inline]
120fn after_read_to_string(res: io::Result<usize>, buf: Vec<u8>) -> BufResult<usize, String> {
121    match res {
122        Err(err) => {
123            // we have to clear the read bytes if it is not valid utf8 bytes
124            let buf = String::from_utf8(buf).unwrap_or_else(|err| {
125                let mut buf = err.into_bytes();
126                buf.clear();
127
128                // SAFETY: the buffer is empty
129                unsafe { String::from_utf8_unchecked(buf) }
130            });
131
132            BufResult(Err(err), buf)
133        }
134        Ok(n) => match String::from_utf8(buf) {
135            Err(err) => BufResult(
136                Err(std::io::Error::new(ErrorKind::InvalidData, err)),
137                String::new(),
138            ),
139            Ok(data) => BufResult(Ok(n), data),
140        },
141    }
142}
143
144/// Implemented as an extension trait, adding utility methods to all
145/// [`AsyncRead`] types. Callers will tend to import this trait instead of
146/// [`AsyncRead`].
147pub trait AsyncReadExt: AsyncRead {
148    /// Creates a "by reference" adaptor for this instance of [`AsyncRead`].
149    ///
150    /// The returned adapter also implements [`AsyncRead`] and will simply
151    /// borrow this current reader.
152    fn by_ref(&mut self) -> &mut Self
153    where
154        Self: Sized,
155    {
156        self
157    }
158
159    /// Same as [`AsyncRead::read`], but it appends data to the end of the
160    /// buffer; in other words, it read to the beginning of the uninitialized
161    /// area.
162    async fn append<T: IoBufMut>(&mut self, buf: T) -> BufResult<usize, T> {
163        self.read(buf.uninit()).await.map_buffer(Uninit::into_inner)
164    }
165
166    /// Read the exact number of bytes required to fill the buf.
167    async fn read_exact<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
168        loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..)));
169    }
170
171    /// Read all bytes as [`String`] until underlying reader reaches `EOF`.
172    async fn read_to_string(&mut self, buf: String) -> BufResult<usize, String> {
173        let BufResult(res, buf) = self.read_to_end(buf.into_bytes()).await;
174        after_read_to_string(res, buf)
175    }
176
177    /// Read all bytes until underlying reader reaches `EOF`.
178    async fn read_to_end<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
179        &mut self,
180        mut buf: t_alloc!(Vec, u8, A),
181    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
182        loop_read_to_end!(buf, total: usize, loop self.read(buf.slice(total..)))
183    }
184
185    /// Read the exact number of bytes required to fill the vectored buf.
186    async fn read_vectored_exact<T: IoVectoredBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
187        let len = buf.total_capacity();
188        loop_read_exact!(buf, len, read, loop self.read_vectored(buf.slice_mut(read)));
189    }
190
191    /// Create a [`framed::Framed`] reader/writer with the given codec and
192    /// framer.
193    fn framed<T, C, F>(
194        self,
195        codec: C,
196        framer: F,
197    ) -> framed::Framed<Self::ReadHalf, Self::WriteHalf, C, F, T, T>
198    where
199        Self: Splittable + Sized,
200    {
201        framed::Framed::new(codec, framer).with_duplex(self)
202    }
203
204    /// Convenience method to create a [`framed::BytesFramed`] reader/writter
205    /// out of a splittable.
206    #[cfg(feature = "bytes")]
207    fn bytes(self) -> framed::BytesFramed<Self::ReadHalf, Self::WriteHalf>
208    where
209        Self: Splittable + Sized,
210    {
211        framed::BytesFramed::new_bytes().with_duplex(self)
212    }
213
214    /// Create a [`Splittable`] that uses `Self` as [`ReadHalf`] and `()` as
215    /// [`WriteHalf`].
216    ///
217    /// This is useful for creating framed sink with only a reader,
218    /// using the [`AsyncReadExt::framed`] or [`AsyncReadExt::bytes`]
219    /// method, which require a [`Splittable`] to work.
220    ///
221    /// # Examples
222    ///
223    /// ```rust,ignore
224    /// use compio_io::{AsyncReadExt, framed::BytesFramed};
225    ///
226    /// let mut file_bytes = file.read_only().bytes();
227    /// while let Some(Ok(bytes)) = file_bytes.next().await {
228    ///     // process bytes
229    /// }
230    /// ```
231    ///
232    /// [`ReadHalf`]: Splittable::ReadHalf
233    /// [`WriteHalf`]: Splittable::WriteHalf
234    fn read_only(self) -> ReadOnly<Self>
235    where
236        Self: Sized,
237    {
238        ReadOnly(self)
239    }
240
241    /// Creates an adaptor which reads at most `limit` bytes from it.
242    ///
243    /// This function returns a new instance of `AsyncRead` which will read
244    /// at most `limit` bytes, after which it will always return EOF
245    /// (`Ok(0)`). Any read errors will not count towards the number of
246    /// bytes read and future calls to [`read()`] may succeed.
247    ///
248    /// [`read()`]: AsyncRead::read
249    fn take(self, limit: u64) -> Take<Self>
250    where
251        Self: Sized,
252    {
253        Take::new(self, limit)
254    }
255
256    read_scalar!(u8, from_be_bytes, from_le_bytes);
257    read_scalar!(u16, from_be_bytes, from_le_bytes);
258    read_scalar!(u32, from_be_bytes, from_le_bytes);
259    read_scalar!(u64, from_be_bytes, from_le_bytes);
260    read_scalar!(u128, from_be_bytes, from_le_bytes);
261    read_scalar!(i8, from_be_bytes, from_le_bytes);
262    read_scalar!(i16, from_be_bytes, from_le_bytes);
263    read_scalar!(i32, from_be_bytes, from_le_bytes);
264    read_scalar!(i64, from_be_bytes, from_le_bytes);
265    read_scalar!(i128, from_be_bytes, from_le_bytes);
266    read_scalar!(f32, from_be_bytes, from_le_bytes);
267    read_scalar!(f64, from_be_bytes, from_le_bytes);
268}
269
270impl<A: AsyncRead + ?Sized> AsyncReadExt for A {}
271
272/// Implemented as an extension trait, adding utility methods to all
273/// [`AsyncReadAt`] types. Callers will tend to import this trait instead of
274/// [`AsyncReadAt`].
275pub trait AsyncReadAtExt: AsyncReadAt {
276    /// Read the exact number of bytes required to fill `buffer`.
277    ///
278    /// This function reads as many bytes as necessary to completely fill the
279    /// uninitialized space of specified `buffer`.
280    ///
281    /// # Errors
282    ///
283    /// If this function encounters an "end of file" before completely filling
284    /// the buffer, it returns an error of the kind
285    /// [`ErrorKind::UnexpectedEof`]. The contents of `buffer` are unspecified
286    /// in this case.
287    ///
288    /// If any other read error is encountered then this function immediately
289    /// returns. The contents of `buffer` are unspecified in this case.
290    ///
291    /// If this function returns an error, it is unspecified how many bytes it
292    /// has read, but it will never read more than would be necessary to
293    /// completely fill the buffer.
294    ///
295    /// [`ErrorKind::UnexpectedEof`]: std::io::ErrorKind::UnexpectedEof
296    async fn read_exact_at<T: IoBufMut>(&self, mut buf: T, pos: u64) -> BufResult<(), T> {
297        loop_read_exact!(
298            buf,
299            buf.buf_capacity(),
300            read,
301            loop self.read_at(buf.slice(read..), pos + read as u64)
302        );
303    }
304
305    /// Read all bytes as [`String`] until EOF in this source, placing them into
306    /// `buffer`.
307    async fn read_to_string_at(&mut self, buf: String, pos: u64) -> BufResult<usize, String> {
308        let BufResult(res, buf) = self.read_to_end_at(buf.into_bytes(), pos).await;
309        after_read_to_string(res, buf)
310    }
311
312    /// Read all bytes until EOF in this source, placing them into `buffer`.
313    ///
314    /// All bytes read from this source will be appended to the specified buffer
315    /// `buffer`. This function will continuously call [`read_at()`] to append
316    /// more data to `buffer` until [`read_at()`] returns [`Ok(0)`].
317    ///
318    /// If successful, this function will return the total number of bytes read.
319    ///
320    /// [`Ok(0)`]: Ok
321    /// [`read_at()`]: AsyncReadAt::read_at
322    async fn read_to_end_at<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
323        &self,
324        mut buffer: t_alloc!(Vec, u8, A),
325        pos: u64,
326    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
327        loop_read_to_end!(buffer, total: u64, loop self.read_at(buffer.slice(total as usize..), pos + total))
328    }
329
330    /// Like [`AsyncReadExt::read_vectored_exact`], expect that it reads at a
331    /// specified position.
332    async fn read_vectored_exact_at<T: IoVectoredBufMut>(
333        &self,
334        mut buf: T,
335        pos: u64,
336    ) -> BufResult<(), T> {
337        let len = buf.total_capacity();
338        loop_read_exact!(buf, len, read, loop self.read_vectored_at(buf.slice_mut(read), pos + read as u64));
339    }
340}
341
342impl<A: AsyncReadAt + ?Sized> AsyncReadAtExt for A {}
343
344/// An adaptor which implements [`Splittable`] for any [`AsyncRead`], with the
345/// write half being `()`.
346///
347/// This can be used to create a framed stream with only a reader, using
348/// the [`AsyncReadExt::framed`] or [`AsyncReadExt::bytes`] method.
349pub struct ReadOnly<R>(pub R);
350
351impl<R: AsyncRead> AsyncRead for ReadOnly<R> {
352    async fn read<T: IoBufMut>(&mut self, buf: T) -> BufResult<usize, T> {
353        self.0.read(buf).await
354    }
355
356    async fn read_vectored<T: IoVectoredBufMut>(&mut self, buf: T) -> BufResult<usize, T> {
357        self.0.read_vectored(buf).await
358    }
359}
360
361impl<R> Splittable for ReadOnly<R> {
362    type ReadHalf = R;
363    type WriteHalf = ();
364
365    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
366        (self.0, ())
367    }
368}