Skip to main content

compio_driver/
asyncify.rs

1use std::{
2    fmt,
3    sync::{
4        Arc,
5        atomic::{AtomicUsize, Ordering},
6    },
7    time::Duration,
8};
9
10use flume::{Receiver, Sender, TrySendError, bounded};
11
12/// An error that may be emitted when all worker threads are busy.
13///
14/// It simply contains the dispatchable value with a convenient [`fmt::Debug`]
15/// and [`fmt::Display`] implementation.
16#[derive(Copy, Clone, PartialEq, Eq)]
17pub struct DispatchError<T>(pub T);
18
19impl<T> DispatchError<T> {
20    /// Consume the error, yielding the dispatchable that failed to be sent.
21    pub fn into_inner(self) -> T {
22        self.0
23    }
24}
25
26impl<T> fmt::Debug for DispatchError<T> {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        "DispatchError(..)".fmt(f)
29    }
30}
31
32impl<T> fmt::Display for DispatchError<T> {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        "all threads are busy".fmt(f)
35    }
36}
37
38impl<T> std::error::Error for DispatchError<T> {}
39
40type BoxedDispatchable = Box<dyn Dispatchable + Send>;
41
42/// A trait for dispatching a closure.
43///
44/// Used by [`AsyncifyPool`] to run closures in other threads.
45///
46/// It's implemented for all `FnOnce() + Send + 'static` but may also be
47/// implemented for any other types that are `Send` and `'static`.
48pub trait Dispatchable: Send + 'static {
49    /// Run the dispatchable
50    fn run(self: Box<Self>);
51}
52
53impl<F> Dispatchable for F
54where
55    F: FnOnce() + Send + 'static,
56{
57    fn run(self: Box<Self>) {
58        (*self)()
59    }
60}
61
62struct CounterGuard(Arc<AtomicUsize>);
63
64impl Drop for CounterGuard {
65    fn drop(&mut self) {
66        self.0.fetch_sub(1, Ordering::AcqRel);
67    }
68}
69
70fn worker(
71    receiver: Receiver<BoxedDispatchable>,
72    counter: Arc<AtomicUsize>,
73    timeout: Duration,
74) -> impl FnOnce() {
75    move || {
76        counter.fetch_add(1, Ordering::AcqRel);
77        let _guard = CounterGuard(counter);
78        while let Ok(f) = receiver.recv_timeout(timeout) {
79            f.run();
80        }
81    }
82}
83
84/// A thread pool to perform blocking operations in other threads.
85#[derive(Debug, Clone)]
86pub struct AsyncifyPool {
87    sender: Sender<BoxedDispatchable>,
88    receiver: Receiver<BoxedDispatchable>,
89    counter: Arc<AtomicUsize>,
90    thread_limit: usize,
91    recv_timeout: Duration,
92}
93
94impl AsyncifyPool {
95    /// Create [`AsyncifyPool`] with thread number limit and channel receive
96    /// timeout.
97    pub fn new(thread_limit: usize, recv_timeout: Duration) -> Self {
98        let (sender, receiver) = bounded(0);
99        Self {
100            sender,
101            receiver,
102            counter: Arc::new(AtomicUsize::new(0)),
103            thread_limit,
104            recv_timeout,
105        }
106    }
107
108    /// Send a dispatchable, usually a closure, to another thread. Usually the
109    /// user should not use it. When all threads are busy and thread number
110    /// limit has been reached, it will return an error with the original
111    /// dispatchable.
112    pub fn dispatch<D: Dispatchable>(&self, f: D) -> Result<(), DispatchError<D>> {
113        match self.sender.try_send(Box::new(f) as BoxedDispatchable) {
114            Ok(_) => Ok(()),
115            Err(e) => match e {
116                TrySendError::Full(f) => {
117                    if self.thread_limit == 0 {
118                        panic!("the thread pool is needed but no worker thread is running");
119                    } else if self.counter.load(Ordering::Acquire) >= self.thread_limit {
120                        // SAFETY: we can ensure the type
121                        Err(DispatchError(*unsafe {
122                            Box::from_raw(Box::into_raw(f).cast())
123                        }))
124                    } else {
125                        std::thread::spawn(worker(
126                            self.receiver.clone(),
127                            self.counter.clone(),
128                            self.recv_timeout,
129                        ));
130                        self.sender.send(f).expect("the channel should not be full");
131                        Ok(())
132                    }
133                }
134                TrySendError::Disconnected(_) => {
135                    unreachable!("receiver should not all disconnected")
136                }
137            },
138        }
139    }
140}