compio_net/
unix.rs

1use std::{future::Future, io, path::Path};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
6use compio_runtime::{BorrowedBuffer, BufferPool};
7use socket2::{SockAddr, Socket as Socket2, Type};
8
9use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, SocketOpts, WriteHalf};
10
11/// A Unix socket server, listening for connections.
12///
13/// You can accept a new connection by using the [`UnixListener::accept`]
14/// method.
15///
16/// # Examples
17///
18/// ```
19/// use compio_io::{AsyncReadExt, AsyncWriteExt};
20/// use compio_net::{UnixListener, UnixStream};
21/// use tempfile::tempdir;
22///
23/// let dir = tempdir().unwrap();
24/// let sock_file = dir.path().join("unix-server.sock");
25///
26/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
27/// let listener = UnixListener::bind(&sock_file).await.unwrap();
28///
29/// let (mut tx, (mut rx, _)) =
30///     futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()).unwrap();
31///
32/// tx.write_all("test").await.0.unwrap();
33///
34/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
35///
36/// assert_eq!(buf, b"test");
37/// # });
38/// ```
39#[derive(Debug, Clone)]
40pub struct UnixListener {
41    inner: Socket,
42}
43
44impl UnixListener {
45    /// Creates a new [`UnixListener`], which will be bound to the specified
46    /// file path. The file path cannot yet exist, and will be cleaned up
47    /// upon dropping [`UnixListener`].
48    pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
49        Self::bind_addr(&SockAddr::unix(path)?).await
50    }
51
52    /// Creates a new [`UnixListener`] with [`SockAddr`], which will be bound to
53    /// the specified file path. The file path cannot yet exist, and will be
54    /// cleaned up upon dropping [`UnixListener`].
55    pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
56        Self::bind_with_options(addr, &SocketOpts::default()).await
57    }
58
59    /// Creates a new [`UnixListener`] with [`SockAddr`] and [`SocketOpts`],
60    /// which will be bound to the specified file path. The file path cannot
61    /// yet exist, and will be cleaned up upon dropping [`UnixListener`].
62    pub async fn bind_with_options(addr: &SockAddr, opts: &SocketOpts) -> io::Result<Self> {
63        if !addr.is_unix() {
64            return Err(io::Error::new(
65                io::ErrorKind::InvalidInput,
66                "addr is not unix socket address",
67            ));
68        }
69
70        let socket = Socket::bind(addr, Type::STREAM, None).await?;
71        opts.setup_socket(&socket)?;
72        socket.listen(1024)?;
73        Ok(UnixListener { inner: socket })
74    }
75
76    #[cfg(unix)]
77    /// Creates new UnixListener from a [`std::os::unix::net::UnixListener`].
78    pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
79        Ok(Self {
80            inner: Socket::from_socket2(Socket2::from(stream))?,
81        })
82    }
83
84    /// Close the socket. If the returned future is dropped before polling, the
85    /// socket won't be closed.
86    pub fn close(self) -> impl Future<Output = io::Result<()>> {
87        self.inner.close()
88    }
89
90    /// Accepts a new incoming connection from this listener.
91    ///
92    /// This function will yield once a new Unix domain socket connection
93    /// is established. When established, the corresponding [`UnixStream`] and
94    /// will be returned.
95    pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
96        let (socket, addr) = self.inner.accept().await?;
97        let stream = UnixStream { inner: socket };
98        Ok((stream, addr))
99    }
100
101    /// Accepts a new incoming connection from this listener, and sets options.
102    ///
103    /// This function will yield once a new Unix domain socket connection
104    /// is established. When established, the corresponding [`UnixStream`] and
105    /// will be returned.
106    pub async fn accept_with_options(
107        &self,
108        options: &SocketOpts,
109    ) -> io::Result<(UnixStream, SockAddr)> {
110        let (socket, addr) = self.inner.accept().await?;
111        options.setup_socket(&socket)?;
112        let stream = UnixStream { inner: socket };
113        Ok((stream, addr))
114    }
115
116    /// Returns the local address that this listener is bound to.
117    pub fn local_addr(&self) -> io::Result<SockAddr> {
118        self.inner.local_addr()
119    }
120}
121
122impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
123
124/// A Unix stream between two local sockets on Windows & WSL.
125///
126/// A Unix stream can either be created by connecting to an endpoint, via the
127/// `connect` method, or by accepting a connection from a listener.
128///
129/// # Examples
130///
131/// ```no_run
132/// use compio_io::AsyncWrite;
133/// use compio_net::UnixStream;
134///
135/// # compio_runtime::Runtime::new().unwrap().block_on(async {
136/// // Connect to a peer
137/// let mut stream = UnixStream::connect("unix-server.sock").await.unwrap();
138///
139/// // Write some data.
140/// stream.write("hello world!").await.unwrap();
141/// # })
142/// ```
143#[derive(Debug, Clone)]
144pub struct UnixStream {
145    inner: Socket,
146}
147
148impl UnixStream {
149    /// Opens a Unix connection to the specified file path. There must be a
150    /// [`UnixListener`] or equivalent listening on the corresponding Unix
151    /// domain socket to successfully connect and return a [`UnixStream`].
152    pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
153        Self::connect_addr(&SockAddr::unix(path)?).await
154    }
155
156    /// Opens a Unix connection to the specified address. There must be a
157    /// [`UnixListener`] or equivalent listening on the corresponding Unix
158    /// domain socket to successfully connect and return a [`UnixStream`].
159    pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
160        Self::connect_with_options(addr, &SocketOpts::default()).await
161    }
162
163    /// Opens a Unix connection to the specified address with [`SocketOpts`].
164    /// There must be a [`UnixListener`] or equivalent listening on the
165    /// corresponding Unix domain socket to successfully connect and return
166    /// a [`UnixStream`].
167    pub async fn connect_with_options(addr: &SockAddr, options: &SocketOpts) -> io::Result<Self> {
168        if !addr.is_unix() {
169            return Err(io::Error::new(
170                io::ErrorKind::InvalidInput,
171                "addr is not unix socket address",
172            ));
173        }
174
175        #[cfg(windows)]
176        let socket = {
177            let new_addr = empty_unix_socket();
178            Socket::bind(&new_addr, Type::STREAM, None).await?
179        };
180        #[cfg(unix)]
181        let socket = {
182            use socket2::Domain;
183            Socket::new(Domain::UNIX, Type::STREAM, None).await?
184        };
185        options.setup_socket(&socket)?;
186        socket.connect_async(addr).await?;
187        let unix_stream = UnixStream { inner: socket };
188        Ok(unix_stream)
189    }
190
191    #[cfg(unix)]
192    /// Creates new UnixStream from a [`std::os::unix::net::UnixStream`].
193    pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
194        Ok(Self {
195            inner: Socket::from_socket2(Socket2::from(stream))?,
196        })
197    }
198
199    /// Close the socket. If the returned future is dropped before polling, the
200    /// socket won't be closed.
201    pub fn close(self) -> impl Future<Output = io::Result<()>> {
202        self.inner.close()
203    }
204
205    /// Returns the socket path of the remote peer of this connection.
206    pub fn peer_addr(&self) -> io::Result<SockAddr> {
207        #[allow(unused_mut)]
208        let mut addr = self.inner.peer_addr()?;
209        #[cfg(windows)]
210        {
211            fix_unix_socket_length(&mut addr);
212        }
213        Ok(addr)
214    }
215
216    /// Returns the socket path of the local half of this connection.
217    pub fn local_addr(&self) -> io::Result<SockAddr> {
218        self.inner.local_addr()
219    }
220
221    /// Splits a [`UnixStream`] into a read half and a write half, which can be
222    /// used to read and write the stream concurrently.
223    ///
224    /// This method is more efficient than
225    /// [`into_split`](UnixStream::into_split), but the halves cannot
226    /// be moved into independently spawned tasks.
227    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
228        crate::split(self)
229    }
230
231    /// Splits a [`UnixStream`] into a read half and a write half, which can be
232    /// used to read and write the stream concurrently.
233    ///
234    /// Unlike [`split`](UnixStream::split), the owned halves can be moved to
235    /// separate tasks, however this comes at the cost of a heap allocation.
236    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
237        crate::into_split(self)
238    }
239
240    /// Create [`PollFd`] from inner socket.
241    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
242        self.inner.to_poll_fd()
243    }
244
245    /// Create [`PollFd`] from inner socket.
246    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
247        self.inner.into_poll_fd()
248    }
249}
250
251impl AsyncRead for UnixStream {
252    #[inline]
253    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
254        (&*self).read(buf).await
255    }
256
257    #[inline]
258    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
259        (&*self).read_vectored(buf).await
260    }
261}
262
263impl AsyncRead for &UnixStream {
264    #[inline]
265    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
266        self.inner.recv(buf, 0).await
267    }
268
269    #[inline]
270    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
271        self.inner.recv_vectored(buf, 0).await
272    }
273}
274
275impl AsyncReadManaged for UnixStream {
276    type Buffer<'a> = BorrowedBuffer<'a>;
277    type BufferPool = BufferPool;
278
279    async fn read_managed<'a>(
280        &mut self,
281        buffer_pool: &'a Self::BufferPool,
282        len: usize,
283    ) -> io::Result<Self::Buffer<'a>> {
284        (&*self).read_managed(buffer_pool, len).await
285    }
286}
287
288impl AsyncReadManaged for &UnixStream {
289    type Buffer<'a> = BorrowedBuffer<'a>;
290    type BufferPool = BufferPool;
291
292    async fn read_managed<'a>(
293        &mut self,
294        buffer_pool: &'a Self::BufferPool,
295        len: usize,
296    ) -> io::Result<Self::Buffer<'a>> {
297        self.inner.recv_managed(buffer_pool, len as _, 0).await
298    }
299}
300
301impl AsyncWrite for UnixStream {
302    #[inline]
303    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
304        (&*self).write(buf).await
305    }
306
307    #[inline]
308    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
309        (&*self).write_vectored(buf).await
310    }
311
312    #[inline]
313    async fn flush(&mut self) -> io::Result<()> {
314        (&*self).flush().await
315    }
316
317    #[inline]
318    async fn shutdown(&mut self) -> io::Result<()> {
319        (&*self).shutdown().await
320    }
321}
322
323impl AsyncWrite for &UnixStream {
324    #[inline]
325    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
326        self.inner.send(buf, 0).await
327    }
328
329    #[inline]
330    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
331        self.inner.send_vectored(buf, 0).await
332    }
333
334    #[inline]
335    async fn flush(&mut self) -> io::Result<()> {
336        Ok(())
337    }
338
339    #[inline]
340    async fn shutdown(&mut self) -> io::Result<()> {
341        self.inner.shutdown().await
342    }
343}
344
345impl Splittable for UnixStream {
346    type ReadHalf = OwnedReadHalf<Self>;
347    type WriteHalf = OwnedWriteHalf<Self>;
348
349    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
350        crate::into_split(self)
351    }
352}
353
354impl<'a> Splittable for &'a UnixStream {
355    type ReadHalf = ReadHalf<'a, UnixStream>;
356    type WriteHalf = WriteHalf<'a, UnixStream>;
357
358    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
359        crate::split(self)
360    }
361}
362
363impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
364
365#[cfg(windows)]
366#[inline]
367fn empty_unix_socket() -> SockAddr {
368    use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
369
370    // SAFETY: the length is correct
371    unsafe {
372        SockAddr::try_init(|addr, len| {
373            let addr: *mut SOCKADDR_UN = addr.cast();
374            std::ptr::write(
375                addr,
376                SOCKADDR_UN {
377                    sun_family: AF_UNIX,
378                    sun_path: [0; 108],
379                },
380            );
381            std::ptr::write(len, 3);
382            Ok(())
383        })
384    }
385    // it is always Ok
386    .unwrap()
387    .1
388}
389
390// The peer addr returned after ConnectEx is buggy. It contains bytes that
391// should not belong to the address. Luckily a unix path should not contain `\0`
392// until the end. We can determine the path ending by that.
393#[cfg(windows)]
394#[inline]
395fn fix_unix_socket_length(addr: &mut SockAddr) {
396    use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
397
398    // SAFETY: cannot construct non-unix socket address in safe way.
399    let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
400    let sun_path = unsafe {
401        std::slice::from_raw_parts(
402            unix_addr.sun_path.as_ptr() as *const u8,
403            unix_addr.sun_path.len(),
404        )
405    };
406    let addr_len = match std::ffi::CStr::from_bytes_until_nul(sun_path) {
407        Ok(str) => str.to_bytes_with_nul().len() + 2,
408        Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
409    };
410    unsafe {
411        addr.set_length(addr_len as _);
412    }
413}