Skip to main content

compio_driver\sys\op\socket/
iocp.rs

1use rustix::net::RecvFlags;
2use windows_sys::Win32::{
3    Networking::WinSock::{
4        LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_DISCONNECTEX, LPFN_GETACCEPTEXSOCKADDRS,
5        LPFN_WSARECVMSG, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, SOCKADDR,
6        SOCKADDR_STORAGE, SOL_SOCKET, TF_REUSE_SOCKET, WSAID_ACCEPTEX, WSAID_CONNECTEX,
7        WSAID_DISCONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG, WSAMSG, WSARecv,
8        WSARecvFrom, WSASend, WSASendMsg, WSASendTo, closesocket, setsockopt, socklen_t,
9    },
10    System::IO::OVERLAPPED,
11};
12
13use crate::{OpCode, OpType, sys::op::*};
14
15static ACCEPT_EX: OnceLock<LPFN_ACCEPTEX> = OnceLock::new();
16static GET_ADDRS: OnceLock<LPFN_GETACCEPTEXSOCKADDRS> = OnceLock::new();
17
18const ACCEPT_ADDR_BUFFER_SIZE: usize = std::mem::size_of::<SOCKADDR_STORAGE>() + 16;
19const ACCEPT_BUFFER_SIZE: usize = ACCEPT_ADDR_BUFFER_SIZE * 2;
20
21unsafe impl OpCode for CloseSocket {
22    type Control = ();
23
24    fn op_type(&self, _: &Self::Control) -> OpType {
25        OpType::Blocking
26    }
27
28    unsafe fn operate(&mut self, _: &mut (), _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
29        Poll::Ready(Ok(
30            syscall!(SOCKET, closesocket(self.fd.as_fd().as_raw_fd() as _))? as _,
31        ))
32    }
33}
34
35/// Accept a connection.
36pub struct Accept<S, SA> {
37    pub(crate) fd: S,
38    pub(crate) accept_fd: SA,
39    pub(crate) buffer: [u8; ACCEPT_BUFFER_SIZE],
40}
41
42impl<S, SA> Accept<S, SA> {
43    /// Create [`Accept`]. `accept_fd` should not be bound.
44    pub fn new(fd: S, accept_fd: SA) -> Self {
45        Self {
46            fd,
47            accept_fd,
48            buffer: [0u8; ACCEPT_BUFFER_SIZE],
49        }
50    }
51}
52
53impl<S: AsFd, SA: AsFd> Accept<S, SA> {
54    /// Update accept context.
55    pub fn update_context(&self) -> io::Result<()> {
56        let fd = self.fd.as_fd().as_raw_fd();
57        syscall!(
58            SOCKET,
59            setsockopt(
60                self.accept_fd.as_fd().as_raw_fd() as _,
61                SOL_SOCKET,
62                SO_UPDATE_ACCEPT_CONTEXT,
63                &fd as *const _ as _,
64                std::mem::size_of_val(&fd) as _,
65            )
66        )?;
67        Ok(())
68    }
69
70    /// Get the remote address from the inner buffer.
71    pub fn into_addr(self) -> io::Result<(SA, SockAddr)> {
72        let get_addrs_fn = GET_ADDRS
73            .get_or_try_init(|| {
74                get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_GETACCEPTEXSOCKADDRS)
75            })?
76            .ok_or_else(|| {
77                io::Error::new(
78                    io::ErrorKind::Unsupported,
79                    "cannot retrieve GetAcceptExSockAddrs",
80                )
81            })?;
82        let mut local_addr: *mut SOCKADDR = null_mut();
83        let mut local_addr_len = 0;
84        let mut remote_addr: *mut SOCKADDR = null_mut();
85        let mut remote_addr_len = 0;
86        unsafe {
87            get_addrs_fn(
88                &self.buffer as *const _ as *const _,
89                0,
90                ACCEPT_ADDR_BUFFER_SIZE as _,
91                ACCEPT_ADDR_BUFFER_SIZE as _,
92                &mut local_addr,
93                &mut local_addr_len,
94                &mut remote_addr,
95                &mut remote_addr_len,
96            );
97        }
98        Ok((self.accept_fd, unsafe {
99            SockAddr::new(
100                // SAFETY: the buffer is large enough to hold the address
101                std::mem::transmute::<SOCKADDR_STORAGE, SockAddrStorage>(read_unaligned(
102                    remote_addr.cast::<SOCKADDR_STORAGE>(),
103                )),
104                remote_addr_len,
105            )
106        }))
107    }
108}
109
110unsafe impl<S: AsFd, SA: AsFd> OpCode for Accept<S, SA> {
111    type Control = ();
112
113    unsafe fn operate(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
114        let accept_fn = ACCEPT_EX
115            .get_or_try_init(|| get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_ACCEPTEX))?
116            .ok_or_else(|| {
117                io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve AcceptEx")
118            })?;
119        let mut received = 0;
120        let res = unsafe {
121            accept_fn(
122                self.fd.as_fd().as_raw_fd() as _,
123                self.accept_fd.as_fd().as_raw_fd() as _,
124                self.buffer.sys_slice_mut().ptr() as _,
125                0,
126                ACCEPT_ADDR_BUFFER_SIZE as _,
127                ACCEPT_ADDR_BUFFER_SIZE as _,
128                &mut received,
129                optr,
130            )
131        };
132        win32_result(res, received)
133    }
134
135    fn cancel(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> io::Result<()> {
136        cancel(self.fd.as_fd().as_raw_fd(), optr)
137    }
138}
139
140static CONNECT_EX: OnceLock<LPFN_CONNECTEX> = OnceLock::new();
141
142impl<S: AsFd> Connect<S> {
143    /// Update connect context.
144    pub fn update_context(&self) -> io::Result<()> {
145        syscall!(
146            SOCKET,
147            setsockopt(
148                self.fd.as_fd().as_raw_fd() as _,
149                SOL_SOCKET,
150                SO_UPDATE_CONNECT_CONTEXT,
151                null(),
152                0,
153            )
154        )?;
155        Ok(())
156    }
157}
158
159unsafe impl<S: AsFd> OpCode for Connect<S> {
160    type Control = ();
161
162    unsafe fn operate(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
163        let connect_fn = CONNECT_EX
164            .get_or_try_init(|| get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_CONNECTEX))?
165            .ok_or_else(|| {
166                io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve ConnectEx")
167            })?;
168        let mut sent = 0;
169        let res = unsafe {
170            connect_fn(
171                self.fd.as_fd().as_raw_fd() as _,
172                self.addr.as_ptr().cast(),
173                self.addr.len(),
174                null(),
175                0,
176                &mut sent,
177                optr,
178            )
179        };
180        win32_result(res, sent)
181    }
182
183    fn cancel(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> io::Result<()> {
184        cancel(self.fd.as_fd().as_raw_fd(), optr)
185    }
186}
187
188/// Disconnect a connected socket and reuse it for another connection.
189pub struct Disconnect<S> {
190    pub(crate) fd: S,
191}
192
193impl<S> Disconnect<S> {
194    /// Create [`Disconnect`].
195    pub fn new(fd: S) -> Self {
196        Self { fd }
197    }
198}
199
200static DISCONNECT_EX: OnceLock<LPFN_DISCONNECTEX> = OnceLock::new();
201
202unsafe impl<S: AsFd> OpCode for Disconnect<S> {
203    type Control = ();
204
205    unsafe fn operate(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
206        let disconnect_fn = DISCONNECT_EX
207            .get_or_try_init(|| get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_DISCONNECTEX))?
208            .ok_or_else(|| {
209                io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve DisconnectEx")
210            })?;
211        let res =
212            unsafe { disconnect_fn(self.fd.as_fd().as_raw_fd() as _, optr, TF_REUSE_SOCKET, 0) };
213        win32_result(res, 0)
214    }
215
216    fn cancel(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> io::Result<()> {
217        cancel(self.fd.as_fd().as_raw_fd(), optr)
218    }
219}
220
221#[derive(Default)]
222#[doc(hidden)]
223pub struct RecvControl {
224    pub(crate) slice: SysSlice,
225}
226
227unsafe impl<T: IoBufMut, S: AsFd> OpCode for Recv<T, S> {
228    type Control = RecvControl;
229
230    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
231        ctrl.slice = self.buffer.sys_slice_mut();
232    }
233
234    unsafe fn operate(
235        &mut self,
236        control: &mut Self::Control,
237        optr: *mut OVERLAPPED,
238    ) -> Poll<io::Result<usize>> {
239        let fd = self.fd.as_fd().as_raw_fd();
240        let mut flags = self.flags.bits() as _;
241        let mut received = 0;
242        let res = unsafe {
243            WSARecv(
244                fd as _,
245                &raw const control.slice as _,
246                1,
247                &mut received,
248                &mut flags,
249                optr,
250                None,
251            )
252        };
253        winsock_result(res, received)
254    }
255
256    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
257        cancel(self.fd.as_fd().as_raw_fd(), optr)
258    }
259}
260
261#[derive(Default)]
262#[doc(hidden)]
263pub struct RecvVectoredControl {
264    pub(crate) slices: Vec<SysSlice>,
265}
266
267unsafe impl<T: IoVectoredBufMut, S: AsFd> OpCode for RecvVectored<T, S> {
268    type Control = RecvVectoredControl;
269
270    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
271        ctrl.slices = self.buffer.sys_slices_mut();
272    }
273
274    unsafe fn operate(
275        &mut self,
276        control: &mut Self::Control,
277        optr: *mut OVERLAPPED,
278    ) -> Poll<io::Result<usize>> {
279        let fd = self.fd.as_fd().as_raw_fd();
280        let mut flags = self.flags.bits() as _;
281        let mut received = 0;
282        let res = unsafe {
283            WSARecv(
284                fd as _,
285                control.slices.as_ptr() as _,
286                control.slices.len() as _,
287                &mut received,
288                &mut flags,
289                optr,
290                None,
291            )
292        };
293        winsock_result(res, received)
294    }
295
296    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
297        cancel(self.fd.as_fd().as_raw_fd(), optr)
298    }
299}
300
301#[derive(Default)]
302#[doc(hidden)]
303pub struct SendControl {
304    pub(crate) slice: SysSlice,
305}
306
307unsafe impl<T: IoBuf, S: AsFd> OpCode for Send<T, S> {
308    type Control = SendControl;
309
310    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
311        ctrl.slice = self.buffer.sys_slice();
312    }
313
314    unsafe fn operate(
315        &mut self,
316        control: &mut Self::Control,
317        optr: *mut OVERLAPPED,
318    ) -> Poll<io::Result<usize>> {
319        let mut sent = 0;
320        let res = unsafe {
321            WSASend(
322                self.fd.as_fd().as_raw_fd() as _,
323                (&raw const control.slice).cast(),
324                1,
325                &mut sent,
326                self.flags.bits() as _,
327                optr,
328                None,
329            )
330        };
331        winsock_result(res, sent)
332    }
333
334    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
335        cancel(self.fd.as_fd().as_raw_fd(), optr)
336    }
337}
338
339#[derive(Default)]
340#[doc(hidden)]
341pub struct SendVectoredControl {
342    pub(crate) slices: Vec<SysSlice>,
343}
344
345unsafe impl<T: IoVectoredBuf, S: AsFd> OpCode for SendVectored<T, S> {
346    type Control = SendVectoredControl;
347
348    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
349        ctrl.slices = self.buffer.sys_slices();
350    }
351
352    unsafe fn operate(
353        &mut self,
354        control: &mut Self::Control,
355        optr: *mut OVERLAPPED,
356    ) -> Poll<io::Result<usize>> {
357        let mut sent = 0;
358        let res = unsafe {
359            WSASend(
360                self.fd.as_fd().as_raw_fd() as _,
361                control.slices.as_ptr() as _,
362                control.slices.len() as _,
363                &mut sent,
364                self.flags.bits() as _,
365                optr,
366                None,
367            )
368        };
369        winsock_result(res, sent)
370    }
371
372    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
373        cancel(self.fd.as_fd().as_raw_fd(), optr)
374    }
375}
376
377#[derive(Default)]
378#[doc(hidden)]
379pub struct RecvFromControl {
380    pub(crate) slice: SysSlice,
381}
382
383unsafe impl<T: IoBufMut, S: AsFd> OpCode for RecvFrom<T, S> {
384    type Control = RecvFromControl;
385
386    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
387        ctrl.slice = self.buffer.sys_slice_mut();
388    }
389
390    unsafe fn operate(
391        &mut self,
392        control: &mut Self::Control,
393        optr: *mut OVERLAPPED,
394    ) -> Poll<io::Result<usize>> {
395        let fd = self.header.fd.as_fd().as_raw_fd();
396        let mut flags = self.header.flags.bits() as _;
397        let mut received = 0;
398        let res = unsafe {
399            WSARecvFrom(
400                fd as _,
401                (&raw const control.slice).cast(),
402                1,
403                &mut received,
404                &mut flags,
405                &raw mut self.header.addr as *mut SOCKADDR,
406                &raw mut self.header.addr_len,
407                optr,
408                None,
409            )
410        };
411        winsock_result(res, received)
412    }
413
414    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
415        cancel(self.header.fd.as_fd().as_raw_fd(), optr)
416    }
417}
418
419#[derive(Default)]
420#[doc(hidden)]
421pub struct RecvFromVectoredControl {
422    pub(crate) slices: Vec<SysSlice>,
423}
424
425unsafe impl<T: IoVectoredBufMut, S: AsFd> OpCode for RecvFromVectored<T, S> {
426    type Control = RecvFromVectoredControl;
427
428    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
429        ctrl.slices = self.buffer.sys_slices_mut();
430    }
431
432    unsafe fn operate(
433        &mut self,
434        control: &mut Self::Control,
435        optr: *mut OVERLAPPED,
436    ) -> Poll<io::Result<usize>> {
437        let fd = self.header.fd.as_fd().as_raw_fd();
438        let mut flags = self.header.flags.bits() as _;
439        let mut received = 0;
440        let res = unsafe {
441            WSARecvFrom(
442                fd as _,
443                control.slices.as_ptr() as _,
444                control.slices.len() as _,
445                &mut received,
446                &mut flags,
447                &raw mut self.header.addr as *mut SOCKADDR,
448                &raw mut self.header.addr_len,
449                optr,
450                None,
451            )
452        };
453        winsock_result(res, received)
454    }
455
456    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
457        cancel(self.header.fd.as_fd().as_raw_fd(), optr)
458    }
459}
460
461#[derive(Default)]
462#[doc(hidden)]
463pub struct SendToControl {
464    pub(crate) slice: SysSlice,
465}
466
467unsafe impl<T: IoBuf, S: AsFd> OpCode for SendTo<T, S> {
468    type Control = SendToControl;
469
470    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
471        ctrl.slice = self.buffer.sys_slice();
472    }
473
474    unsafe fn operate(
475        &mut self,
476        control: &mut Self::Control,
477        optr: *mut OVERLAPPED,
478    ) -> Poll<io::Result<usize>> {
479        let mut sent = 0;
480        let res = unsafe {
481            WSASendTo(
482                self.header.fd.as_fd().as_raw_fd() as _,
483                (&raw const control.slice).cast(),
484                1,
485                &mut sent,
486                self.header.flags.bits() as _,
487                self.header.addr.as_ptr().cast(),
488                self.header.addr.len(),
489                optr,
490                None,
491            )
492        };
493        winsock_result(res, sent)
494    }
495
496    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
497        cancel(self.header.fd.as_fd().as_raw_fd(), optr)
498    }
499}
500
501#[derive(Default)]
502#[doc(hidden)]
503pub struct SendToVectoredControl {
504    pub(crate) slices: Vec<SysSlice>,
505}
506
507unsafe impl<T: IoVectoredBuf, S: AsFd> OpCode for SendToVectored<T, S> {
508    type Control = SendToVectoredControl;
509
510    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
511        ctrl.slices = self.buffer.sys_slices();
512    }
513
514    unsafe fn operate(
515        &mut self,
516        control: &mut Self::Control,
517        optr: *mut OVERLAPPED,
518    ) -> Poll<io::Result<usize>> {
519        let mut sent = 0;
520        let res = unsafe {
521            WSASendTo(
522                self.header.fd.as_fd().as_raw_fd() as _,
523                control.slices.as_ptr() as _,
524                control.slices.len() as _,
525                &mut sent,
526                self.header.flags.bits() as _,
527                self.header.addr.as_ptr().cast(),
528                self.header.addr.len(),
529                optr,
530                None,
531            )
532        };
533        winsock_result(res, sent)
534    }
535
536    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
537        cancel(self.header.fd.as_fd().as_raw_fd(), optr)
538    }
539}
540
541static WSA_RECVMSG: OnceLock<LPFN_WSARECVMSG> = OnceLock::new();
542
543#[derive(Default)]
544#[doc(hidden)]
545pub struct RecvMsgControl {
546    msg: WSAMSG,
547    #[allow(dead_code)]
548    slices: Vec<SysSlice>,
549}
550
551unsafe impl<T: IoVectoredBufMut, C: IoBufMut, S: AsFd> OpCode for RecvMsg<T, C, S> {
552    type Control = RecvMsgControl;
553
554    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
555        ctrl.slices = self.buffer.sys_slices_mut();
556        ctrl.msg.dwFlags = self.header.flags.bits() as _;
557        ctrl.msg.name = &raw mut self.header.addr as _;
558        ctrl.msg.namelen = self.header.addr.size_of() as _;
559        ctrl.msg.lpBuffers = ctrl.slices.as_mut_ptr() as _;
560        ctrl.msg.dwBufferCount = ctrl.slices.len() as _;
561        ctrl.msg.Control = self.control.sys_slice_mut().into_inner();
562    }
563
564    unsafe fn operate(
565        &mut self,
566        control: &mut RecvMsgControl,
567        optr: *mut OVERLAPPED,
568    ) -> Poll<io::Result<usize>> {
569        let recvmsg_fn = WSA_RECVMSG
570            .get_or_try_init(|| get_wsa_fn(self.header.fd.as_fd().as_raw_fd(), WSAID_WSARECVMSG))?
571            .ok_or_else(|| {
572                io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve WSARecvMsg")
573            })?;
574
575        let mut received = 0;
576        let res = unsafe {
577            recvmsg_fn(
578                self.header.fd.as_fd().as_raw_fd() as _,
579                &raw mut control.msg,
580                &mut received,
581                optr,
582                None,
583            )
584        };
585        winsock_result(res, received)
586    }
587
588    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
589        cancel(self.header.fd.as_fd().as_raw_fd(), optr)
590    }
591
592    unsafe fn set_result(
593        &mut self,
594        control: &mut Self::Control,
595        _: &io::Result<usize>,
596        _: &crate::Extra,
597    ) {
598        self.header.flags = RecvFlags::from_bits_retain(control.msg.dwFlags);
599        self.header.addr_len = control.msg.namelen as socklen_t;
600        self.control_len = control.msg.Control.len as _;
601    }
602}
603
604#[derive(Default)]
605#[doc(hidden)]
606pub struct SendMsgControl {
607    msg: WSAMSG,
608    #[allow(dead_code)]
609    slices: Vec<SysSlice>,
610}
611
612unsafe impl<T: IoVectoredBuf, C: IoBuf, S: AsFd> OpCode for SendMsg<T, C, S> {
613    type Control = SendMsgControl;
614
615    unsafe fn init(&mut self, ctrl: &mut Self::Control) {
616        ctrl.slices = self.buffer.sys_slices();
617        let control = if self.control.buf_len() == 0 {
618            SysSlice::null()
619        } else {
620            self.control.sys_slice()
621        };
622
623        ctrl.msg.lpBuffers = ctrl.slices.as_ptr() as _;
624        ctrl.msg.dwBufferCount = ctrl.slices.len() as _;
625        ctrl.msg.Control = control.into_inner();
626        if let Some(addr) = &self.addr {
627            ctrl.msg.name = addr.as_ptr() as _;
628            ctrl.msg.namelen = addr.len() as _;
629        }
630    }
631
632    unsafe fn operate(
633        &mut self,
634        control: &mut Self::Control,
635        optr: *mut OVERLAPPED,
636    ) -> Poll<io::Result<usize>> {
637        let mut sent = 0;
638        let res = unsafe {
639            WSASendMsg(
640                self.fd.as_fd().as_raw_fd() as _,
641                &raw mut control.msg,
642                self.flags.bits() as _,
643                &mut sent,
644                optr,
645                None,
646            )
647        };
648        winsock_result(res, sent)
649    }
650
651    fn cancel(&mut self, _: &mut Self::Control, optr: *mut OVERLAPPED) -> io::Result<()> {
652        cancel(self.fd.as_fd().as_raw_fd(), optr)
653    }
654}