Skip to main content

compio_net/resolve/
mod.rs

1cfg_select! {
2    windows => {
3        #[path = "windows.rs"]
4        mod sys;
5    }
6    unix => {
7        #[path = "unix.rs"]
8        mod sys;
9    }
10    _ => {}
11}
12
13use std::{
14    future::{Future, Ready, ready},
15    io,
16    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
17};
18
19use compio_buf::{BufResult, buf_try};
20use either::Either;
21pub use sys::resolve_sock_addrs;
22
23/// A trait for objects which can be converted or resolved to one or more
24/// [`SocketAddr`] values.
25///
26/// See [`std::net::ToSocketAddrs`].
27///
28/// # Cancel safety
29///
30/// All implementation of [`ToSocketAddrsAsync`] in this crate is safe to cancel
31/// by dropping the future. The Glibc impl may leak the control block if the
32/// task is not completed when dropping.
33#[allow(async_fn_in_trait)]
34pub trait ToSocketAddrsAsync {
35    /// See [`std::net::ToSocketAddrs::Iter`].
36    type Iter: Iterator<Item = SocketAddr>;
37
38    /// See [`std::net::ToSocketAddrs::to_socket_addrs`].
39    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter>;
40}
41
42macro_rules! impl_to_socket_addrs_async {
43    ($($t:ty),* $(,)?) => {
44        $(
45            impl ToSocketAddrsAsync for $t {
46                type Iter = std::iter::Once<SocketAddr>;
47
48                async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
49                    Ok(std::iter::once(SocketAddr::from(*self)))
50                }
51            }
52        )*
53    }
54}
55
56impl_to_socket_addrs_async![
57    SocketAddr,
58    SocketAddrV4,
59    SocketAddrV6,
60    (IpAddr, u16),
61    (Ipv4Addr, u16),
62    (Ipv6Addr, u16),
63];
64
65impl ToSocketAddrsAsync for (&str, u16) {
66    type Iter = Either<std::iter::Once<SocketAddr>, std::vec::IntoIter<SocketAddr>>;
67
68    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
69        let (host, port) = self;
70        if let Ok(addr) = host.parse::<Ipv4Addr>() {
71            return Ok(Either::Left(std::iter::once(SocketAddr::from((
72                addr, *port,
73            )))));
74        }
75        if let Ok(addr) = host.parse::<Ipv6Addr>() {
76            return Ok(Either::Left(std::iter::once(SocketAddr::from((
77                addr, *port,
78            )))));
79        }
80
81        resolve_sock_addrs(host, *port).await.map(Either::Right)
82    }
83}
84
85impl ToSocketAddrsAsync for (String, u16) {
86    type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;
87
88    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
89        (&*self.0, self.1).to_socket_addrs_async().await
90    }
91}
92
93impl ToSocketAddrsAsync for str {
94    type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;
95
96    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
97        if let Ok(addr) = self.parse::<SocketAddr>() {
98            return Ok(Either::Left(std::iter::once(addr)));
99        }
100
101        let (host, port_str) = self
102            .rsplit_once(':')
103            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid socket address"))?;
104        let port: u16 = port_str
105            .parse()
106            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid port value"))?;
107        (host, port).to_socket_addrs_async().await
108    }
109}
110
111impl ToSocketAddrsAsync for String {
112    type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;
113
114    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
115        self.as_str().to_socket_addrs_async().await
116    }
117}
118
119impl<'a> ToSocketAddrsAsync for &'a [SocketAddr] {
120    type Iter = std::iter::Copied<std::slice::Iter<'a, SocketAddr>>;
121
122    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
123        Ok(self.iter().copied())
124    }
125}
126
127impl<T: ToSocketAddrsAsync + ?Sized> ToSocketAddrsAsync for &T {
128    type Iter = T::Iter;
129
130    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
131        (**self).to_socket_addrs_async().await
132    }
133}
134
135pub async fn each_addr<T, F: Future<Output = io::Result<T>>>(
136    addr: impl ToSocketAddrsAsync,
137    f: impl Fn(SocketAddr) -> F,
138) -> io::Result<T> {
139    let addrs = addr.to_socket_addrs_async().await?;
140    let mut last_err = None;
141    for addr in addrs {
142        match f(addr).await {
143            Ok(l) => return Ok(l),
144            Err(e) => last_err = Some(e),
145        }
146    }
147    Err(last_err.unwrap_or_else(|| {
148        io::Error::new(
149            io::ErrorKind::InvalidInput,
150            "could not resolve to any addresses",
151        )
152    }))
153}
154
155pub async fn first_addr_buf<T, B, F: Future<Output = BufResult<T, B>>>(
156    addr: impl ToSocketAddrsAsync,
157    buffer: B,
158    f: impl FnOnce(SocketAddr, B) -> F,
159) -> BufResult<T, B> {
160    let (mut addrs, buffer) = buf_try!(addr.to_socket_addrs_async().await, buffer);
161    if let Some(addr) = addrs.next() {
162        let (res, buffer) = buf_try!(f(addr, buffer).await);
163        BufResult(Ok(res), buffer)
164    } else {
165        BufResult(
166            Err(io::Error::new(
167                io::ErrorKind::InvalidInput,
168                "could not operate on first address",
169            )),
170            buffer,
171        )
172    }
173}
174
175pub async fn first_addr_buf_zerocopy<B, F1, F2>(
176    addr: impl ToSocketAddrsAsync,
177    buffer: B,
178    f: impl FnOnce(SocketAddr, B) -> F1,
179) -> BufResult<usize, Either<Ready<B>, F2>>
180where
181    F1: Future<Output = BufResult<usize, F2>>,
182    F2: Future<Output = B>,
183{
184    fn ret<T, F>(fut: T) -> Either<Ready<T>, F> {
185        Either::Left(ready(fut))
186    }
187
188    let mut addrs = match addr.to_socket_addrs_async().await {
189        Ok(addrs) => addrs,
190        Err(e) => return BufResult(Err(e), ret(buffer)),
191    };
192    if let Some(addr) = addrs.next() {
193        let BufResult(res, fut) = f(addr, buffer).await;
194        BufResult(res, Either::Right(fut))
195    } else {
196        BufResult(
197            Err(io::Error::new(
198                io::ErrorKind::InvalidInput,
199                "could not operate on first address",
200            )),
201            ret(buffer),
202        )
203    }
204}