Skip to main content

compio_net/
unix.rs

1use std::{
2    future::Future,
3    io,
4    path::Path,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
10use compio_driver::impl_raw_fd;
11use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
12use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd};
13use futures_util::{Stream, StreamExt, stream::FusedStream};
14use socket2::{SockAddr, Socket as Socket2, Type};
15
16use crate::{Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, WriteHalf};
17
18/// A Unix socket server, listening for connections.
19///
20/// You can accept a new connection by using the [`UnixListener::accept`]
21/// method.
22///
23/// # Examples
24///
25/// ```
26/// use compio_io::{AsyncReadExt, AsyncWriteExt};
27/// use compio_net::{UnixListener, UnixStream};
28/// use tempfile::tempdir;
29///
30/// let dir = tempdir().unwrap();
31/// let sock_file = dir.path().join("unix-server.sock");
32///
33/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
34/// let listener = UnixListener::bind(&sock_file).await.unwrap();
35///
36/// let (mut tx, (mut rx, _)) =
37///     futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()).unwrap();
38///
39/// tx.write_all("test").await.0.unwrap();
40///
41/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
42///
43/// assert_eq!(buf, b"test");
44/// # });
45/// ```
46#[derive(Debug, Clone)]
47pub struct UnixListener {
48    inner: Socket,
49}
50
51impl UnixListener {
52    /// Creates a new [`UnixListener`], which will be bound to the specified
53    /// file path. The file path cannot yet exist, and will be cleaned up
54    /// upon dropping [`UnixListener`].
55    pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
56        Self::bind_addr(&SockAddr::unix(path)?).await
57    }
58
59    /// Creates a new [`UnixListener`] with [`SockAddr`], which will be bound to
60    /// the specified file path. The file path cannot yet exist, and will be
61    /// cleaned up upon dropping [`UnixListener`].
62    pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
63        Self::bind_with_options(addr, &SocketOpts::default()).await
64    }
65
66    /// Creates a new [`UnixListener`] with [`SockAddr`] and [`SocketOpts`],
67    /// which will be bound to the specified file path. The file path cannot
68    /// yet exist, and will be cleaned up upon dropping [`UnixListener`].
69    pub async fn bind_with_options(addr: &SockAddr, opts: &SocketOpts) -> io::Result<Self> {
70        if !addr.is_unix() {
71            return Err(io::Error::new(
72                io::ErrorKind::InvalidInput,
73                "addr is not unix socket address",
74            ));
75        }
76
77        let socket = Socket::bind(addr, Type::STREAM, None).await?;
78        opts.setup_socket(&socket)?;
79        socket.listen(1024)?;
80        Ok(UnixListener { inner: socket })
81    }
82
83    #[cfg(unix)]
84    /// Creates new UnixListener from a [`std::os::unix::net::UnixListener`].
85    pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
86        Ok(Self {
87            inner: Socket::from_socket2(Socket2::from(stream))?,
88        })
89    }
90
91    /// Close the socket. If the returned future is dropped before polling, the
92    /// socket won't be closed.
93    pub fn close(self) -> impl Future<Output = io::Result<()>> {
94        self.inner.close()
95    }
96
97    /// Accepts a new incoming connection from this listener.
98    ///
99    /// This function will yield once a new Unix domain socket connection
100    /// is established. When established, the corresponding [`UnixStream`] and
101    /// will be returned.
102    pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
103        let (socket, addr) = self.inner.accept().await?;
104        let stream = UnixStream { inner: socket };
105        Ok((stream, addr))
106    }
107
108    /// Accepts a new incoming connection from this listener, and sets options.
109    ///
110    /// This function will yield once a new Unix domain socket connection
111    /// is established. When established, the corresponding [`UnixStream`] and
112    /// will be returned.
113    pub async fn accept_with_options(
114        &self,
115        options: &SocketOpts,
116    ) -> io::Result<(UnixStream, SockAddr)> {
117        let (socket, addr) = self.inner.accept().await?;
118        options.setup_socket(&socket)?;
119        let stream = UnixStream { inner: socket };
120        Ok((stream, addr))
121    }
122
123    /// Returns a stream of incoming connections to this listener.
124    pub fn incoming(&self) -> UnixIncoming<'_> {
125        self.incoming_with_options(&SocketOpts::default())
126    }
127
128    /// Returns a stream of incoming connections to this listener, and sets
129    /// options for each accepted connection.
130    pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> UnixIncoming<'a> {
131        UnixIncoming {
132            inner: self.inner.incoming(),
133            opts: *options,
134        }
135    }
136
137    /// Returns the local address that this listener is bound to.
138    pub fn local_addr(&self) -> io::Result<SockAddr> {
139        self.inner.local_addr()
140    }
141}
142
143impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
144
145/// A stream of incoming Unix connections.
146pub struct UnixIncoming<'a> {
147    inner: Incoming<'a>,
148    opts: SocketOpts,
149}
150
151impl Stream for UnixIncoming<'_> {
152    type Item = io::Result<UnixStream>;
153
154    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
155        let this = self.get_mut();
156        this.inner.poll_next_unpin(cx).map(|res| {
157            res.map(|res| {
158                let socket = res?;
159                this.opts.setup_socket(&socket)?;
160                Ok(UnixStream { inner: socket })
161            })
162        })
163    }
164}
165
166impl FusedStream for UnixIncoming<'_> {
167    fn is_terminated(&self) -> bool {
168        self.inner.is_terminated()
169    }
170}
171
172/// A Unix stream between two local sockets on Windows & WSL.
173///
174/// A Unix stream can either be created by connecting to an endpoint, via the
175/// `connect` method, or by accepting a connection from a listener.
176///
177/// # Examples
178///
179/// ```no_run
180/// use compio_io::AsyncWrite;
181/// use compio_net::UnixStream;
182///
183/// # compio_runtime::Runtime::new().unwrap().block_on(async {
184/// // Connect to a peer
185/// let mut stream = UnixStream::connect("unix-server.sock").await.unwrap();
186///
187/// // Write some data.
188/// stream.write("hello world!").await.unwrap();
189/// # })
190/// ```
191#[derive(Debug, Clone)]
192pub struct UnixStream {
193    inner: Socket,
194}
195
196impl UnixStream {
197    /// Opens a Unix connection to the specified file path. There must be a
198    /// [`UnixListener`] or equivalent listening on the corresponding Unix
199    /// domain socket to successfully connect and return a [`UnixStream`].
200    pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
201        Self::connect_addr(&SockAddr::unix(path)?).await
202    }
203
204    /// Opens a Unix connection to the specified address. There must be a
205    /// [`UnixListener`] or equivalent listening on the corresponding Unix
206    /// domain socket to successfully connect and return a [`UnixStream`].
207    pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
208        Self::connect_with_options(addr, &SocketOpts::default()).await
209    }
210
211    /// Opens a Unix connection to the specified address with [`SocketOpts`].
212    /// There must be a [`UnixListener`] or equivalent listening on the
213    /// corresponding Unix domain socket to successfully connect and return
214    /// a [`UnixStream`].
215    pub async fn connect_with_options(addr: &SockAddr, options: &SocketOpts) -> io::Result<Self> {
216        if !addr.is_unix() {
217            return Err(io::Error::new(
218                io::ErrorKind::InvalidInput,
219                "addr is not unix socket address",
220            ));
221        }
222
223        #[cfg(windows)]
224        let socket = {
225            let new_addr = empty_unix_socket();
226            Socket::bind(&new_addr, Type::STREAM, None).await?
227        };
228        #[cfg(unix)]
229        let socket = {
230            use socket2::Domain;
231            Socket::new(Domain::UNIX, Type::STREAM, None).await?
232        };
233        options.setup_socket(&socket)?;
234        socket.connect_async(addr).await?;
235        let unix_stream = UnixStream { inner: socket };
236        Ok(unix_stream)
237    }
238
239    #[cfg(unix)]
240    /// Creates new UnixStream from a [`std::os::unix::net::UnixStream`].
241    pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
242        Ok(Self {
243            inner: Socket::from_socket2(Socket2::from(stream))?,
244        })
245    }
246
247    /// Close the socket. If the returned future is dropped before polling, the
248    /// socket won't be closed.
249    pub fn close(self) -> impl Future<Output = io::Result<()>> {
250        self.inner.close()
251    }
252
253    /// Returns the socket path of the remote peer of this connection.
254    pub fn peer_addr(&self) -> io::Result<SockAddr> {
255        #[allow(unused_mut)]
256        let mut addr = self.inner.peer_addr()?;
257        #[cfg(windows)]
258        {
259            fix_unix_socket_length(&mut addr);
260        }
261        Ok(addr)
262    }
263
264    /// Returns the socket path of the local half of this connection.
265    pub fn local_addr(&self) -> io::Result<SockAddr> {
266        self.inner.local_addr()
267    }
268
269    /// Splits a [`UnixStream`] into a read half and a write half, which can be
270    /// used to read and write the stream concurrently.
271    ///
272    /// This method is more efficient than
273    /// [`into_split`](UnixStream::into_split), but the halves cannot
274    /// be moved into independently spawned tasks.
275    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
276        crate::split(self)
277    }
278
279    /// Splits a [`UnixStream`] into a read half and a write half, which can be
280    /// used to read and write the stream concurrently.
281    ///
282    /// Unlike [`split`](UnixStream::split), the owned halves can be moved to
283    /// separate tasks, however this comes at the cost of a heap allocation.
284    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
285        crate::into_split(self)
286    }
287
288    /// Create [`PollFd`] from inner socket.
289    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
290        self.inner.to_poll_fd()
291    }
292
293    /// Create [`PollFd`] from inner socket.
294    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
295        self.inner.into_poll_fd()
296    }
297}
298
299impl AsyncRead for UnixStream {
300    #[inline]
301    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
302        (&*self).read(buf).await
303    }
304
305    #[inline]
306    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
307        (&*self).read_vectored(buf).await
308    }
309}
310
311impl AsyncRead for &UnixStream {
312    #[inline]
313    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
314        self.inner.recv(buf, 0).await
315    }
316
317    #[inline]
318    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
319        self.inner.recv_vectored(buf, 0).await
320    }
321}
322
323impl AsyncReadManaged for UnixStream {
324    type Buffer<'a> = BorrowedBuffer<'a>;
325    type BufferPool = BufferPool;
326
327    async fn read_managed<'a>(
328        &mut self,
329        buffer_pool: &'a Self::BufferPool,
330        len: usize,
331    ) -> io::Result<Self::Buffer<'a>> {
332        (&*self).read_managed(buffer_pool, len).await
333    }
334}
335
336impl AsyncReadManaged for &UnixStream {
337    type Buffer<'a> = BorrowedBuffer<'a>;
338    type BufferPool = BufferPool;
339
340    async fn read_managed<'a>(
341        &mut self,
342        buffer_pool: &'a Self::BufferPool,
343        len: usize,
344    ) -> io::Result<Self::Buffer<'a>> {
345        self.inner.recv_managed(buffer_pool, len as _, 0).await
346    }
347}
348
349impl AsyncWrite for UnixStream {
350    #[inline]
351    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
352        (&*self).write(buf).await
353    }
354
355    #[inline]
356    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
357        (&*self).write_vectored(buf).await
358    }
359
360    #[inline]
361    async fn flush(&mut self) -> io::Result<()> {
362        (&*self).flush().await
363    }
364
365    #[inline]
366    async fn shutdown(&mut self) -> io::Result<()> {
367        (&*self).shutdown().await
368    }
369}
370
371impl AsyncWrite for &UnixStream {
372    #[inline]
373    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
374        self.inner.send(buf, 0).await
375    }
376
377    #[inline]
378    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
379        self.inner.send_vectored(buf, 0).await
380    }
381
382    #[inline]
383    async fn flush(&mut self) -> io::Result<()> {
384        Ok(())
385    }
386
387    #[inline]
388    async fn shutdown(&mut self) -> io::Result<()> {
389        self.inner.shutdown().await
390    }
391}
392
393impl Splittable for UnixStream {
394    type ReadHalf = OwnedReadHalf<Self>;
395    type WriteHalf = OwnedWriteHalf<Self>;
396
397    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
398        crate::into_split(self)
399    }
400}
401
402impl<'a> Splittable for &'a UnixStream {
403    type ReadHalf = ReadHalf<'a, UnixStream>;
404    type WriteHalf = WriteHalf<'a, UnixStream>;
405
406    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
407        crate::split(self)
408    }
409}
410
411impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
412
413#[cfg(windows)]
414#[inline]
415fn empty_unix_socket() -> SockAddr {
416    use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
417
418    // SAFETY: the length is correct
419    unsafe {
420        SockAddr::try_init(|addr, len| {
421            let addr: *mut SOCKADDR_UN = addr.cast();
422            std::ptr::write(
423                addr,
424                SOCKADDR_UN {
425                    sun_family: AF_UNIX,
426                    sun_path: [0; 108],
427                },
428            );
429            std::ptr::write(len, 3);
430            Ok(())
431        })
432    }
433    // it is always Ok
434    .unwrap()
435    .1
436}
437
438// The peer addr returned after ConnectEx is buggy. It contains bytes that
439// should not belong to the address. Luckily a unix path should not contain `\0`
440// until the end. We can determine the path ending by that.
441#[cfg(windows)]
442#[inline]
443fn fix_unix_socket_length(addr: &mut SockAddr) {
444    use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
445
446    // SAFETY: cannot construct non-unix socket address in safe way.
447    let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
448    let sun_path = unsafe {
449        std::slice::from_raw_parts(
450            unix_addr.sun_path.as_ptr() as *const u8,
451            unix_addr.sun_path.len(),
452        )
453    };
454    let addr_len = match std::ffi::CStr::from_bytes_until_nul(sun_path) {
455        Ok(str) => str.to_bytes_with_nul().len() + 2,
456        Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
457    };
458    unsafe {
459        addr.set_length(addr_len as _);
460    }
461}