1use std::{
4 cell::RefCell, collections::HashMap, io, mem::MaybeUninit, os::fd::FromRawFd, ptr::null_mut,
5 thread_local,
6};
7
8use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetLen};
9use compio_driver::{OwnedFd, SharedFd, op::Read, syscall};
10
11thread_local! {
12 static REG_MAP: RefCell<HashMap<i32, usize>> = RefCell::new(HashMap::new());
13}
14
15fn sigset(sig: i32) -> io::Result<libc::sigset_t> {
16 let mut set: MaybeUninit<libc::sigset_t> = MaybeUninit::uninit();
17 syscall!(libc::sigemptyset(set.as_mut_ptr()))?;
18 syscall!(libc::sigaddset(set.as_mut_ptr(), sig))?;
19 Ok(unsafe { set.assume_init() })
21}
22
23fn register_signal(sig: i32) -> io::Result<libc::sigset_t> {
24 REG_MAP.with_borrow_mut(|map| {
25 let count = map.entry(sig).or_default();
26 let set = sigset(sig)?;
27 if *count == 0 {
28 syscall!(libc::pthread_sigmask(libc::SIG_BLOCK, &set, null_mut()))?;
29 }
30 *count += 1;
31 Ok(set)
32 })
33}
34
35fn unregister_signal(sig: i32) -> io::Result<libc::sigset_t> {
36 REG_MAP.with_borrow_mut(|map| {
37 let count = map.entry(sig).or_default();
38 if *count > 0 {
39 *count -= 1;
40 }
41 let set = sigset(sig)?;
42 if *count == 0 {
43 syscall!(libc::pthread_sigmask(libc::SIG_UNBLOCK, &set, null_mut()))?;
44 }
45 Ok(set)
46 })
47}
48
49#[derive(Debug)]
51struct SignalFd {
52 fd: SharedFd<OwnedFd>,
53 sig: i32,
54}
55
56impl SignalFd {
57 fn new(sig: i32) -> io::Result<Self> {
58 let set = register_signal(sig)?;
59 let mut flag = libc::SFD_CLOEXEC;
60 if compio_runtime::Runtime::with_current(|r| r.driver_type()).is_polling() {
61 flag |= libc::SFD_NONBLOCK;
62 }
63 let fd = syscall!(libc::signalfd(-1, &set, flag))?;
64 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
65 Ok(Self {
66 fd: SharedFd::new(fd),
67 sig,
68 })
69 }
70
71 async fn wait(self) -> io::Result<()> {
72 const INFO_SIZE: usize = std::mem::size_of::<libc::signalfd_siginfo>();
73
74 struct SignalInfo(MaybeUninit<libc::signalfd_siginfo>);
75
76 impl IoBuf for SignalInfo {
77 fn as_init(&self) -> &[u8] {
78 let ptr = self.0.as_ptr() as *const u8;
79 unsafe { std::slice::from_raw_parts(ptr, INFO_SIZE) }
80 }
81 }
82
83 impl IoBufMut for SignalInfo {
84 fn as_uninit(&mut self) -> &mut [MaybeUninit<u8>] {
85 let ptr = self.0.as_mut_ptr() as *mut _;
86 unsafe { std::slice::from_raw_parts_mut(ptr, INFO_SIZE) }
87 }
88 }
89
90 impl SetLen for SignalInfo {
91 unsafe fn set_len(&mut self, len: usize) {
92 debug_assert!(len <= INFO_SIZE)
93 }
94 }
95
96 let info = SignalInfo(MaybeUninit::<libc::signalfd_siginfo>::uninit());
97 let op = Read::new(self.fd.clone(), info);
98 let BufResult(res, op) = compio_runtime::submit(op).await;
99 let len = res?;
100 debug_assert_eq!(len, INFO_SIZE);
101 let info = op.into_inner();
102 let info = unsafe { info.0.assume_init() };
103 debug_assert_eq!(info.ssi_signo, self.sig as u32);
104 Ok(())
105 }
106}
107
108impl Drop for SignalFd {
109 fn drop(&mut self) {
110 unregister_signal(self.sig).ok();
111 }
112}
113
114pub async fn signal(sig: i32) -> io::Result<()> {
119 let fd = SignalFd::new(sig)?;
120 fd.wait().await?;
121 Ok(())
122}