1use std::{
2 future::Future,
3 io,
4 path::Path,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
10use compio_driver::{
11 BufferRef, SharedFd, impl_raw_fd,
12 op::{RecvFlags, RecvMsgMultiResult, SendMsgZc, SendVectoredZc, SendZc},
13};
14use compio_io::{
15 AsyncRead, AsyncReadManaged, AsyncReadMulti, AsyncWrite, AsyncWriteZerocopy,
16 ancillary::{
17 AsyncReadAncillary, AsyncReadAncillaryManaged, AsyncReadAncillaryMulti,
18 AsyncWriteAncillary, AsyncWriteAncillaryZerocopy,
19 },
20 util::Splittable,
21};
22use compio_runtime::fd::PollFd;
23use futures_util::{Stream, StreamExt, stream::FusedStream};
24use socket2::{Domain, SockAddr, Socket as Socket2, Type};
25
26use crate::{Extract, Incoming, MSG_NOSIGNAL, ReadHalf, Socket, WriteHalf, Zerocopy};
27
28#[derive(Debug, Clone)]
57pub struct UnixListener {
58 inner: Socket,
59}
60
61impl UnixListener {
62 pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
65 Self::bind_addr(&SockAddr::unix(path)?).await
66 }
67
68 pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
74 if !addr.is_unix() {
75 return Err(io::Error::new(
76 io::ErrorKind::InvalidInput,
77 "addr is not unix socket address",
78 ));
79 }
80
81 let socket = Socket::new(addr.domain(), Type::STREAM, None).await?;
82 socket.bind(addr).await?;
83 socket.listen(1024).await?;
84 Ok(UnixListener { inner: socket })
85 }
86
87 #[cfg(unix)]
88 pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
90 Ok(Self {
91 inner: Socket::from_socket2(Socket2::from(stream))?,
92 })
93 }
94
95 pub fn close(self) -> impl Future<Output = io::Result<()>> {
102 self.inner.close()
103 }
104
105 pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
111 let (socket, addr) = self.inner.accept().await?;
112 let stream = UnixStream { inner: socket };
113 Ok((stream, addr))
114 }
115
116 #[cfg(windows)]
119 pub async fn accept_with(&self, sock: UnixSocket) -> io::Result<(UnixStream, SockAddr)> {
120 let (socket, addr) = self.inner.accept_with(sock.inner).await?;
121 let stream = UnixStream { inner: socket };
122 Ok((stream, addr))
123 }
124
125 pub fn incoming(&self) -> UnixIncoming<'_> {
127 UnixIncoming {
128 inner: self.inner.incoming(),
129 }
130 }
131
132 pub fn local_addr(&self) -> io::Result<SockAddr> {
134 self.inner.local_addr()
135 }
136
137 pub fn take_error(&self) -> io::Result<Option<io::Error>> {
139 self.inner.socket.take_error()
140 }
141}
142
143impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
144
145pub struct UnixIncoming<'a> {
147 inner: Incoming<'a>,
148}
149
150impl Stream for UnixIncoming<'_> {
151 type Item = io::Result<UnixStream>;
152
153 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
154 let this = self.get_mut();
155 this.inner.poll_next_unpin(cx).map(|res| {
156 res.map(|res| {
157 let socket = res?;
158 Ok(UnixStream { inner: socket })
159 })
160 })
161 }
162}
163
164impl FusedStream for UnixIncoming<'_> {
165 fn is_terminated(&self) -> bool {
166 self.inner.is_terminated()
167 }
168}
169
170#[derive(Debug, Clone)]
190pub struct UnixStream {
191 inner: Socket,
192}
193
194impl UnixStream {
195 pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
198 Self::connect_addr(&SockAddr::unix(path)?).await
199 }
200
201 pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
208 if !addr.is_unix() {
209 return Err(io::Error::new(
210 io::ErrorKind::InvalidInput,
211 "addr is not unix socket address",
212 ));
213 }
214 let socket = Socket::new(Domain::UNIX, Type::STREAM, None).await?;
215 #[cfg(windows)]
216 {
217 let new_addr = empty_unix_socket();
218 socket.bind(&new_addr).await?
219 }
220 socket.connect_async(addr).await?;
221 let unix_stream = UnixStream { inner: socket };
222 Ok(unix_stream)
223 }
224
225 #[cfg(unix)]
227 pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
228 Ok(Self {
229 inner: Socket::from_socket2(Socket2::from(stream))?,
230 })
231 }
232
233 pub fn close(self) -> impl Future<Output = io::Result<()>> {
240 self.inner.close()
241 }
242
243 pub fn peer_addr(&self) -> io::Result<SockAddr> {
245 #[allow(unused_mut)]
246 let mut addr = self.inner.peer_addr()?;
247 #[cfg(windows)]
248 {
249 fix_unix_socket_length(&mut addr);
250 }
251 Ok(addr)
252 }
253
254 pub fn local_addr(&self) -> io::Result<SockAddr> {
256 self.inner.local_addr()
257 }
258
259 pub fn take_error(&self) -> io::Result<Option<io::Error>> {
261 self.inner.socket.take_error()
262 }
263
264 pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
271 crate::split(self)
272 }
273
274 pub fn into_split(self) -> (Self, Self) {
280 (self.clone(), self)
281 }
282
283 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
285 self.inner.to_poll_fd()
286 }
287
288 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
290 self.inner.into_poll_fd()
291 }
292
293 #[cfg(windows)]
298 pub async fn disconnect(self) -> io::Result<UnixSocket> {
299 self.inner.disconnect().await?;
300 Ok(UnixSocket { inner: self.inner })
301 }
302}
303
304impl AsyncRead for UnixStream {
305 #[inline]
306 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
307 (&*self).read(buf).await
308 }
309
310 #[inline]
311 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
312 (&*self).read_vectored(buf).await
313 }
314}
315
316impl AsyncRead for &UnixStream {
317 #[inline]
318 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
319 self.inner.recv(buf, RecvFlags::empty()).await
320 }
321
322 #[inline]
323 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
324 self.inner.recv_vectored(buf, RecvFlags::empty()).await
325 }
326}
327
328impl AsyncReadManaged for UnixStream {
329 type Buffer = BufferRef;
330
331 async fn read_managed(&mut self, len: usize) -> io::Result<Option<Self::Buffer>> {
332 (&*self).read_managed(len).await
333 }
334}
335
336impl AsyncReadManaged for &UnixStream {
337 type Buffer = BufferRef;
338
339 async fn read_managed(&mut self, len: usize) -> io::Result<Option<Self::Buffer>> {
340 self.inner.recv_managed(len, RecvFlags::empty()).await
341 }
342}
343
344impl AsyncReadMulti for UnixStream {
345 fn read_multi(&mut self, len: usize) -> impl Stream<Item = io::Result<Self::Buffer>> {
346 self.inner.recv_multi(len, RecvFlags::empty())
347 }
348}
349
350impl AsyncReadMulti for &UnixStream {
351 fn read_multi(&mut self, len: usize) -> impl Stream<Item = io::Result<Self::Buffer>> {
352 self.inner.recv_multi(len, RecvFlags::empty())
353 }
354}
355
356impl AsyncReadAncillary for UnixStream {
357 #[inline]
358 async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
359 &mut self,
360 buffer: T,
361 control: C,
362 ) -> BufResult<(usize, usize), (T, C)> {
363 (&*self).read_with_ancillary(buffer, control).await
364 }
365
366 #[inline]
367 async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
368 &mut self,
369 buffer: T,
370 control: C,
371 ) -> BufResult<(usize, usize), (T, C)> {
372 (&*self).read_vectored_with_ancillary(buffer, control).await
373 }
374}
375
376impl AsyncReadAncillary for &UnixStream {
377 #[inline]
378 async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
379 &mut self,
380 buffer: T,
381 control: C,
382 ) -> BufResult<(usize, usize), (T, C)> {
383 self.inner
384 .recv_msg(buffer, control, RecvFlags::empty())
385 .await
386 .map_res(|(res, len, _addr)| (res, len))
387 }
388
389 #[inline]
390 async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
391 &mut self,
392 buffer: T,
393 control: C,
394 ) -> BufResult<(usize, usize), (T, C)> {
395 self.inner
396 .recv_msg_vectored(buffer, control, RecvFlags::empty())
397 .await
398 .map_res(|(res, len, _addr)| (res, len))
399 }
400}
401
402impl AsyncReadAncillaryManaged for UnixStream {
403 #[inline]
404 async fn read_managed_with_ancillary<C: IoBufMut>(
405 &mut self,
406 len: usize,
407 control: C,
408 ) -> io::Result<Option<(Self::Buffer, C)>> {
409 (&*self).read_managed_with_ancillary(len, control).await
410 }
411}
412
413impl AsyncReadAncillaryManaged for &UnixStream {
414 #[inline]
415 async fn read_managed_with_ancillary<C: IoBufMut>(
416 &mut self,
417 len: usize,
418 control: C,
419 ) -> io::Result<Option<(Self::Buffer, C)>> {
420 self.inner
421 .recv_msg_managed(len, control, RecvFlags::empty())
422 .await
423 .map(|res| res.map(|(res, len, _addr)| (res, len)))
424 }
425}
426
427impl AsyncReadAncillaryMulti for UnixStream {
428 type Return = RecvMsgMultiResult;
429
430 #[inline]
431 fn read_multi_with_ancillary(
432 &mut self,
433 control_len: usize,
434 ) -> impl Stream<Item = io::Result<Self::Return>> {
435 self.inner.recv_msg_multi(control_len, RecvFlags::empty())
436 }
437}
438
439impl AsyncReadAncillaryMulti for &UnixStream {
440 type Return = RecvMsgMultiResult;
441
442 #[inline]
443 fn read_multi_with_ancillary(
444 &mut self,
445 control_len: usize,
446 ) -> impl Stream<Item = io::Result<Self::Return>> {
447 self.inner.recv_msg_multi(control_len, RecvFlags::empty())
448 }
449}
450
451impl AsyncWrite for UnixStream {
452 #[inline]
453 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
454 (&*self).write(buf).await
455 }
456
457 #[inline]
458 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
459 (&*self).write_vectored(buf).await
460 }
461
462 #[inline]
463 async fn flush(&mut self) -> io::Result<()> {
464 (&*self).flush().await
465 }
466
467 #[inline]
468 async fn shutdown(&mut self) -> io::Result<()> {
469 (&*self).shutdown().await
470 }
471}
472
473impl AsyncWrite for &UnixStream {
474 #[inline]
475 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
476 self.inner.send(buf, MSG_NOSIGNAL).await
477 }
478
479 #[inline]
480 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
481 self.inner.send_vectored(buf, MSG_NOSIGNAL).await
482 }
483
484 #[inline]
485 async fn flush(&mut self) -> io::Result<()> {
486 Ok(())
487 }
488
489 #[inline]
490 async fn shutdown(&mut self) -> io::Result<()> {
491 self.inner.shutdown().await
492 }
493}
494
495impl AsyncWriteZerocopy for UnixStream {
496 type BufferReadyFuture<T: IoBuf> = Zerocopy<SendZc<T, SharedFd<Socket2>>>;
497 type VectoredBufferReadyFuture<T: IoVectoredBuf> =
498 Zerocopy<SendVectoredZc<T, SharedFd<Socket2>>>;
499
500 async fn write_zerocopy<T: IoBuf>(
501 &mut self,
502 buf: T,
503 ) -> BufResult<usize, Self::BufferReadyFuture<T>> {
504 self.inner.send_zerocopy(buf, MSG_NOSIGNAL).await
505 }
506
507 async fn write_zerocopy_vectored<T: IoVectoredBuf>(
508 &mut self,
509 buf: T,
510 ) -> BufResult<usize, Self::VectoredBufferReadyFuture<T>> {
511 self.inner.send_zerocopy_vectored(buf, MSG_NOSIGNAL).await
512 }
513}
514
515impl AsyncWriteZerocopy for &UnixStream {
516 type BufferReadyFuture<T: IoBuf> = Zerocopy<SendZc<T, SharedFd<Socket2>>>;
517 type VectoredBufferReadyFuture<T: IoVectoredBuf> =
518 Zerocopy<SendVectoredZc<T, SharedFd<Socket2>>>;
519
520 async fn write_zerocopy<T: IoBuf>(
521 &mut self,
522 buf: T,
523 ) -> BufResult<usize, Self::BufferReadyFuture<T>> {
524 self.inner.send_zerocopy(buf, MSG_NOSIGNAL).await
525 }
526
527 async fn write_zerocopy_vectored<T: IoVectoredBuf>(
528 &mut self,
529 buf: T,
530 ) -> BufResult<usize, Self::VectoredBufferReadyFuture<T>> {
531 self.inner.send_zerocopy_vectored(buf, MSG_NOSIGNAL).await
532 }
533}
534
535impl AsyncWriteAncillaryZerocopy for UnixStream {
536 type BufferReadyFuture<T: IoBuf, C: IoBuf> =
537 Extract<Zerocopy<SendMsgZc<[T; 1], C, SharedFd<Socket2>>>, T, C>;
538 type VectoredBufferReadyFuture<T: IoVectoredBuf, C: IoBuf> =
539 Zerocopy<SendMsgZc<T, C, SharedFd<Socket2>>>;
540
541 async fn write_zerocopy_with_ancillary<T: IoBuf, C: IoBuf>(
542 &mut self,
543 buf: T,
544 control: C,
545 ) -> BufResult<usize, Self::BufferReadyFuture<T, C>> {
546 self.inner
547 .send_msg_zerocopy(buf, control, None, MSG_NOSIGNAL)
548 .await
549 }
550
551 async fn write_zerocopy_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
552 &mut self,
553 buf: T,
554 control: C,
555 ) -> BufResult<usize, Self::VectoredBufferReadyFuture<T, C>> {
556 self.inner
557 .send_msg_zerocopy_vectored(buf, control, None, MSG_NOSIGNAL)
558 .await
559 }
560}
561
562impl AsyncWriteAncillaryZerocopy for &UnixStream {
563 type BufferReadyFuture<T: IoBuf, C: IoBuf> =
564 Extract<Zerocopy<SendMsgZc<[T; 1], C, SharedFd<Socket2>>>, T, C>;
565 type VectoredBufferReadyFuture<T: IoVectoredBuf, C: IoBuf> =
566 Zerocopy<SendMsgZc<T, C, SharedFd<Socket2>>>;
567
568 async fn write_zerocopy_with_ancillary<T: IoBuf, C: IoBuf>(
569 &mut self,
570 buf: T,
571 control: C,
572 ) -> BufResult<usize, Self::BufferReadyFuture<T, C>> {
573 self.inner
574 .send_msg_zerocopy(buf, control, None, MSG_NOSIGNAL)
575 .await
576 }
577
578 async fn write_zerocopy_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
579 &mut self,
580 buf: T,
581 control: C,
582 ) -> BufResult<usize, Self::VectoredBufferReadyFuture<T, C>> {
583 self.inner
584 .send_msg_zerocopy_vectored(buf, control, None, MSG_NOSIGNAL)
585 .await
586 }
587}
588
589impl AsyncWriteAncillary for UnixStream {
639 #[inline]
640 async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
641 &mut self,
642 buffer: T,
643 control: C,
644 ) -> BufResult<usize, (T, C)> {
645 (&*self).write_with_ancillary(buffer, control).await
646 }
647
648 #[inline]
649 async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
650 &mut self,
651 buffer: T,
652 control: C,
653 ) -> BufResult<usize, (T, C)> {
654 (&*self)
655 .write_vectored_with_ancillary(buffer, control)
656 .await
657 }
658}
659
660impl AsyncWriteAncillary for &UnixStream {
661 #[inline]
662 async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
663 &mut self,
664 buffer: T,
665 control: C,
666 ) -> BufResult<usize, (T, C)> {
667 self.inner
668 .send_msg(buffer, control, None, MSG_NOSIGNAL)
669 .await
670 }
671
672 #[inline]
673 async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
674 &mut self,
675 buffer: T,
676 control: C,
677 ) -> BufResult<usize, (T, C)> {
678 self.inner
679 .send_msg_vectored(buffer, control, None, MSG_NOSIGNAL)
680 .await
681 }
682}
683
684impl Splittable for UnixStream {
685 type ReadHalf = Self;
686 type WriteHalf = Self;
687
688 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
689 self.into_split()
690 }
691}
692
693impl<'a> Splittable for &'a UnixStream {
694 type ReadHalf = ReadHalf<'a, UnixStream>;
695 type WriteHalf = WriteHalf<'a, UnixStream>;
696
697 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
698 crate::split(self)
699 }
700}
701
702impl<'a> Splittable for &'a mut UnixStream {
703 type ReadHalf = ReadHalf<'a, UnixStream>;
704 type WriteHalf = WriteHalf<'a, UnixStream>;
705
706 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
707 crate::split(self)
708 }
709}
710
711impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
712
713#[derive(Debug)]
716pub struct UnixSocket {
717 inner: Socket,
718}
719
720impl UnixSocket {
721 pub async fn new_stream() -> io::Result<UnixSocket> {
723 UnixSocket::new(socket2::Type::STREAM).await
724 }
725
726 async fn new(ty: socket2::Type) -> io::Result<UnixSocket> {
727 let inner = Socket::new(socket2::Domain::UNIX, ty, None).await?;
728 Ok(UnixSocket { inner })
729 }
730
731 pub fn local_addr(&self) -> io::Result<SockAddr> {
733 self.inner.local_addr()
734 }
735
736 pub fn take_error(&self) -> io::Result<Option<io::Error>> {
738 self.inner.socket.take_error()
739 }
740
741 pub async fn bind(&self, path: impl AsRef<Path>) -> io::Result<()> {
743 self.bind_addr(&SockAddr::unix(path)?).await
744 }
745
746 pub async fn bind_addr(&self, addr: &SockAddr) -> io::Result<()> {
748 if !addr.is_unix() {
749 return Err(io::Error::new(
750 io::ErrorKind::InvalidInput,
751 "addr is not unix socket address",
752 ));
753 }
754 self.inner.bind(addr).await
755 }
756
757 pub async fn listen(self, backlog: i32) -> io::Result<UnixListener> {
764 self.inner.listen(backlog).await?;
765 Ok(UnixListener { inner: self.inner })
766 }
767
768 pub async fn connect(self, path: impl AsRef<Path>) -> io::Result<UnixStream> {
773 self.connect_addr(&SockAddr::unix(path)?).await
774 }
775
776 pub async fn connect_addr(self, addr: &SockAddr) -> io::Result<UnixStream> {
786 if !addr.is_unix() {
787 return Err(io::Error::new(
788 io::ErrorKind::InvalidInput,
789 "addr is not unix socket address",
790 ));
791 }
792 self.inner.connect_async(addr).await?;
793 Ok(UnixStream { inner: self.inner })
794 }
795}
796
797impl_raw_fd!(UnixSocket, socket2::Socket, inner, socket);
798
799#[cfg(windows)]
800#[inline]
801fn empty_unix_socket() -> SockAddr {
802 use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
803
804 unsafe {
806 SockAddr::try_init(|addr, len| {
807 let addr: *mut SOCKADDR_UN = addr.cast();
808 std::ptr::write(
809 addr,
810 SOCKADDR_UN {
811 sun_family: AF_UNIX,
812 sun_path: [0; 108],
813 },
814 );
815 std::ptr::write(len, 3);
816 Ok(())
817 })
818 }
819 .unwrap()
821 .1
822}
823
824#[cfg(windows)]
828#[inline]
829fn fix_unix_socket_length(addr: &mut SockAddr) {
830 use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
831
832 let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
834 let sun_path = unsafe {
835 std::slice::from_raw_parts(
836 unix_addr.sun_path.as_ptr() as *const u8,
837 unix_addr.sun_path.len(),
838 )
839 };
840 let addr_len = match std::ffi::CStr::from_bytes_until_nul(sun_path) {
841 Ok(str) => str.to_bytes_with_nul().len() + 2,
842 Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
843 };
844 unsafe {
845 addr.set_length(addr_len as _);
846 }
847}