1use std::{
2 future::Future,
3 io,
4 net::SocketAddr,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
10use compio_driver::impl_raw_fd;
11use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
12use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd};
13use futures_util::{Stream, StreamExt, stream::FusedStream};
14use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
15
16use crate::{
17 Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync,
18 WriteHalf,
19};
20
21#[derive(Debug, Clone)]
54pub struct TcpListener {
55 inner: Socket,
56}
57
58impl TcpListener {
59 pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
69 Self::bind_with_options(addr, &SocketOpts::default().reuse_address(true)).await
70 }
71
72 pub async fn bind_with_options(
80 addr: impl ToSocketAddrsAsync,
81 options: &SocketOpts,
82 ) -> io::Result<Self> {
83 super::each_addr(addr, |addr| async move {
84 let sa = SockAddr::from(addr);
85 let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
86 options.setup_socket(&socket)?;
87 socket.socket.bind(&sa)?;
88 socket.listen(128)?;
89 Ok(Self { inner: socket })
90 })
91 .await
92 }
93
94 pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
96 Ok(Self {
97 inner: Socket::from_socket2(Socket2::from(stream))?,
98 })
99 }
100
101 pub fn close(self) -> impl Future<Output = io::Result<()>> {
104 self.inner.close()
105 }
106
107 pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
113 self.accept_with_options(&SocketOpts::default()).await
114 }
115
116 pub async fn accept_with_options(
122 &self,
123 options: &SocketOpts,
124 ) -> io::Result<(TcpStream, SocketAddr)> {
125 let (socket, addr) = self.inner.accept().await?;
126 options.setup_socket(&socket)?;
127 let stream = TcpStream { inner: socket };
128 Ok((stream, addr.as_socket().expect("should be SocketAddr")))
129 }
130
131 pub fn incoming(&self) -> TcpIncoming<'_> {
133 self.incoming_with_options(&SocketOpts::default())
134 }
135
136 pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> TcpIncoming<'a> {
139 TcpIncoming {
140 inner: self.inner.incoming(),
141 opts: *options,
142 }
143 }
144
145 pub fn local_addr(&self) -> io::Result<SocketAddr> {
169 self.inner
170 .local_addr()
171 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
172 }
173}
174
175impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
176
177pub struct TcpIncoming<'a> {
179 inner: Incoming<'a>,
180 opts: SocketOpts,
181}
182
183impl Stream for TcpIncoming<'_> {
184 type Item = io::Result<TcpStream>;
185
186 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187 let this = self.get_mut();
188 this.inner.poll_next_unpin(cx).map(|res| {
189 res.map(|res| {
190 let socket = res?;
191 this.opts.setup_socket(&socket)?;
192 Ok(TcpStream { inner: socket })
193 })
194 })
195 }
196}
197
198impl FusedStream for TcpIncoming<'_> {
199 fn is_terminated(&self) -> bool {
200 self.inner.is_terminated()
201 }
202}
203
204#[derive(Debug, Clone)]
226pub struct TcpStream {
227 inner: Socket,
228}
229
230impl TcpStream {
231 pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
233 Self::connect_with_options(addr, &SocketOpts::default()).await
234 }
235
236 pub async fn connect_with_options(
238 addr: impl ToSocketAddrsAsync,
239 options: &SocketOpts,
240 ) -> io::Result<Self> {
241 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
242
243 super::each_addr(addr, |addr| async move {
244 let addr2 = SockAddr::from(addr);
245 let socket = if cfg!(windows) {
246 let bind_addr = if addr.is_ipv4() {
247 SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
248 } else if addr.is_ipv6() {
249 SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
250 } else {
251 return Err(io::Error::new(
252 io::ErrorKind::AddrNotAvailable,
253 "Unsupported address domain.",
254 ));
255 };
256 Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
257 } else {
258 Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
259 };
260 options.setup_socket(&socket)?;
261 socket.connect_async(&addr2).await?;
262 Ok(Self { inner: socket })
263 })
264 .await
265 }
266
267 pub async fn bind_and_connect(
269 bind_addr: SocketAddr,
270 addr: impl ToSocketAddrsAsync,
271 ) -> io::Result<Self> {
272 Self::bind_and_connect_with_options(bind_addr, addr, &SocketOpts::default()).await
273 }
274
275 pub async fn bind_and_connect_with_options(
278 bind_addr: SocketAddr,
279 addr: impl ToSocketAddrsAsync,
280 options: &SocketOpts,
281 ) -> io::Result<Self> {
282 super::each_addr(addr, |addr| async move {
283 let addr = SockAddr::from(addr);
284 let bind_addr = SockAddr::from(bind_addr);
285
286 let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
287 options.setup_socket(&socket)?;
288 socket.connect_async(&addr).await?;
289 Ok(Self { inner: socket })
290 })
291 .await
292 }
293
294 pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
296 Ok(Self {
297 inner: Socket::from_socket2(Socket2::from(stream))?,
298 })
299 }
300
301 pub fn close(self) -> impl Future<Output = io::Result<()>> {
304 self.inner.close()
305 }
306
307 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
309 self.inner
310 .peer_addr()
311 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
312 }
313
314 pub fn local_addr(&self) -> io::Result<SocketAddr> {
316 self.inner
317 .local_addr()
318 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
319 }
320
321 pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
328 crate::split(self)
329 }
330
331 pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
337 crate::into_split(self)
338 }
339
340 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
342 self.inner.to_poll_fd()
343 }
344
345 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
347 self.inner.into_poll_fd()
348 }
349
350 pub fn nodelay(&self) -> io::Result<bool> {
355 self.inner.socket.tcp_nodelay()
356 }
357
358 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
366 self.inner.socket.set_tcp_nodelay(nodelay)
367 }
368
369 pub async fn send_out_of_band<T: IoBuf>(&self, buf: T) -> BufResult<usize, T> {
373 #[cfg(unix)]
374 use libc::MSG_OOB;
375 #[cfg(windows)]
376 use windows_sys::Win32::Networking::WinSock::MSG_OOB;
377
378 self.inner.send(buf, MSG_OOB).await
379 }
380
381 pub async fn send_zerocopy<T: IoBuf>(
386 &self,
387 buf: T,
388 flags: i32,
389 ) -> BufResult<usize, impl Future<Output = T> + use<T>> {
390 self.inner.send_zerocopy(buf, flags).await
391 }
392
393 pub async fn send_zerocopy_vectored<T: IoVectoredBuf>(
398 &self,
399 buf: T,
400 flags: i32,
401 ) -> BufResult<usize, impl Future<Output = T> + use<T>> {
402 self.inner.send_zerocopy_vectored(buf, flags).await
403 }
404}
405
406impl AsyncRead for TcpStream {
407 #[inline]
408 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
409 (&*self).read(buf).await
410 }
411
412 #[inline]
413 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
414 (&*self).read_vectored(buf).await
415 }
416}
417
418impl AsyncRead for &TcpStream {
419 #[inline]
420 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
421 self.inner.recv(buf, 0).await
422 }
423
424 #[inline]
425 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
426 self.inner.recv_vectored(buf, 0).await
427 }
428}
429
430impl AsyncReadManaged for TcpStream {
431 type Buffer<'a> = BorrowedBuffer<'a>;
432 type BufferPool = BufferPool;
433
434 async fn read_managed<'a>(
435 &mut self,
436 buffer_pool: &'a Self::BufferPool,
437 len: usize,
438 ) -> io::Result<Self::Buffer<'a>> {
439 (&*self).read_managed(buffer_pool, len).await
440 }
441}
442
443impl AsyncReadManaged for &TcpStream {
444 type Buffer<'a> = BorrowedBuffer<'a>;
445 type BufferPool = BufferPool;
446
447 async fn read_managed<'a>(
448 &mut self,
449 buffer_pool: &'a Self::BufferPool,
450 len: usize,
451 ) -> io::Result<Self::Buffer<'a>> {
452 self.inner.recv_managed(buffer_pool, len as _, 0).await
453 }
454}
455
456impl AsyncWrite for TcpStream {
457 #[inline]
458 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
459 (&*self).write(buf).await
460 }
461
462 #[inline]
463 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
464 (&*self).write_vectored(buf).await
465 }
466
467 #[inline]
468 async fn flush(&mut self) -> io::Result<()> {
469 (&*self).flush().await
470 }
471
472 #[inline]
473 async fn shutdown(&mut self) -> io::Result<()> {
474 (&*self).shutdown().await
475 }
476}
477
478impl AsyncWrite for &TcpStream {
479 #[inline]
480 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
481 let BufResult(res, fut) = self.send_zerocopy(buf, 0).await;
482 BufResult(res, fut.await)
483 }
484
485 #[inline]
486 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
487 let BufResult(res, fut) = self.send_zerocopy_vectored(buf, 0).await;
488 BufResult(res, fut.await)
489 }
490
491 #[inline]
492 async fn flush(&mut self) -> io::Result<()> {
493 Ok(())
494 }
495
496 #[inline]
497 async fn shutdown(&mut self) -> io::Result<()> {
498 self.inner.shutdown().await
499 }
500}
501
502impl Splittable for TcpStream {
503 type ReadHalf = OwnedReadHalf<Self>;
504 type WriteHalf = OwnedWriteHalf<Self>;
505
506 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
507 crate::into_split(self)
508 }
509}
510
511impl<'a> Splittable for &'a TcpStream {
512 type ReadHalf = ReadHalf<'a, TcpStream>;
513 type WriteHalf = WriteHalf<'a, TcpStream>;
514
515 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
516 crate::split(self)
517 }
518}
519
520impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);