Skip to main content

compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::{Cell, RefCell},
4    collections::HashSet,
5    fmt::Debug,
6    future::Future,
7    io,
8    ops::Deref,
9    panic::AssertUnwindSafe,
10    rc::Rc,
11    sync::Arc,
12    task::{Context, Poll, Waker},
13    time::Duration,
14};
15
16use async_task::Task;
17use compio_buf::IntoInner;
18use compio_driver::{
19    AsRawFd, Cancel, DriverType, Extra, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd,
20    op::Asyncify,
21};
22use compio_log::{debug, instrument};
23use futures_util::FutureExt;
24
25mod future;
26pub use future::*;
27
28mod stream;
29pub use stream::*;
30
31#[cfg(feature = "time")]
32pub(crate) mod time;
33
34mod buffer_pool;
35pub use buffer_pool::*;
36
37mod scheduler;
38
39mod opt_waker;
40pub use opt_waker::OptWaker;
41
42#[cfg(feature = "time")]
43use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
44use crate::{BufResult, affinity::bind_to_cpu_set, runtime::scheduler::Scheduler};
45
46scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
47
48/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
49/// `Err` when the spawned future panicked.
50pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
51
52thread_local! {
53    static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
54}
55
56#[cold]
57fn not_in_compio_runtime() -> ! {
58    panic!("not in a compio runtime")
59}
60
61/// Inner structure of [`Runtime`].
62pub struct RuntimeInner {
63    driver: RefCell<Proactor>,
64    scheduler: Scheduler,
65    #[cfg(feature = "time")]
66    timer_runtime: RefCell<TimerRuntime>,
67    // Runtime id is used to check if the buffer pool is belonged to this runtime or not.
68    // Without this, if user enable `io-uring-buf-ring` feature then:
69    // 1. Create a buffer pool at runtime1
70    // 3. Create another runtime2, then use the exists buffer pool in runtime2, it may cause
71    // - io-uring report error if the buffer group id is not registered
72    // - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause
73    //   UB
74    id: u64,
75}
76
77/// The async runtime of compio.
78///
79/// It is a thread-local runtime, meaning it cannot be sent to other threads.
80#[derive(Clone)]
81pub struct Runtime(Rc<RuntimeInner>);
82
83impl Debug for Runtime {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        let mut s = f.debug_struct("Runtime");
86        s.field("driver", &"...").field("scheduler", &"...");
87        #[cfg(feature = "time")]
88        s.field("timer_runtime", &"...");
89        s.field("id", &self.id).finish()
90    }
91}
92
93impl Deref for Runtime {
94    type Target = RuntimeInner;
95
96    fn deref(&self) -> &Self::Target {
97        &self.0
98    }
99}
100
101impl Runtime {
102    /// Create [`Runtime`] with default config.
103    pub fn new() -> io::Result<Self> {
104        Self::builder().build()
105    }
106
107    /// Create a builder for [`Runtime`].
108    pub fn builder() -> RuntimeBuilder {
109        RuntimeBuilder::new()
110    }
111
112    /// The current driver type.
113    pub fn driver_type(&self) -> DriverType {
114        self.driver.borrow().driver_type()
115    }
116
117    /// Try to perform a function on the current runtime, and if no runtime is
118    /// running, return the function back.
119    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
120        if CURRENT_RUNTIME.is_set() {
121            Ok(CURRENT_RUNTIME.with(f))
122        } else {
123            Err(f)
124        }
125    }
126
127    /// Perform a function on the current runtime.
128    ///
129    /// ## Panics
130    ///
131    /// This method will panic if there is no running [`Runtime`].
132    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
133        if CURRENT_RUNTIME.is_set() {
134            CURRENT_RUNTIME.with(f)
135        } else {
136            not_in_compio_runtime()
137        }
138    }
139
140    /// Try to get the current runtime, and if no runtime is running, return
141    /// `None`.
142    pub fn try_current() -> Option<Self> {
143        if CURRENT_RUNTIME.is_set() {
144            Some(CURRENT_RUNTIME.with(|r| r.clone()))
145        } else {
146            None
147        }
148    }
149
150    /// Get the current runtime.
151    ///
152    /// # Panics
153    ///
154    /// This method will panic if there is no running [`Runtime`].
155    pub fn current() -> Self {
156        if CURRENT_RUNTIME.is_set() {
157            CURRENT_RUNTIME.with(|r| r.clone())
158        } else {
159            not_in_compio_runtime()
160        }
161    }
162
163    /// Set this runtime as current runtime, and perform a function in the
164    /// current scope.
165    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
166        CURRENT_RUNTIME.set(self, f)
167    }
168
169    fn spawn_impl<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
170        unsafe { self.spawn_unchecked(future) }
171    }
172
173    /// Low level API to control the runtime.
174    ///
175    /// Spawns a new asynchronous task, returning a [`Task`] for it.
176    ///
177    /// # Safety
178    ///
179    /// Borrowed variables must outlive the future.
180    pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
181        let waker = self.waker();
182        unsafe { self.scheduler.spawn_unchecked(future, waker) }
183    }
184
185    /// Low level API to control the runtime.
186    ///
187    /// Run the scheduled tasks.
188    ///
189    /// The return value indicates whether there are still tasks in the queue.
190    pub fn run(&self) -> bool {
191        self.scheduler.run()
192    }
193
194    /// Low level API to control the runtime.
195    ///
196    /// Create a waker that always notifies the runtime when woken.
197    pub fn waker(&self) -> Waker {
198        self.driver.borrow().waker()
199    }
200
201    /// Low level API to control the runtime.
202    ///
203    /// Create an optimized waker that only notifies the runtime when woken
204    /// from another thread, or when `notify-always` is enabled.
205    pub fn opt_waker(&self) -> Arc<OptWaker> {
206        OptWaker::new(self.waker())
207    }
208
209    /// Block on the future till it completes.
210    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
211        self.enter(|| {
212            let opt_waker = self.opt_waker();
213            let waker = Waker::from(opt_waker.clone());
214            let mut context = Context::from_waker(&waker);
215            let mut future = std::pin::pin!(future);
216            loop {
217                if let Poll::Ready(result) = future.as_mut().poll(&mut context) {
218                    self.run();
219                    return result;
220                }
221                // We always want to reset the waker here.
222                let remaining_tasks = self.run() | opt_waker.reset();
223                if remaining_tasks {
224                    self.poll_with(Some(Duration::ZERO));
225                } else {
226                    self.poll();
227                }
228            }
229        })
230    }
231
232    /// Spawns a new asynchronous task, returning a [`Task`] for it.
233    ///
234    /// Spawning a task enables the task to execute concurrently to other tasks.
235    /// There is no guarantee that a spawned task will execute to completion.
236    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
237        self.spawn_impl(AssertUnwindSafe(future).catch_unwind())
238    }
239
240    /// Spawns a blocking task in a new thread, and wait for it.
241    ///
242    /// The task will not be cancelled even if the future is dropped.
243    pub fn spawn_blocking<T: Send + 'static>(
244        &self,
245        f: impl (FnOnce() -> T) + Send + 'static,
246    ) -> JoinHandle<T> {
247        let op = Asyncify::new(move || {
248            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
249            BufResult(Ok(0), res)
250        });
251        // It is safe and sound to use `submit` here because the task is spawned
252        // immediately.
253        self.spawn_impl(self.submit(op).map(|res| res.1.into_inner()))
254    }
255
256    /// Attach a raw file descriptor/handle/socket to the runtime.
257    ///
258    /// You only need this when authoring your own high-level APIs. High-level
259    /// resources in this crate are attached automatically.
260    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
261        self.driver.borrow_mut().attach(fd)
262    }
263
264    fn submit_raw<T: OpCode + 'static>(
265        &self,
266        op: T,
267        extra: Option<Extra>,
268    ) -> PushEntry<Key<T>, BufResult<usize, T>> {
269        let mut this = self.driver.borrow_mut();
270        match extra {
271            Some(e) => this.push_with_extra(op, e),
272            None => this.push(op),
273        }
274    }
275
276    fn default_extra(&self) -> Extra {
277        self.driver.borrow().default_extra()
278    }
279
280    /// Submit an operation to the runtime.
281    ///
282    /// You only need this when authoring your own [`OpCode`].
283    pub fn submit<T: OpCode + 'static>(&self, op: T) -> Submit<T> {
284        Submit::new(self.clone(), op)
285    }
286
287    /// Submit a multishot operation to the runtime.
288    ///
289    /// You only need this when authoring your own [`OpCode`].
290    pub fn submit_multi<T: OpCode + 'static>(&self, op: T) -> SubmitMulti<T> {
291        SubmitMulti::new(self.clone(), op)
292    }
293
294    pub(crate) fn cancel<T: OpCode>(&self, key: Key<T>) {
295        self.driver.borrow_mut().cancel(key);
296    }
297
298    pub(crate) fn register_cancel<T: OpCode>(&self, key: &Key<T>) -> Cancel {
299        self.driver.borrow_mut().register_cancel(key)
300    }
301
302    pub(crate) fn cancel_token(&self, token: Cancel) -> bool {
303        self.driver.borrow_mut().cancel_token(token)
304    }
305
306    #[cfg(feature = "time")]
307    pub(crate) fn cancel_timer(&self, key: &TimerKey) {
308        self.timer_runtime.borrow_mut().cancel(key);
309    }
310
311    pub(crate) fn poll_task<T: OpCode>(
312        &self,
313        waker: &Waker,
314        key: Key<T>,
315    ) -> PushEntry<Key<T>, BufResult<usize, T>> {
316        instrument!(compio_log::Level::DEBUG, "poll_task", ?key);
317        let mut driver = self.driver.borrow_mut();
318        driver.pop(key).map_pending(|k| {
319            driver.update_waker(&k, waker);
320            k
321        })
322    }
323
324    pub(crate) fn poll_task_with_extra<T: OpCode>(
325        &self,
326        waker: &Waker,
327        key: Key<T>,
328    ) -> PushEntry<Key<T>, (BufResult<usize, T>, Extra)> {
329        instrument!(compio_log::Level::DEBUG, "poll_task_with_extra", ?key);
330        let mut driver = self.driver.borrow_mut();
331        driver.pop_with_extra(key).map_pending(|k| {
332            driver.update_waker(&k, waker);
333            k
334        })
335    }
336
337    pub(crate) fn poll_multishot<T: OpCode>(
338        &self,
339        waker: &Waker,
340        key: &Key<T>,
341    ) -> Option<BufResult<usize, Extra>> {
342        instrument!(compio_log::Level::DEBUG, "poll_multishot", ?key);
343        let mut driver = self.driver.borrow_mut();
344        if let Some(res) = driver.pop_multishot(key) {
345            return Some(res);
346        }
347        driver.update_waker(key, waker);
348        None
349    }
350
351    #[cfg(feature = "time")]
352    pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> {
353        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
354        let mut timer_runtime = self.timer_runtime.borrow_mut();
355        if timer_runtime.is_completed(key) {
356            debug!("ready");
357            Poll::Ready(())
358        } else {
359            debug!("pending");
360            timer_runtime.update_waker(key, cx.waker());
361            Poll::Pending
362        }
363    }
364
365    /// Low level API to control the runtime.
366    ///
367    /// Get the timeout value to be passed to [`Proactor::poll`].
368    pub fn current_timeout(&self) -> Option<Duration> {
369        #[cfg(not(feature = "time"))]
370        let timeout = None;
371        #[cfg(feature = "time")]
372        let timeout = self.timer_runtime.borrow().min_timeout();
373        timeout
374    }
375
376    /// Low level API to control the runtime.
377    ///
378    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
379    /// with [`Runtime::current_timeout`].
380    pub fn poll(&self) {
381        instrument!(compio_log::Level::DEBUG, "poll");
382        let timeout = self.current_timeout();
383        debug!("timeout: {:?}", timeout);
384        self.poll_with(timeout)
385    }
386
387    /// Low level API to control the runtime.
388    ///
389    /// Poll the inner proactor with a custom timeout.
390    pub fn poll_with(&self, timeout: Option<Duration>) {
391        instrument!(compio_log::Level::DEBUG, "poll_with");
392
393        let mut driver = self.driver.borrow_mut();
394        match driver.poll(timeout) {
395            Ok(()) => {}
396            Err(e) => match e.kind() {
397                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
398                    debug!("expected error: {e}");
399                }
400                _ => panic!("{e:?}"),
401            },
402        }
403        #[cfg(feature = "time")]
404        self.timer_runtime.borrow_mut().wake();
405    }
406
407    pub(crate) fn create_buffer_pool(
408        &self,
409        buffer_len: u16,
410        buffer_size: usize,
411    ) -> io::Result<compio_driver::BufferPool> {
412        self.driver
413            .borrow_mut()
414            .create_buffer_pool(buffer_len, buffer_size)
415    }
416
417    pub(crate) unsafe fn release_buffer_pool(
418        &self,
419        buffer_pool: compio_driver::BufferPool,
420    ) -> io::Result<()> {
421        unsafe { self.driver.borrow_mut().release_buffer_pool(buffer_pool) }
422    }
423
424    pub(crate) fn id(&self) -> u64 {
425        self.id
426    }
427
428    /// Register file descriptors for fixed-file operations.
429    ///
430    /// This is only supported on io-uring driver, and will return an
431    /// [`Unsupported`] io error on all other drivers.
432    ///
433    /// [`Unsupported`]: std::io::ErrorKind::Unsupported
434    pub fn register_files(&self, fds: &[RawFd]) -> io::Result<()> {
435        self.driver.borrow_mut().register_files(fds)
436    }
437
438    /// Unregister previously registered file descriptors.
439    ///
440    /// This is only supported on io-uring driver, and will return an
441    /// [`Unsupported`] io error on all other drivers.
442    ///
443    /// [`Unsupported`]: std::io::ErrorKind::Unsupported
444    pub fn unregister_files(&self) -> io::Result<()> {
445        self.driver.borrow_mut().unregister_files()
446    }
447
448    /// Register the personality for the runtime.
449    ///
450    /// This is only supported on io-uring driver, and will return an
451    /// [`Unsupported`] io error on all other drivers.
452    ///
453    /// The returned personality can be used with `FutureExt::with_personality`
454    /// if the `future-combinator` feature is turned on.
455    ///
456    /// [`Unsupported`]: std::io::ErrorKind::Unsupported
457    pub fn register_personality(&self) -> io::Result<u16> {
458        self.driver.borrow_mut().register_personality()
459    }
460
461    /// Unregister the given personality for the runtime.
462    ///
463    /// This is only supported on io-uring driver, and will return an
464    /// [`Unsupported`] io error on all other drivers.
465    ///
466    /// [`Unsupported`]: std::io::ErrorKind::Unsupported
467    pub fn unregister_personality(&self, personality: u16) -> io::Result<()> {
468        self.driver.borrow_mut().unregister_personality(personality)
469    }
470}
471
472impl Drop for Runtime {
473    fn drop(&mut self) {
474        // this is not the last runtime reference, no need to clear
475        if Rc::strong_count(&self.0) > 1 {
476            return;
477        }
478
479        self.enter(|| {
480            self.scheduler.clear();
481        })
482    }
483}
484
485impl AsRawFd for Runtime {
486    fn as_raw_fd(&self) -> RawFd {
487        self.driver.borrow().as_raw_fd()
488    }
489}
490
491#[cfg(feature = "criterion")]
492impl criterion::async_executor::AsyncExecutor for Runtime {
493    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
494        self.block_on(future)
495    }
496}
497
498#[cfg(feature = "criterion")]
499impl criterion::async_executor::AsyncExecutor for &Runtime {
500    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
501        (**self).block_on(future)
502    }
503}
504
505/// Builder for [`Runtime`].
506#[derive(Debug, Clone)]
507pub struct RuntimeBuilder {
508    proactor_builder: ProactorBuilder,
509    thread_affinity: HashSet<usize>,
510    event_interval: usize,
511}
512
513impl Default for RuntimeBuilder {
514    fn default() -> Self {
515        Self::new()
516    }
517}
518
519impl RuntimeBuilder {
520    /// Create the builder with default config.
521    pub fn new() -> Self {
522        Self {
523            proactor_builder: ProactorBuilder::new(),
524            event_interval: 61,
525            thread_affinity: HashSet::new(),
526        }
527    }
528
529    /// Replace proactor builder.
530    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
531        self.proactor_builder = builder;
532        self
533    }
534
535    /// Sets the thread affinity for the runtime.
536    pub fn thread_affinity(&mut self, cpus: HashSet<usize>) -> &mut Self {
537        self.thread_affinity = cpus;
538        self
539    }
540
541    /// Sets the number of scheduler ticks after which the scheduler will poll
542    /// for external events (timers, I/O, and so on).
543    ///
544    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
545    pub fn event_interval(&mut self, val: usize) -> &mut Self {
546        self.event_interval = val;
547        self
548    }
549
550    /// Build [`Runtime`].
551    pub fn build(&self) -> io::Result<Runtime> {
552        let RuntimeBuilder {
553            proactor_builder,
554            thread_affinity,
555            event_interval,
556        } = self;
557        let id = RUNTIME_ID.get();
558        RUNTIME_ID.set(id + 1);
559        if !thread_affinity.is_empty() {
560            bind_to_cpu_set(thread_affinity);
561        }
562        let inner = RuntimeInner {
563            driver: RefCell::new(proactor_builder.build()?),
564            scheduler: Scheduler::new(*event_interval),
565            #[cfg(feature = "time")]
566            timer_runtime: RefCell::new(TimerRuntime::new()),
567            id,
568        };
569        Ok(Runtime(Rc::new(inner)))
570    }
571}
572
573/// Spawns a new asynchronous task, returning a [`Task`] for it.
574///
575/// Spawning a task enables the task to execute concurrently to other tasks.
576/// There is no guarantee that a spawned task will execute to completion.
577///
578/// ```
579/// # compio_runtime::Runtime::new().unwrap().block_on(async {
580/// let task = compio_runtime::spawn(async {
581///     println!("Hello from a spawned task!");
582///     42
583/// });
584///
585/// assert_eq!(
586///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
587///     42
588/// );
589/// # })
590/// ```
591///
592/// ## Panics
593///
594/// This method doesn't create runtime. It tries to obtain the current runtime
595/// by [`Runtime::with_current`].
596pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
597    Runtime::with_current(|r| r.spawn(future))
598}
599
600/// Spawns a blocking task in a new thread, and wait for it.
601///
602/// The task will not be cancelled even if the future is dropped.
603///
604/// ## Panics
605///
606/// This method doesn't create runtime. It tries to obtain the current runtime
607/// by [`Runtime::with_current`].
608pub fn spawn_blocking<T: Send + 'static>(
609    f: impl (FnOnce() -> T) + Send + 'static,
610) -> JoinHandle<T> {
611    Runtime::with_current(|r| r.spawn_blocking(f))
612}
613
614/// Submit an operation to the current runtime, and return a future for it.
615///
616/// ## Panics
617///
618/// This method doesn't create runtime and will panic if it's not within a
619/// runtime. It tries to obtain the current runtime with
620/// [`Runtime::with_current`].
621pub fn submit<T: OpCode + 'static>(op: T) -> Submit<T> {
622    Runtime::with_current(|r| r.submit(op))
623}
624
625/// Submit a multishot operation to the current runtime, and return a stream for
626/// it.
627///
628/// ## Panics
629///
630/// This method doesn't create runtime and will panic if it's not within a
631/// runtime. It tries to obtain the current runtime with
632/// [`Runtime::with_current`].
633pub fn submit_multi<T: OpCode + 'static>(op: T) -> SubmitMulti<T> {
634    Runtime::with_current(|r| r.submit_multi(op))
635}
636
637/// Register file descriptors for fixed-file operations with the current
638/// runtime's io_uring instance.
639///
640/// This only works on `io_uring` driver. It will return an [`Unsupported`]
641/// error on other drivers.
642///
643/// ## Panics
644///
645/// This method doesn't create runtime. It tries to obtain the current runtime
646/// by [`Runtime::with_current`].
647///
648/// [`Unsupported`]: std::io::ErrorKind::Unsupported
649pub fn register_files(fds: &[RawFd]) -> io::Result<()> {
650    Runtime::with_current(|r| r.register_files(fds))
651}
652
653/// Unregister previously registered file descriptors from the current
654/// runtime's io_uring instance.
655///
656/// This only works on `io_uring` driver. It will return an [`Unsupported`]
657/// error on other drivers.
658///
659/// ## Panics
660///
661/// This method doesn't create runtime. It tries to obtain the current runtime
662/// by [`Runtime::with_current`].
663///
664/// [`Unsupported`]: std::io::ErrorKind::Unsupported
665pub fn unregister_files() -> io::Result<()> {
666    Runtime::with_current(|r| r.unregister_files())
667}
668
669#[cfg(feature = "time")]
670pub(crate) async fn create_timer(instant: std::time::Instant) {
671    let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
672    if let Some(key) = key {
673        TimerFuture::new(key).await
674    }
675}