1use compio_buf::{BufResult, IntoInner, IoBuf, IoVectoredBuf};
2
3use crate::{AsyncWrite, AsyncWriteAt, IoResult, framed, util::Splittable};
4
5macro_rules! write_scalar {
7 ($t:ty, $be:ident, $le:ident) => {
8 ::paste::paste! {
9 #[doc = concat!("Write a big endian `", stringify!($t), "` into the underlying writer.")]
10 async fn [< write_ $t >](&mut self, num: $t) -> IoResult<()> {
11 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
12
13 const LEN: usize = ::std::mem::size_of::<$t>();
14 let BufResult(res, _) = self
15 .write_all(ArrayVec::<u8, LEN>::from(num.$be()))
16 .await;
17 res
18 }
19
20 #[doc = concat!("Write a little endian `", stringify!($t), "` into the underlying writer.")]
21 async fn [< write_ $t _le >](&mut self, num: $t) -> IoResult<()> {
22 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
23
24 const LEN: usize = ::std::mem::size_of::<$t>();
25 let BufResult(res, _) = self
26 .write_all(ArrayVec::<u8, LEN>::from(num.$le()))
27 .await;
28 res
29 }
30 }
31 };
32}
33
34macro_rules! loop_write_all {
36 ($buf:ident, $len:expr, $needle:ident,loop $expr_expr:expr) => {
37 let len = $len;
38 let mut $needle = 0;
39
40 while $needle < len {
41 match $expr_expr.await.into_inner() {
42 BufResult(Ok(0), buf) => {
43 return BufResult(
44 Err(::std::io::Error::new(
45 ::std::io::ErrorKind::WriteZero,
46 "failed to write whole buffer",
47 )),
48 buf,
49 );
50 }
51 BufResult(Ok(n), buf) => {
52 $needle += n;
53 $buf = buf;
54 }
55 BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
56 $buf = buf;
57 }
58 BufResult(Err(e), buf) => return BufResult(Err(e), buf),
59 }
60 }
61
62 return BufResult(Ok(()), $buf);
63 };
64}
65
66macro_rules! loop_write_vectored {
67 ($buf:ident, $iter:ident, $read_expr:expr) => {{
68 let mut $iter = match $buf.owned_iter() {
69 Ok(buf) => buf,
70 Err(buf) => return BufResult(Ok(0), buf),
71 };
72
73 loop {
74 if $iter.buf_len() > 0 {
75 return $read_expr.await.into_inner();
76 }
77
78 match $iter.next() {
79 Ok(next) => $iter = next,
80 Err(buf) => return BufResult(Ok(0), buf),
81 }
82 }
83 }};
84}
85
86pub trait AsyncWriteExt: AsyncWrite {
90 fn by_ref(&mut self) -> &mut Self
95 where
96 Self: Sized,
97 {
98 self
99 }
100
101 async fn write_all<T: IoBuf>(&mut self, mut buf: T) -> BufResult<(), T> {
103 loop_write_all!(
104 buf,
105 buf.buf_len(),
106 needle,
107 loop self.write(buf.slice(needle..))
108 );
109 }
110
111 async fn write_vectored_all<T: IoVectoredBuf>(&mut self, mut buf: T) -> BufResult<(), T> {
115 let len = buf.total_len();
116 loop_write_all!(buf, len, needle, loop self.write_vectored(buf.slice(needle)));
117 }
118
119 fn framed<T, C, F>(
122 self,
123 codec: C,
124 framer: F,
125 ) -> framed::Framed<Self::ReadHalf, Self::WriteHalf, C, F, T, T>
126 where
127 Self: Splittable + Sized,
128 {
129 framed::Framed::new(codec, framer).with_duplex(self)
130 }
131
132 #[cfg(feature = "bytes")]
135 fn bytes(self) -> framed::BytesFramed<Self::ReadHalf, Self::WriteHalf>
136 where
137 Self: Splittable + Sized,
138 {
139 framed::BytesFramed::new_bytes().with_duplex(self)
140 }
141
142 fn write_only(self) -> WriteOnly<Self>
161 where
162 Self: Sized,
163 {
164 WriteOnly(self)
165 }
166
167 write_scalar!(u8, to_be_bytes, to_le_bytes);
168 write_scalar!(u16, to_be_bytes, to_le_bytes);
169 write_scalar!(u32, to_be_bytes, to_le_bytes);
170 write_scalar!(u64, to_be_bytes, to_le_bytes);
171 write_scalar!(u128, to_be_bytes, to_le_bytes);
172 write_scalar!(i8, to_be_bytes, to_le_bytes);
173 write_scalar!(i16, to_be_bytes, to_le_bytes);
174 write_scalar!(i32, to_be_bytes, to_le_bytes);
175 write_scalar!(i64, to_be_bytes, to_le_bytes);
176 write_scalar!(i128, to_be_bytes, to_le_bytes);
177 write_scalar!(f32, to_be_bytes, to_le_bytes);
178 write_scalar!(f64, to_be_bytes, to_le_bytes);
179}
180
181impl<A: AsyncWrite + ?Sized> AsyncWriteExt for A {}
182
183pub trait AsyncWriteAtExt: AsyncWriteAt {
187 async fn write_all_at<T: IoBuf>(&mut self, mut buf: T, pos: u64) -> BufResult<(), T> {
190 loop_write_all!(
191 buf,
192 buf.buf_len(),
193 needle,
194 loop self.write_at(buf.slice(needle..), pos + needle as u64)
195 );
196 }
197
198 async fn write_vectored_all_at<T: IoVectoredBuf>(
201 &mut self,
202 mut buf: T,
203 pos: u64,
204 ) -> BufResult<(), T> {
205 let len = buf.total_len();
206 loop_write_all!(buf, len, needle, loop self.write_vectored_at(buf.slice(needle), pos + needle as u64));
207 }
208}
209
210impl<A: AsyncWriteAt + ?Sized> AsyncWriteAtExt for A {}
211
212pub struct WriteOnly<W>(pub W);
218
219impl<W: AsyncWrite> AsyncWrite for WriteOnly<W> {
220 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
221 self.0.write(buf).await
222 }
223
224 async fn flush(&mut self) -> IoResult<()> {
225 self.0.flush().await
226 }
227
228 async fn shutdown(&mut self) -> IoResult<()> {
229 self.0.shutdown().await
230 }
231}
232
233impl<W> Splittable for WriteOnly<W> {
234 type ReadHalf = ();
235 type WriteHalf = W;
236
237 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
238 ((), self.0)
239 }
240}