Skip to main content

compio_signal/
linux.rs

1//! Linux-specific types for signal handling.
2
3use std::{
4    cell::RefCell, collections::HashMap, io, mem::MaybeUninit, os::fd::FromRawFd, ptr::null_mut,
5    thread_local,
6};
7
8use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetLen};
9use compio_driver::{OwnedFd, SharedFd, op::Read, syscall};
10
11thread_local! {
12    static REG_MAP: RefCell<HashMap<i32, usize>> = RefCell::new(HashMap::new());
13}
14
15fn sigset(sig: i32) -> io::Result<libc::sigset_t> {
16    let mut set: MaybeUninit<libc::sigset_t> = MaybeUninit::uninit();
17    syscall!(libc::sigemptyset(set.as_mut_ptr()))?;
18    syscall!(libc::sigaddset(set.as_mut_ptr(), sig))?;
19    // SAFETY: sigemptyset initializes the set.
20    Ok(unsafe { set.assume_init() })
21}
22
23fn register_signal(sig: i32) -> io::Result<libc::sigset_t> {
24    REG_MAP.with_borrow_mut(|map| {
25        let count = map.entry(sig).or_default();
26        let set = sigset(sig)?;
27        if *count == 0 {
28            syscall!(libc::pthread_sigmask(libc::SIG_BLOCK, &set, null_mut()))?;
29        }
30        *count += 1;
31        Ok(set)
32    })
33}
34
35fn unregister_signal(sig: i32) -> io::Result<libc::sigset_t> {
36    REG_MAP.with_borrow_mut(|map| {
37        let count = map.entry(sig).or_default();
38        if *count > 0 {
39            *count -= 1;
40        }
41        let set = sigset(sig)?;
42        if *count == 0 {
43            syscall!(libc::pthread_sigmask(libc::SIG_UNBLOCK, &set, null_mut()))?;
44        }
45        Ok(set)
46    })
47}
48
49/// Represents a listener to unix signal event.
50#[derive(Debug)]
51struct SignalFd {
52    fd: SharedFd<OwnedFd>,
53    sig: i32,
54}
55
56impl SignalFd {
57    fn new(sig: i32) -> io::Result<Self> {
58        let set = register_signal(sig)?;
59        let mut flag = libc::SFD_CLOEXEC;
60        if compio_runtime::Runtime::with_current(|r| r.driver_type()).is_polling() {
61            flag |= libc::SFD_NONBLOCK;
62        }
63        let fd = syscall!(libc::signalfd(-1, &set, flag))?;
64        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
65        Ok(Self {
66            fd: SharedFd::new(fd),
67            sig,
68        })
69    }
70
71    async fn wait(self) -> io::Result<()> {
72        const INFO_SIZE: usize = std::mem::size_of::<libc::signalfd_siginfo>();
73
74        struct SignalInfo(MaybeUninit<libc::signalfd_siginfo>);
75
76        impl IoBuf for SignalInfo {
77            fn as_init(&self) -> &[u8] {
78                let ptr = self.0.as_ptr() as *const u8;
79                unsafe { std::slice::from_raw_parts(ptr, INFO_SIZE) }
80            }
81        }
82
83        impl IoBufMut for SignalInfo {
84            fn as_uninit(&mut self) -> &mut [MaybeUninit<u8>] {
85                let ptr = self.0.as_mut_ptr() as *mut _;
86                unsafe { std::slice::from_raw_parts_mut(ptr, INFO_SIZE) }
87            }
88        }
89
90        impl SetLen for SignalInfo {
91            unsafe fn set_len(&mut self, len: usize) {
92                debug_assert!(len <= INFO_SIZE)
93            }
94        }
95
96        let info = SignalInfo(MaybeUninit::<libc::signalfd_siginfo>::uninit());
97        let op = Read::new(self.fd.clone(), info);
98        let BufResult(res, op) = compio_runtime::submit(op).await;
99        let len = res?;
100        debug_assert_eq!(len, INFO_SIZE);
101        let info = op.into_inner();
102        let info = unsafe { info.0.assume_init() };
103        debug_assert_eq!(info.ssi_signo, self.sig as u32);
104        Ok(())
105    }
106}
107
108impl Drop for SignalFd {
109    fn drop(&mut self) {
110        unregister_signal(self.sig).ok();
111    }
112}
113
114/// Creates a new listener which will receive notifications when the current
115/// process receives the specified signal.
116///
117/// It sets the signal mask of the current thread.
118pub async fn signal(sig: i32) -> io::Result<()> {
119    let fd = SignalFd::new(sig)?;
120    fd.wait().await?;
121    Ok(())
122}