1#![allow(unused_features)]
4#![warn(missing_docs)]
5#![deny(rustdoc::broken_intra_doc_links)]
6#![doc(
7 html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
8)]
9#![doc(
10 html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
11)]
12
13use std::{
14 collections::HashSet,
15 future::Future,
16 io,
17 num::NonZeroUsize,
18 panic::resume_unwind,
19 thread::{JoinHandle, available_parallelism},
20};
21
22use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
23use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
24use flume::{Sender, unbounded};
25use futures_channel::oneshot;
26
27#[cfg(unix)]
28mod unix;
29
30type Spawning = Box<dyn Spawnable + Send>;
31
32trait Spawnable {
33 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
34}
35
36struct Concrete<F, R> {
38 callback: oneshot::Sender<R>,
39 func: F,
40}
41
42impl<F, R> Concrete<F, R> {
43 pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
44 let (tx, rx) = oneshot::channel();
45 (Self { callback: tx, func }, rx)
46 }
47}
48
49impl<F, Fut, R> Spawnable for Concrete<F, R>
50where
51 F: FnOnce() -> Fut + Send + 'static,
52 Fut: Future<Output = R>,
53 R: Send + 'static,
54{
55 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
56 let Concrete { callback, func } = *self;
57 handle.spawn(async move {
58 let res = func().await;
59 callback.send(res).ok();
60 })
61 }
62}
63
64impl<F, R> Dispatchable for Concrete<F, R>
65where
66 F: FnOnce() -> R + Send + 'static,
67 R: Send + 'static,
68{
69 fn run(self: Box<Self>) {
70 let Concrete { callback, func } = *self;
71 let res = func();
72 callback.send(res).ok();
73 }
74}
75
76#[derive(Debug)]
78pub struct Dispatcher {
79 sender: Sender<Spawning>,
80 threads: Vec<JoinHandle<()>>,
81 pool: AsyncifyPool,
82}
83
84impl Dispatcher {
85 pub(crate) fn new_impl(builder: DispatcherBuilder) -> io::Result<Self> {
87 let DispatcherBuilder {
88 nthreads,
89 concurrent,
90 #[cfg(unix)]
91 block_signals,
92 stack_size,
93 mut thread_affinity,
94 mut names,
95 mut proactor_builder,
96 } = builder;
97 proactor_builder.force_reuse_thread_pool();
98 let pool = proactor_builder.create_or_get_thread_pool();
99 let (sender, receiver) = unbounded::<Spawning>();
100
101 #[cfg(unix)]
103 let _g = unix::mask_signal(block_signals)?;
104
105 let threads = (0..nthreads)
106 .map({
107 |index| {
108 let proactor_builder = proactor_builder.clone();
109 let receiver = receiver.clone();
110
111 let thread_builder = std::thread::Builder::new();
112 let thread_builder = if let Some(s) = stack_size {
113 thread_builder.stack_size(s)
114 } else {
115 thread_builder
116 };
117 let thread_builder = if let Some(f) = &mut names {
118 thread_builder.name(f(index))
119 } else {
120 thread_builder
121 };
122
123 let cpus = if let Some(f) = &mut thread_affinity {
124 f(index)
125 } else {
126 HashSet::new()
127 };
128 thread_builder.spawn(move || {
129 Runtime::builder()
130 .with_proactor(proactor_builder)
131 .thread_affinity(cpus)
132 .build()
133 .expect("cannot create compio runtime")
134 .block_on(async move {
135 while let Ok(f) = receiver.recv_async().await {
136 let task = Runtime::with_current(|rt| f.spawn(rt));
137 if concurrent {
138 task.detach()
139 } else {
140 task.await.ok();
141 }
142 }
143 });
144 })
145 }
146 })
147 .collect::<io::Result<Vec<_>>>()?;
148
149 Ok(Self {
150 sender,
151 threads,
152 pool,
153 })
154 }
155
156 pub fn new() -> io::Result<Self> {
158 Self::builder().build()
159 }
160
161 pub fn builder() -> DispatcherBuilder {
163 DispatcherBuilder::default()
164 }
165
166 pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
177 where
178 Fn: (FnOnce() -> Fut) + Send + 'static,
179 Fut: Future<Output = R> + 'static,
180 R: Send + 'static,
181 {
182 let (concrete, rx) = Concrete::new(f);
183
184 match self.sender.send(Box::new(concrete)) {
185 Ok(_) => Ok(rx),
186 Err(err) => {
187 let recovered =
189 unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
190 Err(DispatchError(recovered.func))
191 }
192 }
193 }
194
195 pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
208 where
209 Fn: FnOnce() -> R + Send + 'static,
210 R: Send + 'static,
211 {
212 let (concrete, rx) = Concrete::new(f);
213
214 self.pool
215 .dispatch(concrete)
216 .map_err(|e| DispatchError(e.0.func))?;
217
218 Ok(rx)
219 }
220
221 pub async fn join(self) -> io::Result<()> {
224 drop(self.sender);
225 let (tx, rx) = oneshot::channel::<Vec<_>>();
226 if let Err(f) = self.pool.dispatch({
227 move || {
228 let results = self
229 .threads
230 .into_iter()
231 .map(|thread| thread.join())
232 .collect();
233 tx.send(results).ok();
234 }
235 }) {
236 std::thread::spawn(f.0);
237 }
238 let results = rx
239 .await
240 .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
241 for res in results {
242 res.unwrap_or_else(|e| resume_unwind(e));
243 }
244 Ok(())
245 }
246}
247
248pub struct DispatcherBuilder {
250 nthreads: usize,
251 concurrent: bool,
252 #[cfg(unix)]
253 block_signals: bool,
254 stack_size: Option<usize>,
255 thread_affinity: Option<Box<dyn FnMut(usize) -> HashSet<usize>>>,
256 names: Option<Box<dyn FnMut(usize) -> String>>,
257 proactor_builder: ProactorBuilder,
258}
259
260impl DispatcherBuilder {
261 pub fn new() -> Self {
263 Self {
264 nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
265 concurrent: true,
266 #[cfg(unix)]
267 block_signals: true,
268 stack_size: None,
269 thread_affinity: None,
270 names: None,
271 proactor_builder: ProactorBuilder::new(),
272 }
273 }
274
275 pub fn concurrent(mut self, concurrent: bool) -> Self {
280 self.concurrent = concurrent;
281 self
282 }
283
284 pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
288 self.nthreads = nthreads.get();
289 self
290 }
291
292 pub fn stack_size(mut self, s: usize) -> Self {
294 self.stack_size = Some(s);
295 self
296 }
297
298 #[allow(unused)]
317 pub fn block_signals(mut self, block_signals: bool) -> Self {
318 #[cfg(unix)]
319 {
320 self.block_signals = block_signals;
321 }
322 self
323 }
324
325 pub fn thread_affinity(mut self, f: impl FnMut(usize) -> HashSet<usize> + 'static) -> Self {
327 self.thread_affinity = Some(Box::new(f));
328 self
329 }
330
331 pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
333 self.names = Some(Box::new(f) as _);
334 self
335 }
336
337 pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
339 self.proactor_builder = builder;
340 self
341 }
342
343 pub fn build(self) -> io::Result<Dispatcher> {
345 Dispatcher::new_impl(self)
346 }
347}
348
349impl Default for DispatcherBuilder {
350 fn default() -> Self {
351 Self::new()
352 }
353}