1use std::{
2 io,
3 task::{Context, Poll},
4};
5
6use compio_buf::{BufResult, IoBuf, bytes::Bytes};
7use compio_io::AsyncWrite;
8use futures_util::{future::poll_fn, ready};
9use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written};
10use thiserror::Error;
11
12use crate::{ConnectionError, ConnectionInner, sync::shared::Shared};
13
14#[derive(Debug)]
32pub struct SendStream {
33 conn: Shared<ConnectionInner>,
34 stream: StreamId,
35 is_0rtt: bool,
36}
37
38impl SendStream {
39 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
40 Self {
41 conn,
42 stream,
43 is_0rtt,
44 }
45 }
46
47 pub fn id(&self) -> StreamId {
49 self.stream
50 }
51
52 pub fn finish(&mut self) -> Result<(), ClosedStream> {
69 let mut state = self.conn.state();
70 match state.conn.send_stream(self.stream).finish() {
71 Ok(()) => {
72 state.wake();
73 Ok(())
74 }
75 Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
76 Err(FinishError::Stopped(_)) => Ok(()),
79 }
80 }
81
82 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
94 let mut state = self.conn.state();
95 if self.is_0rtt && !state.check_0rtt() {
96 return Ok(());
97 }
98 state.conn.send_stream(self.stream).reset(error_code)?;
99 state.wake();
100 Ok(())
101 }
102
103 pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
112 self.conn
113 .state()
114 .conn
115 .send_stream(self.stream)
116 .set_priority(priority)
117 }
118
119 pub fn priority(&self) -> Result<i32, ClosedStream> {
121 self.conn.state().conn.send_stream(self.stream).priority()
122 }
123
124 pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
138 poll_fn(|cx| {
139 let mut state = self.conn.state();
140 if self.is_0rtt && !state.check_0rtt() {
141 return Poll::Ready(Err(StoppedError::ZeroRttRejected));
142 }
143 match state.conn.send_stream(self.stream).stopped() {
144 Err(_) => Poll::Ready(Ok(None)),
145 Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
146 Ok(None) => {
147 if let Some(e) = &state.error {
148 return Poll::Ready(Err(e.clone().into()));
149 }
150 state.stopped.insert(self.stream, cx.waker().clone());
151 Poll::Pending
152 }
153 }
154 })
155 .await
156 }
157
158 fn execute_poll_write<F, R>(&mut self, cx: &mut Context, f: F) -> Poll<Result<R, WriteError>>
159 where
160 F: FnOnce(quinn_proto::SendStream) -> Result<R, quinn_proto::WriteError>,
161 {
162 let mut state = self.conn.try_state()?;
163 if self.is_0rtt && !state.check_0rtt() {
164 return Poll::Ready(Err(WriteError::ZeroRttRejected));
165 }
166 match f(state.conn.send_stream(self.stream)) {
167 Ok(r) => {
168 state.wake();
169 Poll::Ready(Ok(r))
170 }
171 Err(e) => match e.try_into() {
172 Ok(e) => Poll::Ready(Err(e)),
173 Err(()) => {
174 state.writable.insert(self.stream, cx.waker().clone());
175 Poll::Pending
176 }
177 },
178 }
179 }
180
181 pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
189 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await
190 }
191
192 pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
196 let mut chunks = 0;
197 poll_fn(|cx| {
198 loop {
199 if chunks == bufs.len() {
200 return Poll::Ready(Ok(()));
201 }
202 let written = ready!(self.execute_poll_write(cx, |mut stream| {
203 stream.write_chunks(&mut bufs[chunks..])
204 }))?;
205 chunks += written.chunks;
206 }
207 })
208 .await
209 }
210
211 #[cfg(feature = "io-compat")]
213 pub fn into_compat(self) -> CompatSendStream {
214 CompatSendStream(self)
215 }
216}
217
218impl Drop for SendStream {
219 fn drop(&mut self) {
220 let mut state = self.conn.state();
221
222 state.stopped.remove(&self.stream);
224 state.writable.remove(&self.stream);
225
226 if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
227 return;
228 }
229 match state.conn.send_stream(self.stream).finish() {
230 Ok(()) => state.wake(),
231 Err(FinishError::Stopped(reason)) => {
232 if state.conn.send_stream(self.stream).reset(reason).is_ok() {
233 state.wake();
234 }
235 }
236 Err(FinishError::ClosedStream) => {}
238 }
239 }
240}
241
242#[derive(Debug, Error, Clone, PartialEq, Eq)]
244pub enum WriteError {
245 #[error("sending stopped by peer: error {0}")]
249 Stopped(VarInt),
250 #[error("connection lost")]
252 ConnectionLost(#[from] ConnectionError),
253 #[error("closed stream")]
255 ClosedStream,
256 #[error("0-RTT rejected")]
263 ZeroRttRejected,
264 #[cfg(feature = "h3")]
267 #[error("stream not ready")]
268 NotReady,
269}
270
271impl TryFrom<quinn_proto::WriteError> for WriteError {
272 type Error = ();
273
274 fn try_from(value: quinn_proto::WriteError) -> Result<Self, Self::Error> {
275 use quinn_proto::WriteError::*;
276 match value {
277 Stopped(e) => Ok(Self::Stopped(e)),
278 ClosedStream => Ok(Self::ClosedStream),
279 Blocked => Err(()),
280 }
281 }
282}
283
284impl From<StoppedError> for WriteError {
285 fn from(x: StoppedError) -> Self {
286 match x {
287 StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
288 StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
289 }
290 }
291}
292
293impl From<WriteError> for io::Error {
294 fn from(x: WriteError) -> Self {
295 use WriteError::*;
296 let kind = match x {
297 Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
298 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
299 #[cfg(feature = "h3")]
300 NotReady => io::ErrorKind::Other,
301 };
302 Self::new(kind, x)
303 }
304}
305
306#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
308pub enum StoppedError {
309 #[error("connection lost")]
311 ConnectionLost(#[from] ConnectionError),
312 #[error("0-RTT rejected")]
319 ZeroRttRejected,
320}
321
322impl From<StoppedError> for io::Error {
323 fn from(x: StoppedError) -> Self {
324 use StoppedError::*;
325 let kind = match x {
326 ZeroRttRejected => io::ErrorKind::ConnectionReset,
327 ConnectionLost(_) => io::ErrorKind::NotConnected,
328 };
329 Self::new(kind, x)
330 }
331}
332
333impl AsyncWrite for SendStream {
334 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
335 let res =
336 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf.as_init())))
337 .await
338 .map_err(Into::into);
339 BufResult(res, buf)
340 }
341
342 async fn flush(&mut self) -> io::Result<()> {
343 Ok(())
344 }
345
346 async fn shutdown(&mut self) -> io::Result<()> {
347 self.finish()?;
348 Ok(())
349 }
350}
351
352#[cfg(feature = "io-compat")]
353mod compat {
354 use std::{
355 ops::{Deref, DerefMut},
356 pin::Pin,
357 };
358
359 use compio_buf::IntoInner;
360
361 use super::*;
362
363 pub struct CompatSendStream(pub(super) SendStream);
365
366 impl CompatSendStream {
367 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
375 poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
376 }
377
378 pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
382 let mut count = 0;
383 poll_fn(|cx| {
384 loop {
385 if count == buf.len() {
386 return Poll::Ready(Ok(()));
387 }
388 let n = ready!(
389 self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..]))
390 )?;
391 count += n;
392 }
393 })
394 .await
395 }
396 }
397
398 impl IntoInner for CompatSendStream {
399 type Inner = SendStream;
400
401 fn into_inner(self) -> Self::Inner {
402 self.0
403 }
404 }
405
406 impl Deref for CompatSendStream {
407 type Target = SendStream;
408
409 fn deref(&self) -> &Self::Target {
410 &self.0
411 }
412 }
413
414 impl DerefMut for CompatSendStream {
415 fn deref_mut(&mut self) -> &mut Self::Target {
416 &mut self.0
417 }
418 }
419
420 impl futures_util::AsyncWrite for CompatSendStream {
421 fn poll_write(
422 self: Pin<&mut Self>,
423 cx: &mut Context<'_>,
424 buf: &[u8],
425 ) -> Poll<io::Result<usize>> {
426 self.get_mut()
427 .execute_poll_write(cx, |mut stream| stream.write(buf))
428 .map_err(Into::into)
429 }
430
431 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
432 Poll::Ready(Ok(()))
433 }
434
435 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
436 self.get_mut().finish()?;
437 Poll::Ready(Ok(()))
438 }
439 }
440}
441
442#[cfg(feature = "io-compat")]
443pub use compat::CompatSendStream;
444
445#[cfg(feature = "h3")]
446pub(crate) mod h3_impl {
447 use compio_buf::bytes::Buf;
448 use h3::quic::{self, StreamErrorIncoming, WriteBuf};
449
450 use super::*;
451
452 impl From<WriteError> for StreamErrorIncoming {
453 fn from(e: WriteError) -> Self {
454 use WriteError::*;
455 match e {
456 Stopped(code) => Self::StreamTerminated {
457 error_code: code.into_inner(),
458 },
459 ConnectionLost(e) => Self::ConnectionErrorIncoming {
460 connection_error: e.into(),
461 },
462
463 e => Self::Unknown(Box::new(e)),
464 }
465 }
466 }
467
468 pub struct SendStream<B> {
471 inner: super::SendStream,
472 buf: Option<WriteBuf<B>>,
473 }
474
475 impl<B> SendStream<B> {
476 pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
477 Self {
478 inner: super::SendStream::new(conn, stream, is_0rtt),
479 buf: None,
480 }
481 }
482 }
483
484 impl<B> quic::SendStream<B> for SendStream<B>
485 where
486 B: Buf,
487 {
488 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
489 if let Some(data) = &mut self.buf {
490 while data.has_remaining() {
491 let n = ready!(
492 self.inner
493 .execute_poll_write(cx, |mut stream| stream.write(data.chunk()))
494 )?;
495 data.advance(n);
496 }
497 }
498 self.buf = None;
499 Poll::Ready(Ok(()))
500 }
501
502 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
503 if self.buf.is_some() {
504 return Err(WriteError::NotReady.into());
505 }
506 self.buf = Some(data.into());
507 Ok(())
508 }
509
510 fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
511 Poll::Ready(
512 self.inner
513 .finish()
514 .map_err(|_| WriteError::ClosedStream.into()),
515 )
516 }
517
518 fn reset(&mut self, reset_code: u64) {
519 self.inner
520 .reset(reset_code.try_into().unwrap_or(VarInt::MAX))
521 .ok();
522 }
523
524 fn send_id(&self) -> quic::StreamId {
525 u64::from(self.inner.stream).try_into().unwrap()
526 }
527 }
528
529 impl<B> quic::SendStreamUnframed<B> for SendStream<B>
530 where
531 B: Buf,
532 {
533 fn poll_send<D: Buf>(
534 &mut self,
535 cx: &mut Context<'_>,
536 buf: &mut D,
537 ) -> Poll<Result<usize, StreamErrorIncoming>> {
538 debug_assert!(
540 self.buf.is_some(),
541 "poll_send called while send stream is not ready"
542 );
543
544 let n = ready!(
545 self.inner
546 .execute_poll_write(cx, |mut stream| stream.write(buf.chunk()))
547 )?;
548 buf.advance(n);
549 Poll::Ready(Ok(n))
550 }
551 }
552}