compio_dispatcher/
lib.rs

1//! Multithreading dispatcher.
2
3#![warn(missing_docs)]
4#![deny(rustdoc::broken_intra_doc_links)]
5#![doc(
6    html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
7)]
8#![doc(
9    html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
10)]
11
12use std::{
13    collections::HashSet,
14    future::Future,
15    io,
16    num::NonZeroUsize,
17    panic::resume_unwind,
18    thread::{JoinHandle, available_parallelism},
19};
20
21use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
22use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
23use flume::{Sender, unbounded};
24use futures_channel::oneshot;
25
26type Spawning = Box<dyn Spawnable + Send>;
27
28trait Spawnable {
29    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
30}
31
32/// Concrete type for the closure we're sending to worker threads
33struct Concrete<F, R> {
34    callback: oneshot::Sender<R>,
35    func: F,
36}
37
38impl<F, R> Concrete<F, R> {
39    pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
40        let (tx, rx) = oneshot::channel();
41        (Self { callback: tx, func }, rx)
42    }
43}
44
45impl<F, Fut, R> Spawnable for Concrete<F, R>
46where
47    F: FnOnce() -> Fut + Send + 'static,
48    Fut: Future<Output = R>,
49    R: Send + 'static,
50{
51    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
52        let Concrete { callback, func } = *self;
53        handle.spawn(async move {
54            let res = func().await;
55            callback.send(res).ok();
56        })
57    }
58}
59
60impl<F, R> Dispatchable for Concrete<F, R>
61where
62    F: FnOnce() -> R + Send + 'static,
63    R: Send + 'static,
64{
65    fn run(self: Box<Self>) {
66        let Concrete { callback, func } = *self;
67        let res = func();
68        callback.send(res).ok();
69    }
70}
71
72/// The dispatcher. It manages the threads and dispatches the tasks.
73#[derive(Debug)]
74pub struct Dispatcher {
75    sender: Sender<Spawning>,
76    threads: Vec<JoinHandle<()>>,
77    pool: AsyncifyPool,
78}
79
80impl Dispatcher {
81    /// Create the dispatcher with specified number of threads.
82    pub(crate) fn new_impl(builder: DispatcherBuilder) -> io::Result<Self> {
83        let DispatcherBuilder {
84            nthreads,
85            concurrent,
86            stack_size,
87            mut thread_affinity,
88            mut names,
89            mut proactor_builder,
90        } = builder;
91        proactor_builder.force_reuse_thread_pool();
92        let pool = proactor_builder.create_or_get_thread_pool();
93        let (sender, receiver) = unbounded::<Spawning>();
94
95        let threads = (0..nthreads)
96            .map({
97                |index| {
98                    let proactor_builder = proactor_builder.clone();
99                    let receiver = receiver.clone();
100
101                    let thread_builder = std::thread::Builder::new();
102                    let thread_builder = if let Some(s) = stack_size {
103                        thread_builder.stack_size(s)
104                    } else {
105                        thread_builder
106                    };
107                    let thread_builder = if let Some(f) = &mut names {
108                        thread_builder.name(f(index))
109                    } else {
110                        thread_builder
111                    };
112
113                    let cpus = if let Some(f) = &mut thread_affinity {
114                        f(index)
115                    } else {
116                        HashSet::new()
117                    };
118                    thread_builder.spawn(move || {
119                        Runtime::builder()
120                            .with_proactor(proactor_builder)
121                            .thread_affinity(cpus)
122                            .build()
123                            .expect("cannot create compio runtime")
124                            .block_on(async move {
125                                while let Ok(f) = receiver.recv_async().await {
126                                    let task = Runtime::with_current(|rt| f.spawn(rt));
127                                    if concurrent {
128                                        task.detach()
129                                    } else {
130                                        task.await.ok();
131                                    }
132                                }
133                            });
134                    })
135                }
136            })
137            .collect::<io::Result<Vec<_>>>()?;
138        Ok(Self {
139            sender,
140            threads,
141            pool,
142        })
143    }
144
145    /// Create the dispatcher with default config.
146    pub fn new() -> io::Result<Self> {
147        Self::builder().build()
148    }
149
150    /// Create a builder to build a dispatcher.
151    pub fn builder() -> DispatcherBuilder {
152        DispatcherBuilder::default()
153    }
154
155    /// Dispatch a task to the threads
156    ///
157    /// The provided `f` should be [`Send`] because it will be send to another
158    /// thread before calling. The returned [`Future`] need not to be [`Send`]
159    /// because it will be executed on only one thread.
160    ///
161    /// # Error
162    ///
163    /// If all threads have panicked, this method will return an error with the
164    /// sent closure.
165    pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
166    where
167        Fn: (FnOnce() -> Fut) + Send + 'static,
168        Fut: Future<Output = R> + 'static,
169        R: Send + 'static,
170    {
171        let (concrete, rx) = Concrete::new(f);
172
173        match self.sender.send(Box::new(concrete)) {
174            Ok(_) => Ok(rx),
175            Err(err) => {
176                // SAFETY: We know the dispatchable we sent has type `Concrete<Fn, R>`
177                let recovered =
178                    unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
179                Err(DispatchError(recovered.func))
180            }
181        }
182    }
183
184    /// Dispatch a blocking task to the threads.
185    ///
186    /// Blocking pool of the dispatcher will be obtained from the proactor
187    /// builder. So any configuration of the proactor's blocking pool will be
188    /// applied to the dispatcher.
189    ///
190    /// # Error
191    ///
192    /// If all threads are busy and the thread pool is full, this method will
193    /// return an error with the original closure. The limit can be configured
194    /// with [`DispatcherBuilder::proactor_builder`] and
195    /// [`ProactorBuilder::thread_pool_limit`].
196    pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
197    where
198        Fn: FnOnce() -> R + Send + 'static,
199        R: Send + 'static,
200    {
201        let (concrete, rx) = Concrete::new(f);
202
203        self.pool
204            .dispatch(concrete)
205            .map_err(|e| DispatchError(e.0.func))?;
206
207        Ok(rx)
208    }
209
210    /// Stop the dispatcher and wait for the threads to complete. If there is a
211    /// thread panicked, this method will resume the panic.
212    pub async fn join(self) -> io::Result<()> {
213        drop(self.sender);
214        let (tx, rx) = oneshot::channel::<Vec<_>>();
215        if let Err(f) = self.pool.dispatch({
216            move || {
217                let results = self
218                    .threads
219                    .into_iter()
220                    .map(|thread| thread.join())
221                    .collect();
222                tx.send(results).ok();
223            }
224        }) {
225            std::thread::spawn(f.0);
226        }
227        let results = rx
228            .await
229            .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
230        for res in results {
231            res.unwrap_or_else(|e| resume_unwind(e));
232        }
233        Ok(())
234    }
235}
236
237/// A builder for [`Dispatcher`].
238pub struct DispatcherBuilder {
239    nthreads: usize,
240    concurrent: bool,
241    stack_size: Option<usize>,
242    thread_affinity: Option<Box<dyn FnMut(usize) -> HashSet<usize>>>,
243    names: Option<Box<dyn FnMut(usize) -> String>>,
244    proactor_builder: ProactorBuilder,
245}
246
247impl DispatcherBuilder {
248    /// Create a builder with default settings.
249    pub fn new() -> Self {
250        Self {
251            nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
252            concurrent: true,
253            stack_size: None,
254            thread_affinity: None,
255            names: None,
256            proactor_builder: ProactorBuilder::new(),
257        }
258    }
259
260    /// If execute tasks concurrently. Default to be `true`.
261    ///
262    /// When set to `false`, tasks are executed sequentially without any
263    /// concurrency within the thread.
264    pub fn concurrent(mut self, concurrent: bool) -> Self {
265        self.concurrent = concurrent;
266        self
267    }
268
269    /// Set the number of worker threads of the dispatcher. The default value is
270    /// the CPU number. If the CPU number could not be retrieved, the
271    /// default value is 1.
272    pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
273        self.nthreads = nthreads.get();
274        self
275    }
276
277    /// Set the size of stack of the worker threads.
278    pub fn stack_size(mut self, s: usize) -> Self {
279        self.stack_size = Some(s);
280        self
281    }
282
283    /// Set the thread affinity for the dispatcher.
284    pub fn thread_affinity(mut self, f: impl FnMut(usize) -> HashSet<usize> + 'static) -> Self {
285        self.thread_affinity = Some(Box::new(f));
286        self
287    }
288
289    /// Provide a function to assign names to the worker threads.
290    pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
291        self.names = Some(Box::new(f) as _);
292        self
293    }
294
295    /// Set the proactor builder for the inner runtimes.
296    pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
297        self.proactor_builder = builder;
298        self
299    }
300
301    /// Build the [`Dispatcher`].
302    pub fn build(self) -> io::Result<Dispatcher> {
303        Dispatcher::new_impl(self)
304    }
305}
306
307impl Default for DispatcherBuilder {
308    fn default() -> Self {
309        Self::new()
310    }
311}