Skip to main content

compio_dispatcher/
lib.rs

1//! Multithreading dispatcher.
2
3#![allow(unused_features)]
4#![warn(missing_docs)]
5#![deny(rustdoc::broken_intra_doc_links)]
6#![doc(
7    html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
8)]
9#![doc(
10    html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
11)]
12
13use std::{
14    collections::HashSet,
15    future::Future,
16    io,
17    num::NonZeroUsize,
18    panic::resume_unwind,
19    thread::{JoinHandle, available_parallelism},
20};
21
22use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
23use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
24use flume::{Sender, unbounded};
25use futures_channel::oneshot;
26
27type Spawning = Box<dyn Spawnable + Send>;
28
29trait Spawnable {
30    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
31}
32
33/// Concrete type for the closure we're sending to worker threads
34struct Concrete<F, R> {
35    callback: oneshot::Sender<R>,
36    func: F,
37}
38
39impl<F, R> Concrete<F, R> {
40    pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
41        let (tx, rx) = oneshot::channel();
42        (Self { callback: tx, func }, rx)
43    }
44}
45
46impl<F, Fut, R> Spawnable for Concrete<F, R>
47where
48    F: FnOnce() -> Fut + Send + 'static,
49    Fut: Future<Output = R>,
50    R: Send + 'static,
51{
52    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
53        let Concrete { callback, func } = *self;
54        handle.spawn(async move {
55            let res = func().await;
56            callback.send(res).ok();
57        })
58    }
59}
60
61impl<F, R> Dispatchable for Concrete<F, R>
62where
63    F: FnOnce() -> R + Send + 'static,
64    R: Send + 'static,
65{
66    fn run(self: Box<Self>) {
67        let Concrete { callback, func } = *self;
68        let res = func();
69        callback.send(res).ok();
70    }
71}
72
73/// The dispatcher. It manages the threads and dispatches the tasks.
74#[derive(Debug)]
75pub struct Dispatcher {
76    sender: Sender<Spawning>,
77    threads: Vec<JoinHandle<()>>,
78    pool: AsyncifyPool,
79}
80
81impl Dispatcher {
82    /// Create the dispatcher with specified number of threads.
83    pub(crate) fn new_impl(builder: DispatcherBuilder) -> io::Result<Self> {
84        let DispatcherBuilder {
85            nthreads,
86            concurrent,
87            stack_size,
88            mut thread_affinity,
89            mut names,
90            mut proactor_builder,
91        } = builder;
92        proactor_builder.force_reuse_thread_pool();
93        let pool = proactor_builder.create_or_get_thread_pool();
94        let (sender, receiver) = unbounded::<Spawning>();
95
96        let threads = (0..nthreads)
97            .map({
98                |index| {
99                    let proactor_builder = proactor_builder.clone();
100                    let receiver = receiver.clone();
101
102                    let thread_builder = std::thread::Builder::new();
103                    let thread_builder = if let Some(s) = stack_size {
104                        thread_builder.stack_size(s)
105                    } else {
106                        thread_builder
107                    };
108                    let thread_builder = if let Some(f) = &mut names {
109                        thread_builder.name(f(index))
110                    } else {
111                        thread_builder
112                    };
113
114                    let cpus = if let Some(f) = &mut thread_affinity {
115                        f(index)
116                    } else {
117                        HashSet::new()
118                    };
119                    thread_builder.spawn(move || {
120                        Runtime::builder()
121                            .with_proactor(proactor_builder)
122                            .thread_affinity(cpus)
123                            .build()
124                            .expect("cannot create compio runtime")
125                            .block_on(async move {
126                                while let Ok(f) = receiver.recv_async().await {
127                                    let task = Runtime::with_current(|rt| f.spawn(rt));
128                                    if concurrent {
129                                        task.detach()
130                                    } else {
131                                        task.await.ok();
132                                    }
133                                }
134                            });
135                    })
136                }
137            })
138            .collect::<io::Result<Vec<_>>>()?;
139
140        Ok(Self {
141            sender,
142            threads,
143            pool,
144        })
145    }
146
147    /// Create the dispatcher with default config.
148    pub fn new() -> io::Result<Self> {
149        Self::builder().build()
150    }
151
152    /// Create a builder to build a dispatcher.
153    pub fn builder() -> DispatcherBuilder {
154        DispatcherBuilder::default()
155    }
156
157    /// Dispatch a task to the threads
158    ///
159    /// The provided `f` should be [`Send`] because it will be send to another
160    /// thread before calling. The returned [`Future`] need not to be [`Send`]
161    /// because it will be executed on only one thread.
162    ///
163    /// # Error
164    ///
165    /// If all threads have panicked, this method will return an error with the
166    /// sent closure.
167    pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
168    where
169        Fn: (FnOnce() -> Fut) + Send + 'static,
170        Fut: Future<Output = R> + 'static,
171        R: Send + 'static,
172    {
173        let (concrete, rx) = Concrete::new(f);
174
175        match self.sender.send(Box::new(concrete)) {
176            Ok(_) => Ok(rx),
177            Err(err) => {
178                // SAFETY: We know the dispatchable we sent has type `Concrete<Fn, R>`
179                let recovered =
180                    unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
181                Err(DispatchError(recovered.func))
182            }
183        }
184    }
185
186    /// Dispatch a blocking task to the threads.
187    ///
188    /// Blocking pool of the dispatcher will be obtained from the proactor
189    /// builder. So any configuration of the proactor's blocking pool will be
190    /// applied to the dispatcher.
191    ///
192    /// # Error
193    ///
194    /// If all threads are busy and the thread pool is full, this method will
195    /// return an error with the original closure. The limit can be configured
196    /// with [`DispatcherBuilder::proactor_builder`] and
197    /// [`ProactorBuilder::thread_pool_limit`].
198    pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
199    where
200        Fn: FnOnce() -> R + Send + 'static,
201        R: Send + 'static,
202    {
203        let (concrete, rx) = Concrete::new(f);
204
205        self.pool
206            .dispatch(concrete)
207            .map_err(|e| DispatchError(e.0.func))?;
208
209        Ok(rx)
210    }
211
212    /// Stop the dispatcher and wait for the threads to complete. If there is a
213    /// thread panicked, this method will resume the panic.
214    pub async fn join(self) -> io::Result<()> {
215        drop(self.sender);
216        let (tx, rx) = oneshot::channel::<Vec<_>>();
217        if let Err(f) = self.pool.dispatch({
218            move || {
219                let results = self
220                    .threads
221                    .into_iter()
222                    .map(|thread| thread.join())
223                    .collect();
224                tx.send(results).ok();
225            }
226        }) {
227            std::thread::spawn(f.0);
228        }
229        let results = rx
230            .await
231            .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
232        for res in results {
233            res.unwrap_or_else(|e| resume_unwind(e));
234        }
235        Ok(())
236    }
237}
238
239/// A builder for [`Dispatcher`].
240pub struct DispatcherBuilder {
241    nthreads: usize,
242    concurrent: bool,
243    stack_size: Option<usize>,
244    thread_affinity: Option<Box<dyn FnMut(usize) -> HashSet<usize>>>,
245    names: Option<Box<dyn FnMut(usize) -> String>>,
246    proactor_builder: ProactorBuilder,
247}
248
249impl DispatcherBuilder {
250    /// Create a builder with default settings.
251    pub fn new() -> Self {
252        Self {
253            nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
254            concurrent: true,
255            stack_size: None,
256            thread_affinity: None,
257            names: None,
258            proactor_builder: ProactorBuilder::new(),
259        }
260    }
261
262    /// If execute tasks concurrently. Default to be `true`.
263    ///
264    /// When set to `false`, tasks are executed sequentially without any
265    /// concurrency within the thread.
266    pub fn concurrent(mut self, concurrent: bool) -> Self {
267        self.concurrent = concurrent;
268        self
269    }
270
271    /// Set the number of worker threads of the dispatcher. The default value is
272    /// the CPU number. If the CPU number could not be retrieved, the
273    /// default value is 1.
274    pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
275        self.nthreads = nthreads.get();
276        self
277    }
278
279    /// Set the size of stack of the worker threads.
280    pub fn stack_size(mut self, s: usize) -> Self {
281        self.stack_size = Some(s);
282        self
283    }
284
285    /// Set the thread affinity for the dispatcher.
286    pub fn thread_affinity(mut self, f: impl FnMut(usize) -> HashSet<usize> + 'static) -> Self {
287        self.thread_affinity = Some(Box::new(f));
288        self
289    }
290
291    /// Provide a function to assign names to the worker threads.
292    pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
293        self.names = Some(Box::new(f) as _);
294        self
295    }
296
297    /// Set the proactor builder for the inner runtimes.
298    pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
299        self.proactor_builder = builder;
300        self
301    }
302
303    /// Build the [`Dispatcher`].
304    pub fn build(self) -> io::Result<Dispatcher> {
305        Dispatcher::new_impl(self)
306    }
307}
308
309impl Default for DispatcherBuilder {
310    fn default() -> Self {
311        Self::new()
312    }
313}