compio_driver/
asyncify.rs1use std::{
2 fmt,
3 sync::{
4 Arc,
5 atomic::{AtomicUsize, Ordering},
6 },
7 time::Duration,
8};
9
10use flume::{Receiver, Sender, TrySendError, bounded};
11
12#[derive(Copy, Clone, PartialEq, Eq)]
17pub struct DispatchError<T>(pub T);
18
19impl<T> DispatchError<T> {
20 pub fn into_inner(self) -> T {
22 self.0
23 }
24}
25
26impl<T> fmt::Debug for DispatchError<T> {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 "DispatchError(..)".fmt(f)
29 }
30}
31
32impl<T> fmt::Display for DispatchError<T> {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 "all threads are busy".fmt(f)
35 }
36}
37
38impl<T> std::error::Error for DispatchError<T> {}
39
40type BoxedDispatchable = Box<dyn Dispatchable + Send>;
41
42pub trait Dispatchable: Send + 'static {
49 fn run(self: Box<Self>);
51}
52
53impl<F> Dispatchable for F
54where
55 F: FnOnce() + Send + 'static,
56{
57 fn run(self: Box<Self>) {
58 (*self)()
59 }
60}
61
62struct CounterGuard(Arc<AtomicUsize>);
63
64impl Drop for CounterGuard {
65 fn drop(&mut self) {
66 self.0.fetch_sub(1, Ordering::AcqRel);
67 }
68}
69
70fn worker(
71 receiver: Receiver<BoxedDispatchable>,
72 counter: Arc<AtomicUsize>,
73 timeout: Duration,
74) -> impl FnOnce() {
75 move || {
76 counter.fetch_add(1, Ordering::AcqRel);
77 let _guard = CounterGuard(counter);
78 while let Ok(f) = receiver.recv_timeout(timeout) {
79 f.run();
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct AsyncifyPool {
87 sender: Sender<BoxedDispatchable>,
88 receiver: Receiver<BoxedDispatchable>,
89 counter: Arc<AtomicUsize>,
90 thread_limit: usize,
91 recv_timeout: Duration,
92}
93
94impl AsyncifyPool {
95 pub fn new(thread_limit: usize, recv_timeout: Duration) -> Self {
98 let (sender, receiver) = bounded(0);
99 Self {
100 sender,
101 receiver,
102 counter: Arc::new(AtomicUsize::new(0)),
103 thread_limit,
104 recv_timeout,
105 }
106 }
107
108 pub fn dispatch<D: Dispatchable>(&self, f: D) -> Result<(), DispatchError<D>> {
113 match self.sender.try_send(Box::new(f) as BoxedDispatchable) {
114 Ok(_) => Ok(()),
115 Err(e) => match e {
116 TrySendError::Full(f) => {
117 if self.thread_limit == 0 {
118 panic!("the thread pool is needed but no worker thread is running");
119 } else if self.counter.load(Ordering::Acquire) >= self.thread_limit {
120 Err(DispatchError(*unsafe {
122 Box::from_raw(Box::into_raw(f).cast())
123 }))
124 } else {
125 std::thread::spawn(worker(
126 self.receiver.clone(),
127 self.counter.clone(),
128 self.recv_timeout,
129 ));
130 self.sender.send(f).expect("the channel should not be full");
131 Ok(())
132 }
133 }
134 TrySendError::Disconnected(_) => {
135 unreachable!("receiver should not all disconnected")
136 }
137 },
138 }
139 }
140}