Skip to main content

compio_io/ancillary/
sys.rs

1use std::{mem::MaybeUninit, slice};
2
3#[cfg(unix)]
4use libc::{CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, msghdr};
5#[cfg(unix)]
6pub use libc::{CMSG_SPACE, cmsghdr};
7#[cfg(windows)]
8pub use windows_sys::Win32::Networking::WinSock::CMSGHDR as cmsghdr;
9#[cfg(windows)]
10use windows_sys::Win32::Networking::WinSock::{self, IN_PKTINFO, IN6_PKTINFO};
11
12use super::{AncillaryData, CodecError, copy_from_bytes, copy_to_bytes};
13
14#[cfg(windows)]
15#[allow(non_snake_case)]
16mod windows_macros {
17    use std::ptr::null_mut;
18
19    use windows_sys::Win32::Networking::WinSock::{CMSGHDR, WSABUF, WSAMSG};
20
21    const fn CMSG_ALIGN(length: usize) -> usize {
22        (length + align_of::<CMSGHDR>() - 1) & !(align_of::<CMSGHDR>() - 1)
23    }
24
25    const WSA_CMSGDATA_OFFSET: usize = CMSG_ALIGN(size_of::<CMSGHDR>());
26
27    pub unsafe fn CMSG_DATA(cmsg: *const CMSGHDR) -> *mut u8 {
28        unsafe { cmsg.offset(1) as *mut u8 }
29    }
30
31    pub const unsafe fn CMSG_SPACE(length: usize) -> usize {
32        WSA_CMSGDATA_OFFSET + CMSG_ALIGN(length)
33    }
34
35    pub const unsafe fn CMSG_LEN(length: usize) -> usize {
36        WSA_CMSGDATA_OFFSET + length
37    }
38
39    pub unsafe fn CMSG_FIRSTHDR(msg: *const WSAMSG) -> *mut CMSGHDR {
40        unsafe {
41            if (*msg).Control.len as usize >= size_of::<CMSGHDR>() {
42                (*msg).Control.buf as _
43            } else {
44                null_mut()
45            }
46        }
47    }
48
49    pub unsafe fn CMSG_NXTHDR(msg: *const WSAMSG, cmsg: *const CMSGHDR) -> *mut CMSGHDR {
50        unsafe {
51            if cmsg.is_null() {
52                CMSG_FIRSTHDR(msg)
53            } else {
54                let next = cmsg as usize + CMSG_ALIGN((*cmsg).cmsg_len);
55                if next + size_of::<CMSGHDR>()
56                    > (*msg).Control.buf as usize + (*msg).Control.len as usize
57                {
58                    null_mut()
59                } else {
60                    next as _
61                }
62            }
63        }
64    }
65
66    pub fn msghdr_from_raw(ptr: *const u8, len: usize) -> WSAMSG {
67        WSAMSG {
68            Control: WSABUF {
69                len: len as _,
70                buf: ptr as _,
71            },
72            ..unsafe { std::mem::zeroed() }
73        }
74    }
75}
76
77#[cfg(windows)]
78pub use windows_macros::CMSG_SPACE;
79#[cfg(windows)]
80use windows_macros::{CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, msghdr_from_raw};
81
82#[cfg(unix)]
83fn msghdr_from_raw(ptr: *const u8, len: usize) -> msghdr {
84    let mut msg: msghdr = unsafe { std::mem::zeroed() };
85    msg.msg_control = ptr as _;
86    msg.msg_controllen = len as _;
87    msg
88}
89
90pub(crate) struct CMsgRef<'a>(&'a cmsghdr);
91
92impl CMsgRef<'_> {
93    pub(crate) fn level(&self) -> i32 {
94        self.0.cmsg_level as _
95    }
96
97    pub(crate) fn ty(&self) -> i32 {
98        self.0.cmsg_type as _
99    }
100
101    pub(crate) fn len(&self) -> usize {
102        self.0.cmsg_len as _
103    }
104
105    pub(crate) fn decode_data<T: AncillaryData>(&self) -> Result<T, CodecError> {
106        let data_ptr = unsafe { CMSG_DATA(self.0) } as *const u8;
107        let buffer = unsafe { slice::from_raw_parts(data_ptr, self.len()) };
108        T::decode(buffer)
109    }
110}
111
112pub(crate) struct CMsgMut<'a>(&'a mut cmsghdr);
113
114impl CMsgMut<'_> {
115    pub(crate) fn set_level(&mut self, level: i32) {
116        self.0.cmsg_level = level as _;
117    }
118
119    pub(crate) fn set_ty(&mut self, ty: i32) {
120        self.0.cmsg_type = ty as _;
121    }
122
123    pub(crate) fn encode_data<T: AncillaryData>(&mut self, value: &T) -> Result<usize, CodecError> {
124        self.0.cmsg_len = unsafe { CMSG_LEN(T::SIZE as _) } as _;
125        let data_ptr = unsafe { CMSG_DATA(self.0) } as *mut MaybeUninit<u8>;
126        let buffer = unsafe { slice::from_raw_parts_mut(data_ptr, T::SIZE) };
127        value.encode(buffer)?;
128        Ok(unsafe { CMSG_SPACE(T::SIZE as _) } as _)
129    }
130}
131
132pub(crate) struct CMsgIter {
133    len: usize,
134    offset: Option<usize>,
135}
136
137impl CMsgIter {
138    pub(crate) fn new(ptr: *const u8, len: usize) -> Self {
139        assert!(len >= unsafe { CMSG_SPACE(0) as _ }, "buffer too short");
140        assert!(ptr.cast::<cmsghdr>().is_aligned(), "misaligned buffer");
141
142        let msg = msghdr_from_raw(ptr.cast_mut(), len);
143        let first_cmsg = unsafe { CMSG_FIRSTHDR(&msg) };
144
145        let offset = if first_cmsg.is_null() {
146            None
147        } else {
148            Some(first_cmsg.addr() - ptr.addr())
149        };
150        Self { len, offset }
151    }
152
153    pub(crate) unsafe fn current<'a>(&self, ptr: *const u8) -> Option<CMsgRef<'a>> {
154        self.offset
155            .and_then(|offset| unsafe { ptr.add(offset).cast::<cmsghdr>().as_ref() })
156            .map(CMsgRef)
157    }
158
159    pub(crate) unsafe fn next(&mut self, ptr: *const u8) {
160        if let Some(offset) = self.offset {
161            let msg = msghdr_from_raw(ptr, self.len);
162            let next_cmsg = unsafe { CMSG_NXTHDR(&msg, ptr.add(offset).cast()) };
163            if next_cmsg.is_null() {
164                self.offset = None;
165            } else {
166                self.offset = Some(next_cmsg.addr() - ptr.addr());
167            }
168        }
169    }
170
171    pub(crate) unsafe fn current_mut<'a>(&self, ptr: *mut u8) -> Option<CMsgMut<'a>> {
172        self.offset
173            .and_then(|offset| unsafe { ptr.add(offset).cast::<cmsghdr>().as_mut() })
174            .map(CMsgMut)
175    }
176
177    pub(crate) fn is_space_enough(&self, space: usize) -> bool {
178        if let Some(offset) = self.offset {
179            #[allow(clippy::unnecessary_cast)]
180            let space = unsafe { CMSG_SPACE(space as _) } as usize;
181            offset + space <= self.len
182        } else {
183            false
184        }
185    }
186}
187
188#[cfg(unix)]
189impl AncillaryData for libc::in_addr {
190    fn encode(&self, buffer: &mut [MaybeUninit<u8>]) -> Result<(), CodecError> {
191        unsafe { copy_to_bytes(self, buffer) }
192    }
193
194    fn decode(buffer: &[u8]) -> Result<Self, CodecError> {
195        unsafe { copy_from_bytes(buffer) }
196    }
197}
198
199#[cfg(any(target_os = "linux", target_os = "android"))]
200impl AncillaryData for libc::in_pktinfo {
201    fn encode(&self, buffer: &mut [MaybeUninit<u8>]) -> Result<(), CodecError> {
202        let mut pktinfo: libc::in_pktinfo = unsafe { std::mem::zeroed() };
203        pktinfo.ipi_ifindex = self.ipi_ifindex;
204        pktinfo.ipi_spec_dst.s_addr = self.ipi_spec_dst.s_addr;
205        pktinfo.ipi_addr.s_addr = self.ipi_addr.s_addr;
206        unsafe { copy_to_bytes(&pktinfo, buffer) }
207    }
208
209    fn decode(buffer: &[u8]) -> Result<Self, CodecError> {
210        let pktinfo: libc::in_pktinfo = unsafe { copy_from_bytes(buffer) }?;
211        Ok(libc::in_pktinfo {
212            ipi_ifindex: pktinfo.ipi_ifindex,
213            ipi_spec_dst: libc::in_addr {
214                s_addr: pktinfo.ipi_spec_dst.s_addr,
215            },
216            ipi_addr: libc::in_addr {
217                s_addr: pktinfo.ipi_addr.s_addr,
218            },
219        })
220    }
221}
222
223#[cfg(unix)]
224impl AncillaryData for libc::in6_pktinfo {
225    fn encode(&self, buffer: &mut [MaybeUninit<u8>]) -> Result<(), CodecError> {
226        let mut pktinfo: libc::in6_pktinfo = unsafe { std::mem::zeroed() };
227        pktinfo.ipi6_ifindex = self.ipi6_ifindex;
228        pktinfo.ipi6_addr.s6_addr = self.ipi6_addr.s6_addr;
229        unsafe { copy_to_bytes(&pktinfo, buffer) }
230    }
231
232    fn decode(buffer: &[u8]) -> Result<Self, CodecError> {
233        let pktinfo: libc::in6_pktinfo = unsafe { copy_from_bytes(buffer) }?;
234        Ok(libc::in6_pktinfo {
235            ipi6_ifindex: pktinfo.ipi6_ifindex,
236            ipi6_addr: libc::in6_addr {
237                s6_addr: pktinfo.ipi6_addr.s6_addr,
238            },
239        })
240    }
241}
242
243#[cfg(windows)]
244impl AncillaryData for IN_PKTINFO {
245    fn encode(&self, buffer: &mut [MaybeUninit<u8>]) -> Result<(), CodecError> {
246        let mut pktinfo: IN_PKTINFO = unsafe { std::mem::zeroed() };
247        unsafe {
248            pktinfo.ipi_addr.S_un.S_addr = self.ipi_addr.S_un.S_addr;
249        }
250        pktinfo.ipi_ifindex = self.ipi_ifindex;
251        unsafe { copy_to_bytes(&pktinfo, buffer) }
252    }
253
254    fn decode(buffer: &[u8]) -> Result<Self, CodecError> {
255        let pktinfo: IN_PKTINFO = unsafe { copy_from_bytes(buffer) }?;
256        Ok(IN_PKTINFO {
257            ipi_addr: WinSock::IN_ADDR {
258                S_un: WinSock::IN_ADDR_0 {
259                    S_addr: unsafe { pktinfo.ipi_addr.S_un.S_addr },
260                },
261            },
262            ipi_ifindex: pktinfo.ipi_ifindex,
263        })
264    }
265}
266
267#[cfg(windows)]
268impl AncillaryData for IN6_PKTINFO {
269    fn encode(&self, buffer: &mut [MaybeUninit<u8>]) -> Result<(), CodecError> {
270        let mut pktinfo: IN6_PKTINFO = unsafe { std::mem::zeroed() };
271        unsafe {
272            pktinfo.ipi6_addr.u.Byte = self.ipi6_addr.u.Byte;
273        }
274        pktinfo.ipi6_ifindex = self.ipi6_ifindex;
275        unsafe { copy_to_bytes(&pktinfo, buffer) }
276    }
277
278    fn decode(buffer: &[u8]) -> Result<Self, CodecError> {
279        let pktinfo: IN6_PKTINFO = unsafe { copy_from_bytes(buffer) }?;
280        Ok(IN6_PKTINFO {
281            ipi6_addr: WinSock::IN6_ADDR {
282                u: WinSock::IN6_ADDR_0 {
283                    Byte: unsafe { pktinfo.ipi6_addr.u.Byte },
284                },
285            },
286            ipi6_ifindex: pktinfo.ipi6_ifindex,
287        })
288    }
289}