Skip to main content

compio_driver/sys/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{
5    collections::HashSet,
6    io,
7    marker::PhantomData,
8    os::fd::FromRawFd,
9    pin::Pin,
10    sync::Arc,
11    task::{Poll, Wake, Waker},
12    time::Duration,
13};
14
15use compio_buf::BufResult;
16use compio_log::{instrument, trace, warn};
17cfg_if::cfg_if! {
18    if #[cfg(feature = "io-uring-cqe32")] {
19        use io_uring::cqueue::Entry32 as CEntry;
20    } else {
21        use io_uring::cqueue::Entry as CEntry;
22    }
23}
24cfg_if::cfg_if! {
25    if #[cfg(feature = "io-uring-sqe128")] {
26        use io_uring::squeue::Entry128 as SEntry;
27    } else {
28        use io_uring::squeue::Entry as SEntry;
29    }
30}
31use flume::{Receiver, Sender};
32use io_uring::{
33    IoUring,
34    cqueue::more,
35    opcode::{AsyncCancel, PollAdd},
36    types::{Fd, SubmitArgs, Timespec},
37};
38use slab::Slab;
39
40use crate::{
41    AsyncifyPool, BufferPool, DriverType, Entry, ProactorBuilder,
42    key::{BorrowedKey, ErasedKey, RefExt},
43    syscall,
44};
45
46mod extra;
47pub(in crate::sys) use extra::Extra;
48pub(crate) mod op;
49pub(crate) use op::take_buffer;
50
51pub(crate) fn is_op_supported(code: u8) -> bool {
52    #[cfg(feature = "once_cell_try")]
53    use std::sync::OnceLock;
54
55    #[cfg(not(feature = "once_cell_try"))]
56    use once_cell::sync::OnceCell as OnceLock;
57
58    static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
59
60    PROBE
61        .get_or_try_init(|| {
62            let mut probe = io_uring::Probe::new();
63
64            io_uring::IoUring::new(2)?
65                .submitter()
66                .register_probe(&mut probe)?;
67
68            std::io::Result::Ok(probe)
69        })
70        .map(|probe| probe.is_supported(code))
71        .unwrap_or_default()
72}
73
74/// The created entry of [`OpCode`].
75pub enum OpEntry {
76    /// This operation creates an io-uring submission entry.
77    Submission(io_uring::squeue::Entry),
78    #[cfg(feature = "io-uring-sqe128")]
79    /// This operation creates an 128-bit io-uring submission entry.
80    Submission128(io_uring::squeue::Entry128),
81    /// This operation is a blocking one.
82    Blocking,
83}
84
85impl OpEntry {
86    fn personality(self, personality: Option<u16>) -> Self {
87        let Some(personality) = personality else {
88            return self;
89        };
90
91        match self {
92            Self::Submission(entry) => Self::Submission(entry.personality(personality)),
93            #[cfg(feature = "io-uring-sqe128")]
94            Self::Submission128(entry) => Self::Submission128(entry.personality(personality)),
95            Self::Blocking => Self::Blocking,
96        }
97    }
98}
99
100impl From<io_uring::squeue::Entry> for OpEntry {
101    fn from(value: io_uring::squeue::Entry) -> Self {
102        Self::Submission(value)
103    }
104}
105
106#[cfg(feature = "io-uring-sqe128")]
107impl From<io_uring::squeue::Entry128> for OpEntry {
108    fn from(value: io_uring::squeue::Entry128) -> Self {
109        Self::Submission128(value)
110    }
111}
112
113/// Abstraction of io-uring operations.
114///
115/// # Safety
116///
117/// The returned Entry from `create_entry` must be valid until the operation is
118/// completed.
119pub unsafe trait OpCode {
120    /// Create submission entry.
121    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
122
123    /// Create submission entry for fallback. This method will only be called if
124    /// `create_entry` returns an entry with unsupported opcode.
125    fn create_entry_fallback(self: Pin<&mut Self>) -> OpEntry {
126        OpEntry::Blocking
127    }
128
129    /// Call the operation in a blocking way. This method will be called if
130    /// * [`create_entry`] returns [`OpEntry::Blocking`].
131    /// * [`create_entry`] returns an entry with unsupported opcode, and
132    ///   [`create_entry_fallback`] returns [`OpEntry::Blocking`].
133    /// * [`create_entry`] and [`create_entry_fallback`] both return an entry
134    ///   with unsupported opcode.
135    ///
136    /// [`create_entry`]: OpCode::create_entry
137    /// [`create_entry_fallback`]: OpCode::create_entry_fallback
138    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
139        unreachable!("this operation is asynchronous")
140    }
141
142    /// Set the result when it completes.
143    /// The operation stores the result and is responsible to release it if the
144    /// operation is cancelled.
145    ///
146    /// # Safety
147    ///
148    /// The params must be the result coming from this operation.
149    unsafe fn set_result(self: Pin<&mut Self>, _: &io::Result<usize>, _: &crate::Extra) {}
150
151    /// Push a multishot result to the inner queue.
152    ///
153    /// # Safety
154    ///
155    /// The params must be the result coming from this operation.
156    unsafe fn push_multishot(self: Pin<&mut Self>, _: io::Result<usize>, _: crate::Extra) {
157        unreachable!("this operation is not multishot")
158    }
159
160    /// Pop a multishot result from the inner queue.
161    fn pop_multishot(self: Pin<&mut Self>) -> Option<BufResult<usize, crate::sys::Extra>> {
162        unreachable!("this operation is not multishot")
163    }
164}
165
166pub use OpCode as IourOpCode;
167
168/// Low-level driver of io-uring.
169pub(crate) struct Driver {
170    inner: IoUring<SEntry, CEntry>,
171    notifier: Notifier,
172    pool: AsyncifyPool,
173    completed_tx: Sender<Entry>,
174    completed_rx: Receiver<Entry>,
175    buffer_group_ids: Slab<()>,
176    need_push_notifier: bool,
177    /// Keys leaked via `into_raw()` into io_uring user_data, freed on drop.
178    in_flight: HashSet<usize>,
179    _p: PhantomData<ErasedKey>,
180}
181
182impl Driver {
183    const CANCEL: u64 = u64::MAX;
184    const NOTIFY: u64 = u64::MAX - 1;
185
186    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
187        instrument!(compio_log::Level::TRACE, "new", ?builder);
188        trace!("new iour driver");
189        // if op_flags is empty, this loop will not run
190        for code in builder.op_flags.get_codes() {
191            if !is_op_supported(code) {
192                return Err(io::Error::new(
193                    io::ErrorKind::Unsupported,
194                    format!("io-uring does not support opcode {code:?}({code})"),
195                ));
196            }
197        }
198        let notifier = Notifier::new()?;
199        let mut io_uring_builder = IoUring::builder();
200        if let Some(sqpoll_idle) = builder.sqpoll_idle {
201            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
202        }
203        if builder.coop_taskrun {
204            io_uring_builder.setup_coop_taskrun();
205        }
206        if builder.taskrun_flag {
207            io_uring_builder.setup_taskrun_flag();
208        }
209
210        let inner = io_uring_builder.build(builder.capacity)?;
211
212        let submitter = inner.submitter();
213
214        if let Some(fd) = builder.eventfd {
215            submitter.register_eventfd(fd)?;
216        }
217
218        let (completed_tx, completed_rx) = flume::unbounded();
219
220        Ok(Self {
221            inner,
222            notifier,
223            completed_tx,
224            completed_rx,
225            pool: builder.create_or_get_thread_pool(),
226            buffer_group_ids: Slab::new(),
227            need_push_notifier: true,
228            in_flight: HashSet::new(),
229            _p: PhantomData,
230        })
231    }
232
233    pub fn driver_type(&self) -> DriverType {
234        DriverType::IoUring
235    }
236
237    #[allow(dead_code)]
238    pub fn as_iour(&self) -> Option<&Self> {
239        Some(self)
240    }
241
242    pub fn register_files(&self, fds: &[RawFd]) -> io::Result<()> {
243        self.inner.submitter().register_files(fds)?;
244        Ok(())
245    }
246
247    pub fn unregister_files(&self) -> io::Result<()> {
248        self.inner.submitter().unregister_files()?;
249        Ok(())
250    }
251
252    pub fn register_personality(&self) -> io::Result<u16> {
253        self.inner.submitter().register_personality()
254    }
255
256    pub fn unregister_personality(&self, personality: u16) -> io::Result<()> {
257        self.inner.submitter().unregister_personality(personality)
258    }
259
260    // Auto means that it choose to wait or not automatically.
261    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
262        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
263
264        // when taskrun is true, there are completed cqes wait to handle, no need to
265        // block the submit
266        let want_sqe = if self.inner.submission().taskrun() {
267            0
268        } else {
269            1
270        };
271
272        let res = {
273            // Last part of submission queue, wait till timeout.
274            if let Some(duration) = timeout {
275                let timespec = timespec(duration);
276                let args = SubmitArgs::new().timespec(&timespec);
277                self.inner.submitter().submit_with_args(want_sqe, &args)
278            } else {
279                self.inner.submit_and_wait(want_sqe)
280            }
281        };
282        trace!("submit result: {res:?}");
283        match res {
284            Ok(_) => {
285                if self.inner.completion().is_empty() {
286                    Err(io::ErrorKind::TimedOut.into())
287                } else {
288                    Ok(())
289                }
290            }
291            Err(e) => match e.raw_os_error() {
292                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
293                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
294                _ => Err(e),
295            },
296        }
297    }
298
299    fn poll_blocking(&mut self) {
300        while let Ok(entry) = self.completed_rx.try_recv() {
301            entry.notify();
302        }
303    }
304
305    fn poll_entries(&mut self) -> bool {
306        self.poll_blocking();
307
308        let mut cqueue = self.inner.completion();
309        cqueue.sync();
310        let has_entry = !cqueue.is_empty();
311        for entry in cqueue {
312            match entry.user_data() {
313                Self::CANCEL => {}
314                Self::NOTIFY => {
315                    let flags = entry.flags();
316                    if !more(flags) {
317                        self.need_push_notifier = true;
318                    }
319                    self.notifier.clear().expect("cannot clear notifier");
320                }
321                key => {
322                    let flags = entry.flags();
323                    if more(flags) {
324                        let key = unsafe { BorrowedKey::from_raw(key as _) };
325                        let mut key = key.borrow();
326                        #[allow(clippy::useless_conversion)]
327                        let mut extra: crate::sys::Extra = Extra::new().into();
328                        extra.set_flags(entry.flags());
329                        unsafe {
330                            key.pinned_op()
331                                .push_multishot(create_result(entry.result()), extra);
332                        }
333                        key.wake_by_ref();
334                    } else {
335                        self.in_flight.remove(&(key as usize));
336                        create_entry(entry).notify()
337                    }
338                }
339            }
340        }
341        has_entry
342    }
343
344    pub(in crate::sys) fn default_extra(&self) -> Extra {
345        Extra::new()
346    }
347
348    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
349        Ok(())
350    }
351
352    pub fn cancel(&mut self, key: ErasedKey) {
353        instrument!(compio_log::Level::TRACE, "cancel", ?key);
354        trace!("cancel RawOp");
355        unsafe {
356            #[allow(clippy::useless_conversion)]
357            if self
358                .inner
359                .submission()
360                .push(
361                    &AsyncCancel::new(key.as_raw() as _)
362                        .build()
363                        .user_data(Self::CANCEL)
364                        .into(),
365                )
366                .is_err()
367            {
368                warn!("could not push AsyncCancel entry");
369            }
370        }
371    }
372
373    fn push_raw_with_key(&mut self, entry: SEntry, key: ErasedKey) -> io::Result<()> {
374        let user_data = key.as_raw();
375        let entry = entry.user_data(user_data as _);
376        self.push_raw(entry)?; // if push failed, do not leak the key. Drop it upon return.
377        self.in_flight.insert(user_data);
378        key.into_raw();
379        Ok(())
380    }
381
382    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
383        loop {
384            let mut squeue = self.inner.submission();
385            match unsafe { squeue.push(&entry) } {
386                Ok(()) => {
387                    squeue.sync();
388                    break Ok(());
389                }
390                Err(_) => {
391                    drop(squeue);
392                    self.poll_entries();
393                    match self.submit_auto(Some(Duration::ZERO)) {
394                        Ok(()) => {}
395                        Err(e)
396                            if matches!(
397                                e.kind(),
398                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
399                            ) => {}
400                        Err(e) => return Err(e),
401                    }
402                }
403            }
404        }
405    }
406
407    pub fn push(&mut self, key: ErasedKey) -> Poll<io::Result<usize>> {
408        instrument!(compio_log::Level::TRACE, "push", ?key);
409        let personality = key.borrow().extra().as_iour().get_personality();
410        let mut op_entry = key
411            .borrow()
412            .pinned_op()
413            .create_entry()
414            .personality(personality);
415        let mut has_fallbacked = false;
416        trace!(?personality, "push Key");
417        loop {
418            match op_entry {
419                OpEntry::Submission(entry) => {
420                    if is_op_supported(entry.get_opcode() as _) {
421                        #[allow(clippy::useless_conversion)]
422                        self.push_raw_with_key(entry.into(), key)?;
423                    } else if !has_fallbacked {
424                        op_entry = key
425                            .borrow()
426                            .pinned_op()
427                            .create_entry_fallback()
428                            .personality(personality);
429                        has_fallbacked = true;
430                        continue;
431                    } else {
432                        self.push_blocking(key);
433                    }
434                }
435                #[cfg(feature = "io-uring-sqe128")]
436                OpEntry::Submission128(entry) => {
437                    if is_op_supported(entry.get_opcode() as _) {
438                        self.push_raw_with_key(entry, key)?;
439                    } else if !has_fallbacked {
440                        op_entry = key
441                            .borrow()
442                            .pinned_op()
443                            .create_entry_fallback()
444                            .personality(personality);
445                        has_fallbacked = true;
446                        continue;
447                    } else {
448                        self.push_blocking(key);
449                    }
450                }
451                OpEntry::Blocking => self.push_blocking(key),
452            }
453            break;
454        }
455        Poll::Pending
456    }
457
458    fn push_blocking(&mut self, key: ErasedKey) {
459        let waker = self.waker();
460        let completed = self.completed_tx.clone();
461        // SAFETY: we're submitting into the driver, so it's safe to freeze here.
462        let mut key = unsafe { key.freeze() };
463        let mut closure = move || {
464            let res = key.pinned_op().call_blocking();
465            let _ = completed.send(Entry::new(key.into_inner(), res));
466            waker.wake();
467        };
468        while let Err(e) = self.pool.dispatch(closure) {
469            closure = e.0;
470            // do something to avoid busy loop
471            self.poll_blocking();
472            std::thread::yield_now();
473        }
474        self.poll_blocking();
475    }
476
477    pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
478        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
479        // Anyway we need to submit once, no matter if there are entries in squeue.
480        trace!("start polling");
481
482        if self.need_push_notifier {
483            #[allow(clippy::useless_conversion)]
484            self.push_raw(
485                PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
486                    .multi(true)
487                    .build()
488                    .user_data(Self::NOTIFY)
489                    .into(),
490            )?;
491            self.need_push_notifier = false;
492        }
493
494        if !self.poll_entries() {
495            self.submit_auto(timeout)?;
496            self.poll_entries();
497        }
498
499        Ok(())
500    }
501
502    pub fn waker(&self) -> Waker {
503        self.notifier.waker()
504    }
505
506    pub fn create_buffer_pool(
507        &mut self,
508        buffer_len: u16,
509        buffer_size: usize,
510    ) -> io::Result<BufferPool> {
511        let buffer_group = self.buffer_group_ids.insert(());
512        if buffer_group > u16::MAX as usize {
513            self.buffer_group_ids.remove(buffer_group);
514
515            return Err(io::Error::new(
516                io::ErrorKind::OutOfMemory,
517                "too many buffer pool allocated",
518            ));
519        }
520
521        let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
522            &self.inner,
523            buffer_len,
524            buffer_group as _,
525            buffer_size,
526            0,
527        )?;
528
529        #[cfg(fusion)]
530        {
531            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
532                buf_ring,
533            )))
534        }
535        #[cfg(not(fusion))]
536        {
537            Ok(BufferPool::new(buf_ring))
538        }
539    }
540
541    /// # Safety
542    ///
543    /// caller must make sure release the buffer pool with correct driver
544    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
545        #[cfg(fusion)]
546        let buffer_pool = buffer_pool.into_io_uring();
547
548        let buffer_group = buffer_pool.buffer_group();
549        // FIXME: should we drop it directly if `into_inner` fails?
550        unsafe {
551            buffer_pool
552                .into_inner()
553                .expect("operations not completed")
554                .release(&self.inner)?
555        };
556        self.buffer_group_ids.remove(buffer_group as _);
557
558        Ok(())
559    }
560
561    pub fn pop_multishot(
562        &mut self,
563        key: &ErasedKey,
564    ) -> Option<BufResult<usize, crate::sys::Extra>> {
565        key.borrow().pinned_op().pop_multishot()
566    }
567}
568
569impl AsRawFd for Driver {
570    fn as_raw_fd(&self) -> RawFd {
571        self.inner.as_raw_fd()
572    }
573}
574
575impl Drop for Driver {
576    fn drop(&mut self) {
577        // Drain completed CQEs first to avoid double-free.
578        let mut cqueue = self.inner.completion();
579        cqueue.sync();
580        for entry in cqueue {
581            match entry.user_data() {
582                Self::CANCEL | Self::NOTIFY => {}
583                key => {
584                    self.in_flight.remove(&(key as usize));
585                    drop(unsafe { ErasedKey::from_raw(key as _) });
586                }
587            }
588        }
589
590        // Free remaining in-flight keys.
591        for user_data in self.in_flight.drain() {
592            drop(unsafe { ErasedKey::from_raw(user_data) });
593        }
594    }
595}
596
597fn create_entry(cq_entry: CEntry) -> Entry {
598    let result = cq_entry.result();
599    let result = create_result(result);
600    let key = unsafe { ErasedKey::from_raw(cq_entry.user_data() as _) };
601    let mut entry = Entry::new(key, result);
602    entry.set_flags(cq_entry.flags());
603
604    entry
605}
606
607fn create_result(result: i32) -> io::Result<usize> {
608    if result < 0 {
609        let result = if result == -libc::ECANCELED {
610            libc::ETIMEDOUT
611        } else {
612            -result
613        };
614        Err(io::Error::from_raw_os_error(result))
615    } else {
616        Ok(result as _)
617    }
618}
619
620fn timespec(duration: std::time::Duration) -> Timespec {
621    Timespec::new()
622        .sec(duration.as_secs())
623        .nsec(duration.subsec_nanos())
624}
625
626#[derive(Debug)]
627struct Notifier {
628    notify: Arc<Notify>,
629}
630
631impl Notifier {
632    /// Create a new notifier.
633    fn new() -> io::Result<Self> {
634        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
635        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
636        Ok(Self {
637            notify: Arc::new(Notify::new(fd)),
638        })
639    }
640
641    pub fn clear(&self) -> io::Result<()> {
642        loop {
643            let mut buffer = [0u64];
644            let res = syscall!(libc::read(
645                self.as_raw_fd(),
646                buffer.as_mut_ptr().cast(),
647                std::mem::size_of::<u64>()
648            ));
649            match res {
650                Ok(len) => {
651                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
652                    break Ok(());
653                }
654                // Clear the next time:)
655                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
656                // Just like read_exact
657                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
658                Err(e) => break Err(e),
659            }
660        }
661    }
662
663    pub fn waker(&self) -> Waker {
664        Waker::from(self.notify.clone())
665    }
666}
667
668impl AsRawFd for Notifier {
669    fn as_raw_fd(&self) -> RawFd {
670        self.notify.fd.as_raw_fd()
671    }
672}
673
674/// A notify handle to the inner driver.
675#[derive(Debug)]
676pub(crate) struct Notify {
677    fd: OwnedFd,
678}
679
680impl Notify {
681    pub(crate) fn new(fd: OwnedFd) -> Self {
682        Self { fd }
683    }
684
685    /// Notify the inner driver.
686    pub fn notify(&self) -> io::Result<()> {
687        let data = 1u64;
688        syscall!(libc::write(
689            self.fd.as_raw_fd(),
690            &data as *const _ as *const _,
691            std::mem::size_of::<u64>(),
692        ))?;
693        Ok(())
694    }
695}
696
697impl Wake for Notify {
698    fn wake(self: Arc<Self>) {
699        self.wake_by_ref();
700    }
701
702    fn wake_by_ref(self: &Arc<Self>) {
703        self.notify().ok();
704    }
705}