1use std::{
2 io,
3 task::{Context, Poll},
4};
5
6use compio_buf::{BufResult, IoBuf, bytes::Bytes};
7use compio_io::AsyncWrite;
8use futures_util::{future::poll_fn, ready};
9use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written};
10use thiserror::Error;
11
12use crate::{ConnectionError, ConnectionInner, StoppedError, sync::shared::Shared};
13
14#[derive(Debug)]
32pub struct SendStream {
33 conn: Shared<ConnectionInner>,
34 stream: StreamId,
35 is_0rtt: bool,
36}
37
38impl SendStream {
39 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
40 Self {
41 conn,
42 stream,
43 is_0rtt,
44 }
45 }
46
47 pub fn id(&self) -> StreamId {
49 self.stream
50 }
51
52 pub fn finish(&mut self) -> Result<(), ClosedStream> {
69 let mut state = self.conn.state();
70 match state.conn.send_stream(self.stream).finish() {
71 Ok(()) => {
72 state.wake();
73 Ok(())
74 }
75 Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
76 Err(FinishError::Stopped(_)) => Ok(()),
79 }
80 }
81
82 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
94 let mut state = self.conn.state();
95 if self.is_0rtt && !state.check_0rtt() {
96 return Ok(());
97 }
98 state.conn.send_stream(self.stream).reset(error_code)?;
99 state.wake();
100 Ok(())
101 }
102
103 pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
112 self.conn
113 .state()
114 .conn
115 .send_stream(self.stream)
116 .set_priority(priority)
117 }
118
119 pub fn priority(&self) -> Result<i32, ClosedStream> {
121 self.conn.state().conn.send_stream(self.stream).priority()
122 }
123
124 pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
138 poll_fn(|cx| {
139 let mut state = self.conn.state();
140 if self.is_0rtt && !state.check_0rtt() {
141 return Poll::Ready(Err(StoppedError::ZeroRttRejected));
142 }
143 match state.conn.send_stream(self.stream).stopped() {
144 Err(_) => Poll::Ready(Ok(None)),
145 Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
146 Ok(None) => {
147 if let Some(e) = &state.error {
148 return Poll::Ready(Err(e.clone().into()));
149 }
150 state.stopped.insert(self.stream, cx.waker().clone());
151 Poll::Pending
152 }
153 }
154 })
155 .await
156 }
157
158 fn execute_poll_write<F, R>(&mut self, cx: &mut Context, f: F) -> Poll<Result<R, WriteError>>
159 where
160 F: FnOnce(quinn_proto::SendStream) -> Result<R, quinn_proto::WriteError>,
161 {
162 let mut state = self.conn.try_state()?;
163 if self.is_0rtt && !state.check_0rtt() {
164 return Poll::Ready(Err(WriteError::ZeroRttRejected));
165 }
166 match f(state.conn.send_stream(self.stream)) {
167 Ok(r) => {
168 state.wake();
169 Poll::Ready(Ok(r))
170 }
171 Err(e) => match e.try_into() {
172 Ok(e) => Poll::Ready(Err(e)),
173 Err(()) => {
174 state.writable.insert(self.stream, cx.waker().clone());
175 Poll::Pending
176 }
177 },
178 }
179 }
180
181 pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
189 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await
190 }
191
192 pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
196 let mut chunks = 0;
197 poll_fn(|cx| {
198 loop {
199 if chunks == bufs.len() {
200 return Poll::Ready(Ok(()));
201 }
202 let written = ready!(self.execute_poll_write(cx, |mut stream| {
203 stream.write_chunks(&mut bufs[chunks..])
204 }))?;
205 chunks += written.chunks;
206 }
207 })
208 .await
209 }
210
211 #[cfg(feature = "io-compat")]
213 pub fn into_compat(self) -> CompatSendStream {
214 CompatSendStream(self)
215 }
216}
217
218impl Drop for SendStream {
219 fn drop(&mut self) {
220 let mut state = self.conn.state();
221
222 state.stopped.remove(&self.stream);
224 state.writable.remove(&self.stream);
225
226 if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
227 return;
228 }
229 match state.conn.send_stream(self.stream).finish() {
230 Ok(()) => state.wake(),
231 Err(FinishError::Stopped(reason)) => {
232 if state.conn.send_stream(self.stream).reset(reason).is_ok() {
233 state.wake();
234 }
235 }
236 Err(FinishError::ClosedStream) => {}
238 }
239 }
240}
241
242#[derive(Debug, Error, Clone, PartialEq, Eq)]
244pub enum WriteError {
245 #[error("sending stopped by peer: error {0}")]
249 Stopped(VarInt),
250 #[error("connection lost")]
252 ConnectionLost(#[from] ConnectionError),
253 #[error("closed stream")]
255 ClosedStream,
256 #[error("0-RTT rejected")]
263 ZeroRttRejected,
264 #[cfg(feature = "h3")]
267 #[error("stream not ready")]
268 NotReady,
269}
270
271impl TryFrom<quinn_proto::WriteError> for WriteError {
272 type Error = ();
273
274 fn try_from(value: quinn_proto::WriteError) -> Result<Self, Self::Error> {
275 use quinn_proto::WriteError::*;
276 match value {
277 Stopped(e) => Ok(Self::Stopped(e)),
278 ClosedStream => Ok(Self::ClosedStream),
279 Blocked => Err(()),
280 }
281 }
282}
283
284impl From<StoppedError> for WriteError {
285 fn from(x: StoppedError) -> Self {
286 match x {
287 StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
288 StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
289 }
290 }
291}
292
293impl From<WriteError> for io::Error {
294 fn from(x: WriteError) -> Self {
295 use WriteError::*;
296 let kind = match x {
297 Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
298 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
299 #[cfg(feature = "h3")]
300 NotReady => io::ErrorKind::Other,
301 };
302 Self::new(kind, x)
303 }
304}
305
306impl AsyncWrite for SendStream {
307 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
308 let res =
309 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf.as_init())))
310 .await
311 .map_err(Into::into);
312 BufResult(res, buf)
313 }
314
315 async fn flush(&mut self) -> io::Result<()> {
316 Ok(())
317 }
318
319 async fn shutdown(&mut self) -> io::Result<()> {
320 self.finish()?;
321 Ok(())
322 }
323}
324
325#[cfg(feature = "io-compat")]
326mod compat {
327 use std::{
328 ops::{Deref, DerefMut},
329 pin::Pin,
330 };
331
332 use compio_buf::IntoInner;
333
334 use super::*;
335
336 pub struct CompatSendStream(pub(super) SendStream);
338
339 impl CompatSendStream {
340 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
348 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
349 }
350
351 pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
355 let mut count = 0;
356 poll_fn(|cx| {
357 loop {
358 if count == buf.len() {
359 return Poll::Ready(Ok(()));
360 }
361 let n = ready!(
362 self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..]))
363 )?;
364 count += n;
365 }
366 })
367 .await
368 }
369 }
370
371 impl IntoInner for CompatSendStream {
372 type Inner = SendStream;
373
374 fn into_inner(self) -> Self::Inner {
375 self.0
376 }
377 }
378
379 impl Deref for CompatSendStream {
380 type Target = SendStream;
381
382 fn deref(&self) -> &Self::Target {
383 &self.0
384 }
385 }
386
387 impl DerefMut for CompatSendStream {
388 fn deref_mut(&mut self) -> &mut Self::Target {
389 &mut self.0
390 }
391 }
392
393 impl futures_util::AsyncWrite for CompatSendStream {
394 fn poll_write(
395 self: Pin<&mut Self>,
396 cx: &mut Context<'_>,
397 buf: &[u8],
398 ) -> Poll<io::Result<usize>> {
399 self.get_mut()
400 .execute_poll_write(cx, |mut stream| stream.write(buf))
401 .map_err(Into::into)
402 }
403
404 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
405 Poll::Ready(Ok(()))
406 }
407
408 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
409 self.get_mut().finish()?;
410 Poll::Ready(Ok(()))
411 }
412 }
413}
414
415#[cfg(feature = "io-compat")]
416pub use compat::CompatSendStream;
417
418#[cfg(feature = "h3")]
419pub(crate) mod h3_impl {
420 use compio_buf::bytes::Buf;
421 use h3::quic::{self, StreamErrorIncoming, WriteBuf};
422
423 use super::*;
424
425 impl From<WriteError> for StreamErrorIncoming {
426 fn from(e: WriteError) -> Self {
427 use WriteError::*;
428 match e {
429 Stopped(code) => Self::StreamTerminated {
430 error_code: code.into_inner(),
431 },
432 ConnectionLost(e) => Self::ConnectionErrorIncoming {
433 connection_error: e.into(),
434 },
435
436 e => Self::Unknown(Box::new(e)),
437 }
438 }
439 }
440
441 pub struct SendStream<B> {
444 inner: super::SendStream,
445 buf: Option<WriteBuf<B>>,
446 }
447
448 impl<B> SendStream<B> {
449 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
450 Self {
451 inner: super::SendStream::new(conn, stream, is_0rtt),
452 buf: None,
453 }
454 }
455 }
456
457 impl<B> quic::SendStream<B> for SendStream<B>
458 where
459 B: Buf,
460 {
461 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
462 if let Some(data) = &mut self.buf {
463 while data.has_remaining() {
464 let n = ready!(
465 self.inner
466 .execute_poll_write(cx, |mut stream| stream.write(data.chunk()))
467 )?;
468 data.advance(n);
469 }
470 }
471 self.buf = None;
472 Poll::Ready(Ok(()))
473 }
474
475 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
476 if self.buf.is_some() {
477 return Err(WriteError::NotReady.into());
478 }
479 self.buf = Some(data.into());
480 Ok(())
481 }
482
483 fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
484 Poll::Ready(
485 self.inner
486 .finish()
487 .map_err(|_| WriteError::ClosedStream.into()),
488 )
489 }
490
491 fn reset(&mut self, reset_code: u64) {
492 self.inner
493 .reset(reset_code.try_into().unwrap_or(VarInt::MAX))
494 .ok();
495 }
496
497 fn send_id(&self) -> quic::StreamId {
498 u64::from(self.inner.stream).try_into().unwrap()
499 }
500 }
501
502 impl<B> quic::SendStreamUnframed<B> for SendStream<B>
503 where
504 B: Buf,
505 {
506 fn poll_send<D: Buf>(
507 &mut self,
508 cx: &mut Context<'_>,
509 buf: &mut D,
510 ) -> Poll<Result<usize, StreamErrorIncoming>> {
511 debug_assert!(
513 self.buf.is_some(),
514 "poll_send called while send stream is not ready"
515 );
516
517 let n = ready!(
518 self.inner
519 .execute_poll_write(cx, |mut stream| stream.write(buf.chunk()))
520 )?;
521 buf.advance(n);
522 Poll::Ready(Ok(n))
523 }
524 }
525}