Skip to main content

compio_dispatcher/
lib.rs

1//! Multithreading dispatcher.
2
3#![allow(unused_features)]
4#![warn(missing_docs)]
5#![deny(rustdoc::broken_intra_doc_links)]
6#![doc(
7    html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
8)]
9#![doc(
10    html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
11)]
12
13use std::{
14    collections::HashSet,
15    future::Future,
16    io,
17    num::NonZeroUsize,
18    panic::resume_unwind,
19    thread::{JoinHandle, available_parallelism},
20};
21
22use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
23use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
24use flume::{Sender, unbounded};
25use futures_channel::oneshot;
26
27#[cfg(unix)]
28mod unix;
29
30type Spawning = Box<dyn Spawnable + Send>;
31
32trait Spawnable {
33    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
34}
35
36/// Concrete type for the closure we're sending to worker threads
37struct Concrete<F, R> {
38    callback: oneshot::Sender<R>,
39    func: F,
40}
41
42impl<F, R> Concrete<F, R> {
43    pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
44        let (tx, rx) = oneshot::channel();
45        (Self { callback: tx, func }, rx)
46    }
47}
48
49impl<F, Fut, R> Spawnable for Concrete<F, R>
50where
51    F: FnOnce() -> Fut + Send + 'static,
52    Fut: Future<Output = R>,
53    R: Send + 'static,
54{
55    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
56        let Concrete { callback, func } = *self;
57        handle.spawn(async move {
58            let res = func().await;
59            callback.send(res).ok();
60        })
61    }
62}
63
64impl<F, R> Dispatchable for Concrete<F, R>
65where
66    F: FnOnce() -> R + Send + 'static,
67    R: Send + 'static,
68{
69    fn run(self: Box<Self>) {
70        let Concrete { callback, func } = *self;
71        let res = func();
72        callback.send(res).ok();
73    }
74}
75
76/// The dispatcher. It manages the threads and dispatches the tasks.
77#[derive(Debug)]
78pub struct Dispatcher {
79    sender: Sender<Spawning>,
80    threads: Vec<JoinHandle<()>>,
81    pool: AsyncifyPool,
82}
83
84impl Dispatcher {
85    /// Create the dispatcher with specified number of threads.
86    pub(crate) fn new_impl(builder: DispatcherBuilder) -> io::Result<Self> {
87        let DispatcherBuilder {
88            nthreads,
89            concurrent,
90            #[cfg(unix)]
91            block_signals,
92            stack_size,
93            mut thread_affinity,
94            mut names,
95            mut proactor_builder,
96        } = builder;
97        proactor_builder.force_reuse_thread_pool();
98        let pool = proactor_builder.create_or_get_thread_pool();
99        let (sender, receiver) = unbounded::<Spawning>();
100
101        // Block standard signals before spawning workers.
102        #[cfg(unix)]
103        let _g = unix::mask_signal(block_signals)?;
104
105        let threads = (0..nthreads)
106            .map({
107                |index| {
108                    let proactor_builder = proactor_builder.clone();
109                    let receiver = receiver.clone();
110
111                    let thread_builder = std::thread::Builder::new();
112                    let thread_builder = if let Some(s) = stack_size {
113                        thread_builder.stack_size(s)
114                    } else {
115                        thread_builder
116                    };
117                    let thread_builder = if let Some(f) = &mut names {
118                        thread_builder.name(f(index))
119                    } else {
120                        thread_builder
121                    };
122
123                    let cpus = if let Some(f) = &mut thread_affinity {
124                        f(index)
125                    } else {
126                        HashSet::new()
127                    };
128                    thread_builder.spawn(move || {
129                        Runtime::builder()
130                            .with_proactor(proactor_builder)
131                            .thread_affinity(cpus)
132                            .build()
133                            .expect("cannot create compio runtime")
134                            .block_on(async move {
135                                while let Ok(f) = receiver.recv_async().await {
136                                    let task = Runtime::with_current(|rt| f.spawn(rt));
137                                    if concurrent {
138                                        task.detach()
139                                    } else {
140                                        task.await.ok();
141                                    }
142                                }
143                            });
144                    })
145                }
146            })
147            .collect::<io::Result<Vec<_>>>()?;
148
149        Ok(Self {
150            sender,
151            threads,
152            pool,
153        })
154    }
155
156    /// Create the dispatcher with default config.
157    pub fn new() -> io::Result<Self> {
158        Self::builder().build()
159    }
160
161    /// Create a builder to build a dispatcher.
162    pub fn builder() -> DispatcherBuilder {
163        DispatcherBuilder::default()
164    }
165
166    /// Dispatch a task to the threads
167    ///
168    /// The provided `f` should be [`Send`] because it will be send to another
169    /// thread before calling. The returned [`Future`] need not to be [`Send`]
170    /// because it will be executed on only one thread.
171    ///
172    /// # Error
173    ///
174    /// If all threads have panicked, this method will return an error with the
175    /// sent closure.
176    pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
177    where
178        Fn: (FnOnce() -> Fut) + Send + 'static,
179        Fut: Future<Output = R> + 'static,
180        R: Send + 'static,
181    {
182        let (concrete, rx) = Concrete::new(f);
183
184        match self.sender.send(Box::new(concrete)) {
185            Ok(_) => Ok(rx),
186            Err(err) => {
187                // SAFETY: We know the dispatchable we sent has type `Concrete<Fn, R>`
188                let recovered =
189                    unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
190                Err(DispatchError(recovered.func))
191            }
192        }
193    }
194
195    /// Dispatch a blocking task to the threads.
196    ///
197    /// Blocking pool of the dispatcher will be obtained from the proactor
198    /// builder. So any configuration of the proactor's blocking pool will be
199    /// applied to the dispatcher.
200    ///
201    /// # Error
202    ///
203    /// If all threads are busy and the thread pool is full, this method will
204    /// return an error with the original closure. The limit can be configured
205    /// with [`DispatcherBuilder::proactor_builder`] and
206    /// [`ProactorBuilder::thread_pool_limit`].
207    pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
208    where
209        Fn: FnOnce() -> R + Send + 'static,
210        R: Send + 'static,
211    {
212        let (concrete, rx) = Concrete::new(f);
213
214        self.pool
215            .dispatch(concrete)
216            .map_err(|e| DispatchError(e.0.func))?;
217
218        Ok(rx)
219    }
220
221    /// Stop the dispatcher and wait for the threads to complete. If there is a
222    /// thread panicked, this method will resume the panic.
223    pub async fn join(self) -> io::Result<()> {
224        drop(self.sender);
225        let (tx, rx) = oneshot::channel::<Vec<_>>();
226        if let Err(f) = self.pool.dispatch({
227            move || {
228                let results = self
229                    .threads
230                    .into_iter()
231                    .map(|thread| thread.join())
232                    .collect();
233                tx.send(results).ok();
234            }
235        }) {
236            std::thread::spawn(f.0);
237        }
238        let results = rx
239            .await
240            .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
241        for res in results {
242            res.unwrap_or_else(|e| resume_unwind(e));
243        }
244        Ok(())
245    }
246}
247
248/// A builder for [`Dispatcher`].
249pub struct DispatcherBuilder {
250    nthreads: usize,
251    concurrent: bool,
252    #[cfg(unix)]
253    block_signals: bool,
254    stack_size: Option<usize>,
255    thread_affinity: Option<Box<dyn FnMut(usize) -> HashSet<usize>>>,
256    names: Option<Box<dyn FnMut(usize) -> String>>,
257    proactor_builder: ProactorBuilder,
258}
259
260impl DispatcherBuilder {
261    /// Create a builder with default settings.
262    pub fn new() -> Self {
263        Self {
264            nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
265            concurrent: true,
266            #[cfg(unix)]
267            block_signals: true,
268            stack_size: None,
269            thread_affinity: None,
270            names: None,
271            proactor_builder: ProactorBuilder::new(),
272        }
273    }
274
275    /// If execute tasks concurrently. Default to be `true`.
276    ///
277    /// When set to `false`, tasks are executed sequentially without any
278    /// concurrency within the thread.
279    pub fn concurrent(mut self, concurrent: bool) -> Self {
280        self.concurrent = concurrent;
281        self
282    }
283
284    /// Set the number of worker threads of the dispatcher. The default value is
285    /// the CPU number. If the CPU number could not be retrieved, the
286    /// default value is 1.
287    pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
288        self.nthreads = nthreads.get();
289        self
290    }
291
292    /// Set the size of stack of the worker threads.
293    pub fn stack_size(mut self, s: usize) -> Self {
294        self.stack_size = Some(s);
295        self
296    }
297
298    /// Block standard signals on worker threads.
299    ///
300    /// Default to `true`. When enabled, `SIGINT`, `SIGTERM`, `SIGQUIT`,
301    /// `SIGHUP`, `SIGUSR1`, `SIGUSR2`, and `SIGPIPE` are masked on worker
302    /// threads.
303    ///
304    /// This option only has effect on Unix systems. On non-Unix systems, this
305    /// method does nothing.
306    ///
307    /// On Unix systems, when [`Dispatcher`] spawns worker threads, they inherit
308    /// the parent thread's signal mask. By default, SIGINT (Ctrl-C) and other
309    /// signals can be delivered to any thread in the process. If a worker
310    /// thread receives the signal before the async signal handler is polled on
311    /// the main thread, the default signal handler runs (terminating the
312    /// process) instead of the compio signal handler.
313    ///
314    /// This is a well-known issue in multi-threaded Unix applications and
315    /// requires explicit signal masking.
316    #[allow(unused)]
317    pub fn block_signals(mut self, block_signals: bool) -> Self {
318        #[cfg(unix)]
319        {
320            self.block_signals = block_signals;
321        }
322        self
323    }
324
325    /// Set the thread affinity for the dispatcher.
326    pub fn thread_affinity(mut self, f: impl FnMut(usize) -> HashSet<usize> + 'static) -> Self {
327        self.thread_affinity = Some(Box::new(f));
328        self
329    }
330
331    /// Provide a function to assign names to the worker threads.
332    pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
333        self.names = Some(Box::new(f) as _);
334        self
335    }
336
337    /// Set the proactor builder for the inner runtimes.
338    pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
339        self.proactor_builder = builder;
340        self
341    }
342
343    /// Build the [`Dispatcher`].
344    pub fn build(self) -> io::Result<Dispatcher> {
345        Dispatcher::new_impl(self)
346    }
347}
348
349impl Default for DispatcherBuilder {
350    fn default() -> Self {
351        Self::new()
352    }
353}