compio_net/
tcp.rs

1use std::{future::Future, io, net::SocketAddr};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
6use compio_runtime::{BorrowedBuffer, BufferPool};
7use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
8
9use crate::{
10    OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync,
11    WriteHalf,
12};
13
14/// A TCP socket server, listening for connections.
15///
16/// You can accept a new connection by using the
17/// [`accept`](`TcpListener::accept`) method.
18///
19/// # Examples
20///
21/// ```
22/// use std::net::SocketAddr;
23///
24/// use compio_io::{AsyncReadExt, AsyncWriteExt};
25/// use compio_net::{TcpListener, TcpStream};
26/// use socket2::SockAddr;
27///
28/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
29/// let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
30///
31/// let addr = listener.local_addr().unwrap();
32///
33/// let tx_fut = TcpStream::connect(&addr);
34///
35/// let rx_fut = listener.accept();
36///
37/// let (mut tx, (mut rx, _)) = futures_util::try_join!(tx_fut, rx_fut).unwrap();
38///
39/// tx.write_all("test").await.0.unwrap();
40///
41/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
42///
43/// assert_eq!(buf, b"test");
44/// # });
45/// ```
46#[derive(Debug, Clone)]
47pub struct TcpListener {
48    inner: Socket,
49}
50
51impl TcpListener {
52    /// Creates a new `TcpListener`, which will be bound to the specified
53    /// address.
54    ///
55    /// The returned listener is ready for accepting connections.
56    ///
57    /// Binding with a port number of 0 will request that the OS assigns a port
58    /// to this listener.
59    ///
60    /// It enables the `SO_REUSEADDR` option by default.
61    pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
62        Self::bind_with_options(addr, &SocketOpts::default().reuse_address(true)).await
63    }
64
65    /// Creates a new `TcpListener`, which will be bound to the specified
66    /// address using `SocketOpts`.
67    ///
68    /// The returned listener is ready for accepting connections.
69    ///
70    /// Binding with a port number of 0 will request that the OS assigns a port
71    /// to this listener.
72    pub async fn bind_with_options(
73        addr: impl ToSocketAddrsAsync,
74        options: &SocketOpts,
75    ) -> io::Result<Self> {
76        super::each_addr(addr, |addr| async move {
77            let sa = SockAddr::from(addr);
78            let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
79            options.setup_socket(&socket)?;
80            socket.socket.bind(&sa)?;
81            socket.listen(128)?;
82            Ok(Self { inner: socket })
83        })
84        .await
85    }
86
87    /// Creates new TcpListener from a [`std::net::TcpListener`].
88    pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
89        Ok(Self {
90            inner: Socket::from_socket2(Socket2::from(stream))?,
91        })
92    }
93
94    /// Close the socket. If the returned future is dropped before polling, the
95    /// socket won't be closed.
96    pub fn close(self) -> impl Future<Output = io::Result<()>> {
97        self.inner.close()
98    }
99
100    /// Accepts a new incoming connection from this listener.
101    ///
102    /// This function will yield once a new TCP connection is established. When
103    /// established, the corresponding [`TcpStream`] and the remote peer's
104    /// address will be returned.
105    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
106        self.accept_with_options(&SocketOpts::default()).await
107    }
108
109    /// Accepts a new incoming connection from this listener, and sets options.
110    ///
111    /// This function will yield once a new TCP connection is established. When
112    /// established, the corresponding [`TcpStream`] and the remote peer's
113    /// address will be returned.
114    pub async fn accept_with_options(
115        &self,
116        options: &SocketOpts,
117    ) -> io::Result<(TcpStream, SocketAddr)> {
118        let (socket, addr) = self.inner.accept().await?;
119        options.setup_socket(&socket)?;
120        let stream = TcpStream { inner: socket };
121        Ok((stream, addr.as_socket().expect("should be SocketAddr")))
122    }
123
124    /// Returns the local address that this listener is bound to.
125    ///
126    /// This can be useful, for example, when binding to port 0 to
127    /// figure out which port was actually bound.
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
133    ///
134    /// use compio_net::TcpListener;
135    /// use socket2::SockAddr;
136    ///
137    /// # compio_runtime::Runtime::new().unwrap().block_on(async {
138    /// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
139    ///
140    /// let addr = listener.local_addr().expect("Couldn't get local address");
141    /// assert_eq!(
142    ///     addr,
143    ///     SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080))
144    /// );
145    /// # });
146    /// ```
147    pub fn local_addr(&self) -> io::Result<SocketAddr> {
148        self.inner
149            .local_addr()
150            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
151    }
152}
153
154impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
155
156/// A TCP stream between a local and a remote socket.
157///
158/// A TCP stream can either be created by connecting to an endpoint, via the
159/// `connect` method, or by accepting a connection from a listener.
160///
161/// # Examples
162///
163/// ```no_run
164/// use std::net::SocketAddr;
165///
166/// use compio_io::AsyncWrite;
167/// use compio_net::TcpStream;
168///
169/// # compio_runtime::Runtime::new().unwrap().block_on(async {
170/// // Connect to a peer
171/// let mut stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
172///
173/// // Write some data.
174/// stream.write("hello world!").await.unwrap();
175/// # })
176/// ```
177#[derive(Debug, Clone)]
178pub struct TcpStream {
179    inner: Socket,
180}
181
182impl TcpStream {
183    /// Opens a TCP connection to a remote host.
184    pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
185        Self::connect_with_options(addr, &SocketOpts::default()).await
186    }
187
188    /// Opens a TCP connection to a remote host using `SocketOpts`.
189    pub async fn connect_with_options(
190        addr: impl ToSocketAddrsAsync,
191        options: &SocketOpts,
192    ) -> io::Result<Self> {
193        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
194
195        super::each_addr(addr, |addr| async move {
196            let addr2 = SockAddr::from(addr);
197            let socket = if cfg!(windows) {
198                let bind_addr = if addr.is_ipv4() {
199                    SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
200                } else if addr.is_ipv6() {
201                    SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
202                } else {
203                    return Err(io::Error::new(
204                        io::ErrorKind::AddrNotAvailable,
205                        "Unsupported address domain.",
206                    ));
207                };
208                Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
209            } else {
210                Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
211            };
212            options.setup_socket(&socket)?;
213            socket.connect_async(&addr2).await?;
214            Ok(Self { inner: socket })
215        })
216        .await
217    }
218
219    /// Bind to `bind_addr` then opens a TCP connection to a remote host.
220    pub async fn bind_and_connect(
221        bind_addr: SocketAddr,
222        addr: impl ToSocketAddrsAsync,
223    ) -> io::Result<Self> {
224        Self::bind_and_connect_with_options(bind_addr, addr, &SocketOpts::default()).await
225    }
226
227    /// Bind to `bind_addr` then opens a TCP connection to a remote host using
228    /// `SocketOpts`.
229    pub async fn bind_and_connect_with_options(
230        bind_addr: SocketAddr,
231        addr: impl ToSocketAddrsAsync,
232        options: &SocketOpts,
233    ) -> io::Result<Self> {
234        super::each_addr(addr, |addr| async move {
235            let addr = SockAddr::from(addr);
236            let bind_addr = SockAddr::from(bind_addr);
237
238            let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
239            options.setup_socket(&socket)?;
240            socket.connect_async(&addr).await?;
241            Ok(Self { inner: socket })
242        })
243        .await
244    }
245
246    /// Creates new TcpStream from a [`std::net::TcpStream`].
247    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
248        Ok(Self {
249            inner: Socket::from_socket2(Socket2::from(stream))?,
250        })
251    }
252
253    /// Close the socket. If the returned future is dropped before polling, the
254    /// socket won't be closed.
255    pub fn close(self) -> impl Future<Output = io::Result<()>> {
256        self.inner.close()
257    }
258
259    /// Returns the socket address of the remote peer of this TCP connection.
260    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
261        self.inner
262            .peer_addr()
263            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
264    }
265
266    /// Returns the socket address of the local half of this TCP connection.
267    pub fn local_addr(&self) -> io::Result<SocketAddr> {
268        self.inner
269            .local_addr()
270            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
271    }
272
273    /// Splits a [`TcpStream`] into a read half and a write half, which can be
274    /// used to read and write the stream concurrently.
275    ///
276    /// This method is more efficient than
277    /// [`into_split`](TcpStream::into_split), but the halves cannot
278    /// be moved into independently spawned tasks.
279    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
280        crate::split(self)
281    }
282
283    /// Splits a [`TcpStream`] into a read half and a write half, which can be
284    /// used to read and write the stream concurrently.
285    ///
286    /// Unlike [`split`](TcpStream::split), the owned halves can be moved to
287    /// separate tasks, however this comes at the cost of a heap allocation.
288    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
289        crate::into_split(self)
290    }
291
292    /// Create [`PollFd`] from inner socket.
293    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
294        self.inner.to_poll_fd()
295    }
296
297    /// Create [`PollFd`] from inner socket.
298    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
299        self.inner.into_poll_fd()
300    }
301
302    /// Gets the value of the `TCP_NODELAY` option on this socket.
303    ///
304    /// For more information about this option, see
305    /// [`TcpStream::set_nodelay`].
306    pub fn nodelay(&self) -> io::Result<bool> {
307        self.inner.socket.tcp_nodelay()
308    }
309
310    /// Sets the value of the TCP_NODELAY option on this socket.
311    ///
312    /// If set, this option disables the Nagle algorithm. This means
313    /// that segments are always sent as soon as possible, even if
314    /// there is only a small amount of data. When not set, data is
315    /// buffered until there is a sufficient amount to send out,
316    /// thereby avoiding the frequent sending of small packets.
317    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
318        self.inner.socket.set_tcp_nodelay(nodelay)
319    }
320
321    /// Sends out-of-band data on this socket.
322    ///
323    /// Out-of-band data is sent with the `MSG_OOB` flag.
324    pub async fn send_out_of_band<T: IoBuf>(&self, buf: T) -> BufResult<usize, T> {
325        #[cfg(unix)]
326        use libc::MSG_OOB;
327        #[cfg(windows)]
328        use windows_sys::Win32::Networking::WinSock::MSG_OOB;
329
330        self.inner.send(buf, MSG_OOB).await
331    }
332}
333
334impl AsyncRead for TcpStream {
335    #[inline]
336    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
337        (&*self).read(buf).await
338    }
339
340    #[inline]
341    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
342        (&*self).read_vectored(buf).await
343    }
344}
345
346impl AsyncRead for &TcpStream {
347    #[inline]
348    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
349        self.inner.recv(buf, 0).await
350    }
351
352    #[inline]
353    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
354        self.inner.recv_vectored(buf, 0).await
355    }
356}
357
358impl AsyncReadManaged for TcpStream {
359    type Buffer<'a> = BorrowedBuffer<'a>;
360    type BufferPool = BufferPool;
361
362    async fn read_managed<'a>(
363        &mut self,
364        buffer_pool: &'a Self::BufferPool,
365        len: usize,
366    ) -> io::Result<Self::Buffer<'a>> {
367        (&*self).read_managed(buffer_pool, len).await
368    }
369}
370
371impl AsyncReadManaged for &TcpStream {
372    type Buffer<'a> = BorrowedBuffer<'a>;
373    type BufferPool = BufferPool;
374
375    async fn read_managed<'a>(
376        &mut self,
377        buffer_pool: &'a Self::BufferPool,
378        len: usize,
379    ) -> io::Result<Self::Buffer<'a>> {
380        self.inner.recv_managed(buffer_pool, len as _, 0).await
381    }
382}
383
384impl AsyncWrite for TcpStream {
385    #[inline]
386    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
387        (&*self).write(buf).await
388    }
389
390    #[inline]
391    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
392        (&*self).write_vectored(buf).await
393    }
394
395    #[inline]
396    async fn flush(&mut self) -> io::Result<()> {
397        (&*self).flush().await
398    }
399
400    #[inline]
401    async fn shutdown(&mut self) -> io::Result<()> {
402        (&*self).shutdown().await
403    }
404}
405
406impl AsyncWrite for &TcpStream {
407    #[inline]
408    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
409        self.inner.send(buf, 0).await
410    }
411
412    #[inline]
413    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
414        self.inner.send_vectored(buf, 0).await
415    }
416
417    #[inline]
418    async fn flush(&mut self) -> io::Result<()> {
419        Ok(())
420    }
421
422    #[inline]
423    async fn shutdown(&mut self) -> io::Result<()> {
424        self.inner.shutdown().await
425    }
426}
427
428impl Splittable for TcpStream {
429    type ReadHalf = OwnedReadHalf<Self>;
430    type WriteHalf = OwnedWriteHalf<Self>;
431
432    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
433        crate::into_split(self)
434    }
435}
436
437impl<'a> Splittable for &'a TcpStream {
438    type ReadHalf = ReadHalf<'a, TcpStream>;
439    type WriteHalf = WriteHalf<'a, TcpStream>;
440
441    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
442        crate::split(self)
443    }
444}
445
446impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);