1use std::{
2 future::Future,
3 io,
4 path::Path,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
10use compio_driver::impl_raw_fd;
11use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
12use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd};
13use futures_util::{Stream, StreamExt, stream::FusedStream};
14use socket2::{SockAddr, Socket as Socket2, Type};
15
16use crate::{Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, WriteHalf};
17
18#[derive(Debug, Clone)]
47pub struct UnixListener {
48 inner: Socket,
49}
50
51impl UnixListener {
52 pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
56 Self::bind_addr(&SockAddr::unix(path)?).await
57 }
58
59 pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
63 Self::bind_with_options(addr, &SocketOpts::default()).await
64 }
65
66 pub async fn bind_with_options(addr: &SockAddr, opts: &SocketOpts) -> io::Result<Self> {
70 if !addr.is_unix() {
71 return Err(io::Error::new(
72 io::ErrorKind::InvalidInput,
73 "addr is not unix socket address",
74 ));
75 }
76
77 let socket = Socket::bind(addr, Type::STREAM, None).await?;
78 opts.setup_socket(&socket)?;
79 socket.listen(1024)?;
80 Ok(UnixListener { inner: socket })
81 }
82
83 #[cfg(unix)]
84 pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
86 Ok(Self {
87 inner: Socket::from_socket2(Socket2::from(stream))?,
88 })
89 }
90
91 pub fn close(self) -> impl Future<Output = io::Result<()>> {
94 self.inner.close()
95 }
96
97 pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
103 let (socket, addr) = self.inner.accept().await?;
104 let stream = UnixStream { inner: socket };
105 Ok((stream, addr))
106 }
107
108 pub async fn accept_with_options(
114 &self,
115 options: &SocketOpts,
116 ) -> io::Result<(UnixStream, SockAddr)> {
117 let (socket, addr) = self.inner.accept().await?;
118 options.setup_socket(&socket)?;
119 let stream = UnixStream { inner: socket };
120 Ok((stream, addr))
121 }
122
123 pub fn incoming(&self) -> UnixIncoming<'_> {
125 self.incoming_with_options(&SocketOpts::default())
126 }
127
128 pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> UnixIncoming<'a> {
131 UnixIncoming {
132 inner: self.inner.incoming(),
133 opts: *options,
134 }
135 }
136
137 pub fn local_addr(&self) -> io::Result<SockAddr> {
139 self.inner.local_addr()
140 }
141}
142
143impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
144
145pub struct UnixIncoming<'a> {
147 inner: Incoming<'a>,
148 opts: SocketOpts,
149}
150
151impl Stream for UnixIncoming<'_> {
152 type Item = io::Result<UnixStream>;
153
154 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
155 let this = self.get_mut();
156 this.inner.poll_next_unpin(cx).map(|res| {
157 res.map(|res| {
158 let socket = res?;
159 this.opts.setup_socket(&socket)?;
160 Ok(UnixStream { inner: socket })
161 })
162 })
163 }
164}
165
166impl FusedStream for UnixIncoming<'_> {
167 fn is_terminated(&self) -> bool {
168 self.inner.is_terminated()
169 }
170}
171
172#[derive(Debug, Clone)]
192pub struct UnixStream {
193 inner: Socket,
194}
195
196impl UnixStream {
197 pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
201 Self::connect_addr(&SockAddr::unix(path)?).await
202 }
203
204 pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
208 Self::connect_with_options(addr, &SocketOpts::default()).await
209 }
210
211 pub async fn connect_with_options(addr: &SockAddr, options: &SocketOpts) -> io::Result<Self> {
216 if !addr.is_unix() {
217 return Err(io::Error::new(
218 io::ErrorKind::InvalidInput,
219 "addr is not unix socket address",
220 ));
221 }
222
223 #[cfg(windows)]
224 let socket = {
225 let new_addr = empty_unix_socket();
226 Socket::bind(&new_addr, Type::STREAM, None).await?
227 };
228 #[cfg(unix)]
229 let socket = {
230 use socket2::Domain;
231 Socket::new(Domain::UNIX, Type::STREAM, None).await?
232 };
233 options.setup_socket(&socket)?;
234 socket.connect_async(addr).await?;
235 let unix_stream = UnixStream { inner: socket };
236 Ok(unix_stream)
237 }
238
239 #[cfg(unix)]
240 pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
242 Ok(Self {
243 inner: Socket::from_socket2(Socket2::from(stream))?,
244 })
245 }
246
247 pub fn close(self) -> impl Future<Output = io::Result<()>> {
250 self.inner.close()
251 }
252
253 pub fn peer_addr(&self) -> io::Result<SockAddr> {
255 #[allow(unused_mut)]
256 let mut addr = self.inner.peer_addr()?;
257 #[cfg(windows)]
258 {
259 fix_unix_socket_length(&mut addr);
260 }
261 Ok(addr)
262 }
263
264 pub fn local_addr(&self) -> io::Result<SockAddr> {
266 self.inner.local_addr()
267 }
268
269 pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
276 crate::split(self)
277 }
278
279 pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
285 crate::into_split(self)
286 }
287
288 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
290 self.inner.to_poll_fd()
291 }
292
293 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
295 self.inner.into_poll_fd()
296 }
297}
298
299impl AsyncRead for UnixStream {
300 #[inline]
301 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
302 (&*self).read(buf).await
303 }
304
305 #[inline]
306 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
307 (&*self).read_vectored(buf).await
308 }
309}
310
311impl AsyncRead for &UnixStream {
312 #[inline]
313 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
314 self.inner.recv(buf, 0).await
315 }
316
317 #[inline]
318 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
319 self.inner.recv_vectored(buf, 0).await
320 }
321}
322
323impl AsyncReadManaged for UnixStream {
324 type Buffer<'a> = BorrowedBuffer<'a>;
325 type BufferPool = BufferPool;
326
327 async fn read_managed<'a>(
328 &mut self,
329 buffer_pool: &'a Self::BufferPool,
330 len: usize,
331 ) -> io::Result<Self::Buffer<'a>> {
332 (&*self).read_managed(buffer_pool, len).await
333 }
334}
335
336impl AsyncReadManaged for &UnixStream {
337 type Buffer<'a> = BorrowedBuffer<'a>;
338 type BufferPool = BufferPool;
339
340 async fn read_managed<'a>(
341 &mut self,
342 buffer_pool: &'a Self::BufferPool,
343 len: usize,
344 ) -> io::Result<Self::Buffer<'a>> {
345 self.inner.recv_managed(buffer_pool, len as _, 0).await
346 }
347}
348
349impl AsyncWrite for UnixStream {
350 #[inline]
351 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
352 (&*self).write(buf).await
353 }
354
355 #[inline]
356 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
357 (&*self).write_vectored(buf).await
358 }
359
360 #[inline]
361 async fn flush(&mut self) -> io::Result<()> {
362 (&*self).flush().await
363 }
364
365 #[inline]
366 async fn shutdown(&mut self) -> io::Result<()> {
367 (&*self).shutdown().await
368 }
369}
370
371impl AsyncWrite for &UnixStream {
372 #[inline]
373 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
374 self.inner.send(buf, 0).await
375 }
376
377 #[inline]
378 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
379 self.inner.send_vectored(buf, 0).await
380 }
381
382 #[inline]
383 async fn flush(&mut self) -> io::Result<()> {
384 Ok(())
385 }
386
387 #[inline]
388 async fn shutdown(&mut self) -> io::Result<()> {
389 self.inner.shutdown().await
390 }
391}
392
393impl Splittable for UnixStream {
394 type ReadHalf = OwnedReadHalf<Self>;
395 type WriteHalf = OwnedWriteHalf<Self>;
396
397 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
398 crate::into_split(self)
399 }
400}
401
402impl<'a> Splittable for &'a UnixStream {
403 type ReadHalf = ReadHalf<'a, UnixStream>;
404 type WriteHalf = WriteHalf<'a, UnixStream>;
405
406 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
407 crate::split(self)
408 }
409}
410
411impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
412
413#[cfg(windows)]
414#[inline]
415fn empty_unix_socket() -> SockAddr {
416 use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
417
418 unsafe {
420 SockAddr::try_init(|addr, len| {
421 let addr: *mut SOCKADDR_UN = addr.cast();
422 std::ptr::write(
423 addr,
424 SOCKADDR_UN {
425 sun_family: AF_UNIX,
426 sun_path: [0; 108],
427 },
428 );
429 std::ptr::write(len, 3);
430 Ok(())
431 })
432 }
433 .unwrap()
435 .1
436}
437
438#[cfg(windows)]
442#[inline]
443fn fix_unix_socket_length(addr: &mut SockAddr) {
444 use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
445
446 let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
448 let sun_path = unsafe {
449 std::slice::from_raw_parts(
450 unix_addr.sun_path.as_ptr() as *const u8,
451 unix_addr.sun_path.len(),
452 )
453 };
454 let addr_len = match std::ffi::CStr::from_bytes_until_nul(sun_path) {
455 Ok(str) => str.to_bytes_with_nul().len() + 2,
456 Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
457 };
458 unsafe {
459 addr.set_length(addr_len as _);
460 }
461}