Skip to main content

compio_executor/
join_handle.rs

1use std::{
2    error::Error,
3    fmt::Display,
4    io,
5    marker::PhantomData,
6    mem::ManuallyDrop,
7    panic::resume_unwind,
8    pin::Pin,
9    ptr,
10    task::{Context, Poll},
11};
12
13use compio_log::{instrument, trace};
14
15use crate::{Panic, task::Task};
16
17/// A handle that awaits the result of a task.
18///
19/// Dropping a [`JoinHandle`] will cancel the task. To run the task in the
20/// background, use [`JoinHandle::detach`].
21#[must_use = "Drop `JoinHandle` will cancel the task. Use `detach` to run it in background."]
22#[derive(Debug)]
23#[repr(transparent)]
24pub struct JoinHandle<T> {
25    task: Option<Task>,
26    _marker: PhantomData<T>,
27}
28
29/// If T is send, we can poll result from other thread
30unsafe impl<T: Send> Send for JoinHandle<T> {}
31
32/// JoinHandle does not expose any &self interface, so it's unconditionally
33/// Sync.
34unsafe impl<T> Sync for JoinHandle<T> {}
35
36impl<T> Unpin for JoinHandle<T> {}
37
38impl<T> JoinHandle<T> {
39    pub(crate) fn new(task: Task) -> Self {
40        Self {
41            task: Some(task),
42            _marker: PhantomData,
43        }
44    }
45
46    /// Cancel the task and wait for the result, if any.
47    pub async fn cancel(self) -> Option<T> {
48        self.task.as_ref()?.cancel(false);
49        self.await.ok()
50    }
51
52    /// Detach the task to let it run in the background.
53    pub fn detach(self) {
54        unsafe { ptr::drop_in_place(&raw mut ManuallyDrop::new(self).task) };
55    }
56}
57
58/// Task failed to execute to completion.
59#[derive(Debug)]
60pub enum JoinError {
61    /// The task was cancelled.
62    Cancelled,
63    /// The task panicked.
64    Panicked(Panic),
65}
66
67/// Trait to resume unwind from a [`JoinError`].
68pub trait ResumeUnwind {
69    /// The output type.
70    type Output;
71
72    /// Resume the panic if the task panicked.
73    fn resume_unwind(self) -> Self::Output;
74}
75
76impl<T> ResumeUnwind for Result<T, JoinError> {
77    type Output = Option<T>;
78
79    fn resume_unwind(self) -> Self::Output {
80        match self {
81            Ok(res) => Some(res),
82            Err(JoinError::Cancelled) => None,
83            Err(JoinError::Panicked(e)) => resume_unwind(e),
84        }
85    }
86}
87
88impl JoinError {
89    /// Resume unwind if the task panicked, otherwise do nothing.
90    pub fn resume_unwind(self) {
91        if let JoinError::Panicked(e) = self {
92            resume_unwind(e)
93        }
94    }
95}
96
97impl Display for JoinError {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        match self {
100            JoinError::Cancelled => write!(f, "Task was cancelled"),
101            JoinError::Panicked(_) => write!(f, "Task has panicked"),
102        }
103    }
104}
105
106impl Error for JoinError {}
107
108impl From<JoinError> for io::Error {
109    fn from(e: JoinError) -> Self {
110        match e {
111            JoinError::Cancelled => io::Error::other("Task was cancelled"),
112            JoinError::Panicked(_) => io::Error::other("Task has panicked"),
113        }
114    }
115}
116
117impl<T> Future for JoinHandle<T> {
118    type Output = Result<T, JoinError>;
119
120    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121        instrument!(compio_log::Level::TRACE, "JoinHandle::poll");
122
123        let task = self.task.as_ref().expect("Cannot poll after completion");
124
125        unsafe { task.poll(cx) }.map(|res| {
126            trace!("Poll ready");
127
128            self.task = None;
129
130            match res {
131                Some(Ok(res)) => Ok(res),
132                Some(Err(e)) => Err(JoinError::Panicked(e)),
133                None => Err(JoinError::Cancelled),
134            }
135        })
136    }
137}
138
139impl<T> Drop for JoinHandle<T> {
140    fn drop(&mut self) {
141        instrument!(compio_log::Level::TRACE, "JoinHandle::drop");
142
143        if let Some(task) = self.task.as_ref() {
144            task.cancel(true);
145        }
146    }
147}