Skip to main content

compio_net/
tcp.rs

1use std::{
2    future::Future,
3    io,
4    net::SocketAddr,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
10use compio_driver::impl_raw_fd;
11use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
12use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd};
13use futures_util::{Stream, StreamExt, stream::FusedStream};
14use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
15
16use crate::{
17    Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync,
18    WriteHalf,
19};
20
21/// A TCP socket server, listening for connections.
22///
23/// You can accept a new connection by using the
24/// [`accept`](`TcpListener::accept`) method.
25///
26/// # Examples
27///
28/// ```
29/// use std::net::SocketAddr;
30///
31/// use compio_io::{AsyncReadExt, AsyncWriteExt};
32/// use compio_net::{TcpListener, TcpStream};
33/// use socket2::SockAddr;
34///
35/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
36/// let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
37///
38/// let addr = listener.local_addr().unwrap();
39///
40/// let tx_fut = TcpStream::connect(&addr);
41///
42/// let rx_fut = listener.accept();
43///
44/// let (mut tx, (mut rx, _)) = futures_util::try_join!(tx_fut, rx_fut).unwrap();
45///
46/// tx.write_all("test").await.0.unwrap();
47///
48/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
49///
50/// assert_eq!(buf, b"test");
51/// # });
52/// ```
53#[derive(Debug, Clone)]
54pub struct TcpListener {
55    inner: Socket,
56}
57
58impl TcpListener {
59    /// Creates a new `TcpListener`, which will be bound to the specified
60    /// address.
61    ///
62    /// The returned listener is ready for accepting connections.
63    ///
64    /// Binding with a port number of 0 will request that the OS assigns a port
65    /// to this listener.
66    ///
67    /// It enables the `SO_REUSEADDR` option by default.
68    pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
69        Self::bind_with_options(addr, &SocketOpts::default().reuse_address(true)).await
70    }
71
72    /// Creates a new `TcpListener`, which will be bound to the specified
73    /// address using `SocketOpts`.
74    ///
75    /// The returned listener is ready for accepting connections.
76    ///
77    /// Binding with a port number of 0 will request that the OS assigns a port
78    /// to this listener.
79    pub async fn bind_with_options(
80        addr: impl ToSocketAddrsAsync,
81        options: &SocketOpts,
82    ) -> io::Result<Self> {
83        super::each_addr(addr, |addr| async move {
84            let sa = SockAddr::from(addr);
85            let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
86            options.setup_socket(&socket)?;
87            socket.socket.bind(&sa)?;
88            socket.listen(128)?;
89            Ok(Self { inner: socket })
90        })
91        .await
92    }
93
94    /// Creates new TcpListener from a [`std::net::TcpListener`].
95    pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
96        Ok(Self {
97            inner: Socket::from_socket2(Socket2::from(stream))?,
98        })
99    }
100
101    /// Close the socket. If the returned future is dropped before polling, the
102    /// socket won't be closed.
103    pub fn close(self) -> impl Future<Output = io::Result<()>> {
104        self.inner.close()
105    }
106
107    /// Accepts a new incoming connection from this listener.
108    ///
109    /// This function will yield once a new TCP connection is established. When
110    /// established, the corresponding [`TcpStream`] and the remote peer's
111    /// address will be returned.
112    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
113        self.accept_with_options(&SocketOpts::default()).await
114    }
115
116    /// Accepts a new incoming connection from this listener, and sets options.
117    ///
118    /// This function will yield once a new TCP connection is established. When
119    /// established, the corresponding [`TcpStream`] and the remote peer's
120    /// address will be returned.
121    pub async fn accept_with_options(
122        &self,
123        options: &SocketOpts,
124    ) -> io::Result<(TcpStream, SocketAddr)> {
125        let (socket, addr) = self.inner.accept().await?;
126        options.setup_socket(&socket)?;
127        let stream = TcpStream { inner: socket };
128        Ok((stream, addr.as_socket().expect("should be SocketAddr")))
129    }
130
131    /// Returns a stream of incoming connections to this listener.
132    pub fn incoming(&self) -> TcpIncoming<'_> {
133        self.incoming_with_options(&SocketOpts::default())
134    }
135
136    /// Returns a stream of incoming connections to this listener, and sets
137    /// options for each accepted connection.
138    pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> TcpIncoming<'a> {
139        TcpIncoming {
140            inner: self.inner.incoming(),
141            opts: *options,
142        }
143    }
144
145    /// Returns the local address that this listener is bound to.
146    ///
147    /// This can be useful, for example, when binding to port 0 to
148    /// figure out which port was actually bound.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
154    ///
155    /// use compio_net::TcpListener;
156    /// use socket2::SockAddr;
157    ///
158    /// # compio_runtime::Runtime::new().unwrap().block_on(async {
159    /// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
160    ///
161    /// let addr = listener.local_addr().expect("Couldn't get local address");
162    /// assert_eq!(
163    ///     addr,
164    ///     SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080))
165    /// );
166    /// # });
167    /// ```
168    pub fn local_addr(&self) -> io::Result<SocketAddr> {
169        self.inner
170            .local_addr()
171            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
172    }
173}
174
175impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
176
177/// A stream of incoming TCP connections.
178pub struct TcpIncoming<'a> {
179    inner: Incoming<'a>,
180    opts: SocketOpts,
181}
182
183impl Stream for TcpIncoming<'_> {
184    type Item = io::Result<TcpStream>;
185
186    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187        let this = self.get_mut();
188        this.inner.poll_next_unpin(cx).map(|res| {
189            res.map(|res| {
190                let socket = res?;
191                this.opts.setup_socket(&socket)?;
192                Ok(TcpStream { inner: socket })
193            })
194        })
195    }
196}
197
198impl FusedStream for TcpIncoming<'_> {
199    fn is_terminated(&self) -> bool {
200        self.inner.is_terminated()
201    }
202}
203
204/// A TCP stream between a local and a remote socket.
205///
206/// A TCP stream can either be created by connecting to an endpoint, via the
207/// `connect` method, or by accepting a connection from a listener.
208///
209/// # Examples
210///
211/// ```no_run
212/// use std::net::SocketAddr;
213///
214/// use compio_io::AsyncWrite;
215/// use compio_net::TcpStream;
216///
217/// # compio_runtime::Runtime::new().unwrap().block_on(async {
218/// // Connect to a peer
219/// let mut stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
220///
221/// // Write some data.
222/// stream.write("hello world!").await.unwrap();
223/// # })
224/// ```
225#[derive(Debug, Clone)]
226pub struct TcpStream {
227    inner: Socket,
228}
229
230impl TcpStream {
231    /// Opens a TCP connection to a remote host.
232    pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
233        Self::connect_with_options(addr, &SocketOpts::default()).await
234    }
235
236    /// Opens a TCP connection to a remote host using `SocketOpts`.
237    pub async fn connect_with_options(
238        addr: impl ToSocketAddrsAsync,
239        options: &SocketOpts,
240    ) -> io::Result<Self> {
241        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
242
243        super::each_addr(addr, |addr| async move {
244            let addr2 = SockAddr::from(addr);
245            let socket = if cfg!(windows) {
246                let bind_addr = if addr.is_ipv4() {
247                    SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
248                } else if addr.is_ipv6() {
249                    SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
250                } else {
251                    return Err(io::Error::new(
252                        io::ErrorKind::AddrNotAvailable,
253                        "Unsupported address domain.",
254                    ));
255                };
256                Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
257            } else {
258                Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
259            };
260            options.setup_socket(&socket)?;
261            socket.connect_async(&addr2).await?;
262            Ok(Self { inner: socket })
263        })
264        .await
265    }
266
267    /// Bind to `bind_addr` then opens a TCP connection to a remote host.
268    pub async fn bind_and_connect(
269        bind_addr: SocketAddr,
270        addr: impl ToSocketAddrsAsync,
271    ) -> io::Result<Self> {
272        Self::bind_and_connect_with_options(bind_addr, addr, &SocketOpts::default()).await
273    }
274
275    /// Bind to `bind_addr` then opens a TCP connection to a remote host using
276    /// `SocketOpts`.
277    pub async fn bind_and_connect_with_options(
278        bind_addr: SocketAddr,
279        addr: impl ToSocketAddrsAsync,
280        options: &SocketOpts,
281    ) -> io::Result<Self> {
282        super::each_addr(addr, |addr| async move {
283            let addr = SockAddr::from(addr);
284            let bind_addr = SockAddr::from(bind_addr);
285
286            let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
287            options.setup_socket(&socket)?;
288            socket.connect_async(&addr).await?;
289            Ok(Self { inner: socket })
290        })
291        .await
292    }
293
294    /// Creates new TcpStream from a [`std::net::TcpStream`].
295    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
296        Ok(Self {
297            inner: Socket::from_socket2(Socket2::from(stream))?,
298        })
299    }
300
301    /// Close the socket. If the returned future is dropped before polling, the
302    /// socket won't be closed.
303    pub fn close(self) -> impl Future<Output = io::Result<()>> {
304        self.inner.close()
305    }
306
307    /// Returns the socket address of the remote peer of this TCP connection.
308    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
309        self.inner
310            .peer_addr()
311            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
312    }
313
314    /// Returns the socket address of the local half of this TCP connection.
315    pub fn local_addr(&self) -> io::Result<SocketAddr> {
316        self.inner
317            .local_addr()
318            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
319    }
320
321    /// Splits a [`TcpStream`] into a read half and a write half, which can be
322    /// used to read and write the stream concurrently.
323    ///
324    /// This method is more efficient than
325    /// [`into_split`](TcpStream::into_split), but the halves cannot
326    /// be moved into independently spawned tasks.
327    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
328        crate::split(self)
329    }
330
331    /// Splits a [`TcpStream`] into a read half and a write half, which can be
332    /// used to read and write the stream concurrently.
333    ///
334    /// Unlike [`split`](TcpStream::split), the owned halves can be moved to
335    /// separate tasks, however this comes at the cost of a heap allocation.
336    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
337        crate::into_split(self)
338    }
339
340    /// Create [`PollFd`] from inner socket.
341    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
342        self.inner.to_poll_fd()
343    }
344
345    /// Create [`PollFd`] from inner socket.
346    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
347        self.inner.into_poll_fd()
348    }
349
350    /// Gets the value of the `TCP_NODELAY` option on this socket.
351    ///
352    /// For more information about this option, see
353    /// [`TcpStream::set_nodelay`].
354    pub fn nodelay(&self) -> io::Result<bool> {
355        self.inner.socket.tcp_nodelay()
356    }
357
358    /// Sets the value of the TCP_NODELAY option on this socket.
359    ///
360    /// If set, this option disables the Nagle algorithm. This means
361    /// that segments are always sent as soon as possible, even if
362    /// there is only a small amount of data. When not set, data is
363    /// buffered until there is a sufficient amount to send out,
364    /// thereby avoiding the frequent sending of small packets.
365    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
366        self.inner.socket.set_tcp_nodelay(nodelay)
367    }
368
369    /// Sends out-of-band data on this socket.
370    ///
371    /// Out-of-band data is sent with the `MSG_OOB` flag.
372    pub async fn send_out_of_band<T: IoBuf>(&self, buf: T) -> BufResult<usize, T> {
373        #[cfg(unix)]
374        use libc::MSG_OOB;
375        #[cfg(windows)]
376        use windows_sys::Win32::Networking::WinSock::MSG_OOB;
377
378        self.inner.send(buf, MSG_OOB).await
379    }
380
381    /// Sends data using [zero-copy send](https://man7.org/linux/man-pages/man3/io_uring_prep_send_zc.3.html).
382    ///
383    /// If the underlying platform doesn't support zero-copy send, it will fall
384    /// back to normal send.
385    pub async fn send_zerocopy<T: IoBuf>(
386        &self,
387        buf: T,
388        flags: i32,
389    ) -> BufResult<usize, impl Future<Output = T> + use<T>> {
390        self.inner.send_zerocopy(buf, flags).await
391    }
392
393    /// Sends vectorized data using [zero-copy send](https://man7.org/linux/man-pages/man3/io_uring_prep_send_zc.3.html).
394    ///
395    /// If the underlying platform doesn't support zero-copy send, it will fall
396    /// back to normal send.
397    pub async fn send_zerocopy_vectored<T: IoVectoredBuf>(
398        &self,
399        buf: T,
400        flags: i32,
401    ) -> BufResult<usize, impl Future<Output = T> + use<T>> {
402        self.inner.send_zerocopy_vectored(buf, flags).await
403    }
404}
405
406impl AsyncRead for TcpStream {
407    #[inline]
408    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
409        (&*self).read(buf).await
410    }
411
412    #[inline]
413    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
414        (&*self).read_vectored(buf).await
415    }
416}
417
418impl AsyncRead for &TcpStream {
419    #[inline]
420    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
421        self.inner.recv(buf, 0).await
422    }
423
424    #[inline]
425    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
426        self.inner.recv_vectored(buf, 0).await
427    }
428}
429
430impl AsyncReadManaged for TcpStream {
431    type Buffer<'a> = BorrowedBuffer<'a>;
432    type BufferPool = BufferPool;
433
434    async fn read_managed<'a>(
435        &mut self,
436        buffer_pool: &'a Self::BufferPool,
437        len: usize,
438    ) -> io::Result<Self::Buffer<'a>> {
439        (&*self).read_managed(buffer_pool, len).await
440    }
441}
442
443impl AsyncReadManaged for &TcpStream {
444    type Buffer<'a> = BorrowedBuffer<'a>;
445    type BufferPool = BufferPool;
446
447    async fn read_managed<'a>(
448        &mut self,
449        buffer_pool: &'a Self::BufferPool,
450        len: usize,
451    ) -> io::Result<Self::Buffer<'a>> {
452        self.inner.recv_managed(buffer_pool, len as _, 0).await
453    }
454}
455
456impl AsyncWrite for TcpStream {
457    #[inline]
458    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
459        (&*self).write(buf).await
460    }
461
462    #[inline]
463    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
464        (&*self).write_vectored(buf).await
465    }
466
467    #[inline]
468    async fn flush(&mut self) -> io::Result<()> {
469        (&*self).flush().await
470    }
471
472    #[inline]
473    async fn shutdown(&mut self) -> io::Result<()> {
474        (&*self).shutdown().await
475    }
476}
477
478impl AsyncWrite for &TcpStream {
479    #[inline]
480    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
481        let BufResult(res, fut) = self.send_zerocopy(buf, 0).await;
482        BufResult(res, fut.await)
483    }
484
485    #[inline]
486    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
487        let BufResult(res, fut) = self.send_zerocopy_vectored(buf, 0).await;
488        BufResult(res, fut.await)
489    }
490
491    #[inline]
492    async fn flush(&mut self) -> io::Result<()> {
493        Ok(())
494    }
495
496    #[inline]
497    async fn shutdown(&mut self) -> io::Result<()> {
498        self.inner.shutdown().await
499    }
500}
501
502impl Splittable for TcpStream {
503    type ReadHalf = OwnedReadHalf<Self>;
504    type WriteHalf = OwnedWriteHalf<Self>;
505
506    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
507        crate::into_split(self)
508    }
509}
510
511impl<'a> Splittable for &'a TcpStream {
512    type ReadHalf = ReadHalf<'a, TcpStream>;
513    type WriteHalf = WriteHalf<'a, TcpStream>;
514
515    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
516        crate::split(self)
517    }
518}
519
520impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);