1#![warn(missing_docs)]
4#![deny(rustdoc::broken_intra_doc_links)]
5#![doc(
6 html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
7)]
8#![doc(
9 html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
10)]
11
12use std::{
13 collections::HashSet,
14 future::Future,
15 io,
16 num::NonZeroUsize,
17 panic::resume_unwind,
18 thread::{JoinHandle, available_parallelism},
19};
20
21use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
22use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
23use flume::{Sender, unbounded};
24use futures_channel::oneshot;
25
26type Spawning = Box<dyn Spawnable + Send>;
27
28trait Spawnable {
29 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
30}
31
32struct Concrete<F, R> {
34 callback: oneshot::Sender<R>,
35 func: F,
36}
37
38impl<F, R> Concrete<F, R> {
39 pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
40 let (tx, rx) = oneshot::channel();
41 (Self { callback: tx, func }, rx)
42 }
43}
44
45impl<F, Fut, R> Spawnable for Concrete<F, R>
46where
47 F: FnOnce() -> Fut + Send + 'static,
48 Fut: Future<Output = R>,
49 R: Send + 'static,
50{
51 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
52 let Concrete { callback, func } = *self;
53 handle.spawn(async move {
54 let res = func().await;
55 callback.send(res).ok();
56 })
57 }
58}
59
60impl<F, R> Dispatchable for Concrete<F, R>
61where
62 F: FnOnce() -> R + Send + 'static,
63 R: Send + 'static,
64{
65 fn run(self: Box<Self>) {
66 let Concrete { callback, func } = *self;
67 let res = func();
68 callback.send(res).ok();
69 }
70}
71
72#[derive(Debug)]
74pub struct Dispatcher {
75 sender: Sender<Spawning>,
76 threads: Vec<JoinHandle<()>>,
77 pool: AsyncifyPool,
78}
79
80impl Dispatcher {
81 pub(crate) fn new_impl(builder: DispatcherBuilder) -> io::Result<Self> {
83 let DispatcherBuilder {
84 nthreads,
85 concurrent,
86 stack_size,
87 mut thread_affinity,
88 mut names,
89 mut proactor_builder,
90 } = builder;
91 proactor_builder.force_reuse_thread_pool();
92 let pool = proactor_builder.create_or_get_thread_pool();
93 let (sender, receiver) = unbounded::<Spawning>();
94
95 let threads = (0..nthreads)
96 .map({
97 |index| {
98 let proactor_builder = proactor_builder.clone();
99 let receiver = receiver.clone();
100
101 let thread_builder = std::thread::Builder::new();
102 let thread_builder = if let Some(s) = stack_size {
103 thread_builder.stack_size(s)
104 } else {
105 thread_builder
106 };
107 let thread_builder = if let Some(f) = &mut names {
108 thread_builder.name(f(index))
109 } else {
110 thread_builder
111 };
112
113 let cpus = if let Some(f) = &mut thread_affinity {
114 f(index)
115 } else {
116 HashSet::new()
117 };
118 thread_builder.spawn(move || {
119 Runtime::builder()
120 .with_proactor(proactor_builder)
121 .thread_affinity(cpus)
122 .build()
123 .expect("cannot create compio runtime")
124 .block_on(async move {
125 while let Ok(f) = receiver.recv_async().await {
126 let task = Runtime::with_current(|rt| f.spawn(rt));
127 if concurrent {
128 task.detach()
129 } else {
130 task.await.ok();
131 }
132 }
133 });
134 })
135 }
136 })
137 .collect::<io::Result<Vec<_>>>()?;
138 Ok(Self {
139 sender,
140 threads,
141 pool,
142 })
143 }
144
145 pub fn new() -> io::Result<Self> {
147 Self::builder().build()
148 }
149
150 pub fn builder() -> DispatcherBuilder {
152 DispatcherBuilder::default()
153 }
154
155 pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
166 where
167 Fn: (FnOnce() -> Fut) + Send + 'static,
168 Fut: Future<Output = R> + 'static,
169 R: Send + 'static,
170 {
171 let (concrete, rx) = Concrete::new(f);
172
173 match self.sender.send(Box::new(concrete)) {
174 Ok(_) => Ok(rx),
175 Err(err) => {
176 let recovered =
178 unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
179 Err(DispatchError(recovered.func))
180 }
181 }
182 }
183
184 pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
197 where
198 Fn: FnOnce() -> R + Send + 'static,
199 R: Send + 'static,
200 {
201 let (concrete, rx) = Concrete::new(f);
202
203 self.pool
204 .dispatch(concrete)
205 .map_err(|e| DispatchError(e.0.func))?;
206
207 Ok(rx)
208 }
209
210 pub async fn join(self) -> io::Result<()> {
213 drop(self.sender);
214 let (tx, rx) = oneshot::channel::<Vec<_>>();
215 if let Err(f) = self.pool.dispatch({
216 move || {
217 let results = self
218 .threads
219 .into_iter()
220 .map(|thread| thread.join())
221 .collect();
222 tx.send(results).ok();
223 }
224 }) {
225 std::thread::spawn(f.0);
226 }
227 let results = rx
228 .await
229 .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
230 for res in results {
231 res.unwrap_or_else(|e| resume_unwind(e));
232 }
233 Ok(())
234 }
235}
236
237pub struct DispatcherBuilder {
239 nthreads: usize,
240 concurrent: bool,
241 stack_size: Option<usize>,
242 thread_affinity: Option<Box<dyn FnMut(usize) -> HashSet<usize>>>,
243 names: Option<Box<dyn FnMut(usize) -> String>>,
244 proactor_builder: ProactorBuilder,
245}
246
247impl DispatcherBuilder {
248 pub fn new() -> Self {
250 Self {
251 nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
252 concurrent: true,
253 stack_size: None,
254 thread_affinity: None,
255 names: None,
256 proactor_builder: ProactorBuilder::new(),
257 }
258 }
259
260 pub fn concurrent(mut self, concurrent: bool) -> Self {
265 self.concurrent = concurrent;
266 self
267 }
268
269 pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
273 self.nthreads = nthreads.get();
274 self
275 }
276
277 pub fn stack_size(mut self, s: usize) -> Self {
279 self.stack_size = Some(s);
280 self
281 }
282
283 pub fn thread_affinity(mut self, f: impl FnMut(usize) -> HashSet<usize> + 'static) -> Self {
285 self.thread_affinity = Some(Box::new(f));
286 self
287 }
288
289 pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
291 self.names = Some(Box::new(f) as _);
292 self
293 }
294
295 pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
297 self.proactor_builder = builder;
298 self
299 }
300
301 pub fn build(self) -> io::Result<Dispatcher> {
303 Dispatcher::new_impl(self)
304 }
305}
306
307impl Default for DispatcherBuilder {
308 fn default() -> Self {
309 Self::new()
310 }
311}