Skip to content

Commit 8919175

Browse files
Add TcpOrUnixStream and SocketAddrOrUnixSocketAddr, allow using SOCKS5 over unix-domain sockets
1 parent 4e7463d commit 8919175

File tree

2 files changed

+355
-22
lines changed

2 files changed

+355
-22
lines changed

src/lib.rs

Lines changed: 299 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,312 @@ extern crate libc;
1010
extern crate winapi;
1111

1212
use std::io;
13-
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
13+
use std::net::{
14+
Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpStream, ToSocketAddrs,
15+
};
16+
#[cfg(unix)]
17+
use std::os::unix::net::{UnixStream, SocketAddr as UnixSocketAddr};
18+
#[cfg(any(target_os = "linux", target_os = "android"))]
19+
use std::os::linux::net::SocketAddrExt;
20+
use std::hash::{Hash, Hasher};
1421
use std::vec;
1522

16-
pub use v4::{Socks4Stream, Socks4Listener};
17-
pub use v5::{Socks5Stream, Socks5Listener, Socks5Datagram};
23+
pub use v4::{Socks4Listener, Socks4Stream};
24+
pub use v5::{Socks5Datagram, Socks5Listener, Socks5Stream};
1825

1926
mod v4;
2027
mod v5;
2128
mod writev;
2229

30+
/// Either a [`SocketAddr`], or, under unix, [`UnixSocketAddr`]
31+
///
32+
/// If `#[cfg(unix)]`, this can hold an internet socket address *or* a unix-domain socket address.
33+
///
34+
/// Otherwise, this can only hold an internet socket address.
35+
#[derive(Clone, Debug)]
36+
pub enum SocketAddrOrUnixSocketAddr {
37+
/// The internet address.
38+
SocketAddr(SocketAddr),
39+
/// The unix-domain address.
40+
#[cfg(unix)]
41+
UnixSocketAddr(UnixSocketAddr),
42+
}
43+
44+
impl PartialEq for SocketAddrOrUnixSocketAddr {
45+
fn eq(&self, other: &Self) -> bool {
46+
match (self, other) {
47+
#[cfg(unix)]
48+
(Self::SocketAddr(_), Self::UnixSocketAddr(_)) => false,
49+
#[cfg(unix)]
50+
(Self::UnixSocketAddr(_), Self::SocketAddr(_)) => false,
51+
(Self::SocketAddr(l), Self::SocketAddr(r)) => l == r,
52+
#[cfg(unix)]
53+
(Self::UnixSocketAddr(l), Self::UnixSocketAddr(r)) => {
54+
if (l.is_unnamed() && r.is_unnamed()) || (l.as_pathname() == r.as_pathname()) {
55+
return true;
56+
}
57+
#[cfg(any(target_os = "linux", target_os = "android"))]
58+
if l.as_abstract_name() == r.as_abstract_name() {
59+
return true;
60+
}
61+
false
62+
},
63+
}
64+
}
65+
}
66+
67+
impl Hash for SocketAddrOrUnixSocketAddr {
68+
fn hash<H: Hasher>(&self, state: &mut H) {
69+
match self {
70+
Self::SocketAddr(a) => a.hash(state),
71+
#[cfg(unix)]
72+
Self::UnixSocketAddr(a) => {
73+
a.is_unnamed().hash(state);
74+
a.as_pathname().hash(state);
75+
#[cfg(any(target_os = "linux", target_os = "android"))]
76+
a.as_abstract_name().hash(state);
77+
}
78+
}
79+
}
80+
}
81+
82+
impl Into<SocketAddrOrUnixSocketAddr> for &SocketAddrOrUnixSocketAddr {
83+
fn into(self) -> SocketAddrOrUnixSocketAddr {
84+
self.clone() // no allocations, this struct is effectively Copy
85+
}
86+
}
87+
impl Into<SocketAddrOrUnixSocketAddr> for &mut SocketAddrOrUnixSocketAddr {
88+
fn into(self) -> SocketAddrOrUnixSocketAddr {
89+
self.clone() // no allocations, this struct is effectively Copy
90+
}
91+
}
92+
impl Into<SocketAddrOrUnixSocketAddr> for SocketAddr {
93+
fn into(self) -> SocketAddrOrUnixSocketAddr {
94+
SocketAddrOrUnixSocketAddr::SocketAddr(self)
95+
}
96+
}
97+
impl Into<SocketAddrOrUnixSocketAddr> for &SocketAddr {
98+
fn into(self) -> SocketAddrOrUnixSocketAddr {
99+
SocketAddrOrUnixSocketAddr::SocketAddr(self.clone())
100+
}
101+
}
102+
#[cfg(unix)]
103+
impl Into<SocketAddrOrUnixSocketAddr> for UnixSocketAddr {
104+
fn into(self) -> SocketAddrOrUnixSocketAddr {
105+
SocketAddrOrUnixSocketAddr::UnixSocketAddr(self)
106+
}
107+
}
108+
109+
/// Either a [`TcpStream`], or, under unix, [`UnixStream`]
110+
///
111+
/// If `#[cfg(unix)]`, this can hold an internet socket *or* a unix-domain socket.
112+
///
113+
/// Otherwise, this can only hold an internet socket.
114+
#[derive(Debug)]
115+
pub enum TcpOrUnixStream {
116+
/// The internet socket.
117+
Tcp(TcpStream),
118+
#[cfg(unix)]
119+
/// The unix-domain socket.
120+
Unix(UnixStream),
121+
}
122+
123+
macro_rules! fwd {
124+
($self:expr, $fun:ident) => {
125+
match $self {
126+
TcpOrUnixStream::Tcp(ref mut s) => s.$fun(),
127+
#[cfg(unix)]
128+
TcpOrUnixStream::Unix(ref mut s) => s.$fun(),
129+
}
130+
};
131+
($self:expr, $fun:ident, $arg:expr) => {
132+
match $self {
133+
TcpOrUnixStream::Tcp(ref mut s) => s.$fun($arg),
134+
#[cfg(unix)]
135+
TcpOrUnixStream::Unix(ref mut s) => s.$fun($arg),
136+
}
137+
}
138+
}
139+
140+
macro_rules! fwd_ref {
141+
($self:expr, $fun:ident) => {
142+
match $self {
143+
TcpOrUnixStream::Tcp(s) => (&mut &*s).$fun(),
144+
#[cfg(unix)]
145+
TcpOrUnixStream::Unix(s) => (&mut &*s).$fun(),
146+
}
147+
};
148+
($self:expr, $fun:ident, $arg:expr) => {
149+
match $self {
150+
TcpOrUnixStream::Tcp(s) => (&mut &*s).$fun($arg),
151+
#[cfg(unix)]
152+
TcpOrUnixStream::Unix(s) => (&mut &*s).$fun($arg),
153+
}
154+
}
155+
}
156+
157+
impl TcpOrUnixStream {
158+
/// [`TcpStream::connect`] or [`UnixStream::connect_addr`]
159+
pub fn connect<T: Into<SocketAddrOrUnixSocketAddr>>(addr: T) -> std::io::Result<TcpOrUnixStream> {
160+
match addr.into() {
161+
SocketAddrOrUnixSocketAddr::SocketAddr(s) => TcpStream::connect(s).map(TcpOrUnixStream::Tcp),
162+
#[cfg(unix)]
163+
SocketAddrOrUnixSocketAddr::UnixSocketAddr(s) => UnixStream::connect_addr(&s).map(TcpOrUnixStream::Unix),
164+
}
165+
}
166+
167+
/// [`TcpStream::connect`] with [`ToSocketAddrs`]
168+
pub fn connect_tsa<T: ToSocketAddrs>(addr: T) -> std::io::Result<TcpOrUnixStream> {
169+
TcpStream::connect(addr).map(TcpOrUnixStream::Tcp)
170+
}
171+
172+
/// [`TcpStream::local_addr`] or [`UnixStream::local_addr`]
173+
pub fn local_addr(&self) -> std::io::Result<SocketAddrOrUnixSocketAddr> {
174+
match self {
175+
TcpOrUnixStream::Tcp(s) => s.local_addr().map(SocketAddrOrUnixSocketAddr::SocketAddr),
176+
#[cfg(unix)]
177+
TcpOrUnixStream::Unix(s) => s.local_addr().map(SocketAddrOrUnixSocketAddr::UnixSocketAddr),
178+
}
179+
}
180+
181+
/// [`TcpStream::peer_addr`] or [`UnixStream::peer_addr`]
182+
pub fn peer_addr(&self) -> std::io::Result<SocketAddrOrUnixSocketAddr> {
183+
match self {
184+
TcpOrUnixStream::Tcp(s) => s.peer_addr().map(SocketAddrOrUnixSocketAddr::SocketAddr),
185+
#[cfg(unix)]
186+
TcpOrUnixStream::Unix(s) => s.peer_addr().map(SocketAddrOrUnixSocketAddr::UnixSocketAddr),
187+
}
188+
}
189+
190+
/// [`TcpStream::read_timeout`] or [`UnixStream::read_timeout`]
191+
pub fn read_timeout(&self) -> std::io::Result<Option<std::time::Duration>> {
192+
fwd_ref!(self, read_timeout)
193+
}
194+
195+
/// [`TcpStream::set_nonblocking`] or [`UnixStream::set_nonblocking`]
196+
pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
197+
fwd_ref!(self, set_nonblocking, nonblocking)
198+
}
199+
200+
/// [`TcpStream::set_read_timeout`] or [`UnixStream::set_read_timeout`]
201+
pub fn set_read_timeout(&self, timeout: Option<std::time::Duration>) -> std::io::Result<()> {
202+
fwd_ref!(self, set_read_timeout, timeout)
203+
}
204+
205+
/// [`TcpStream::set_write_timeout`] or [`UnixStream::set_write_timeout`]
206+
pub fn set_write_timeout(&self, timeout: Option<std::time::Duration>) -> std::io::Result<()> {
207+
fwd_ref!(self, set_write_timeout, timeout)
208+
}
209+
210+
/// [`TcpStream::shutdown`] or [`UnixStream::shutdown`]
211+
pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
212+
fwd_ref!(self, shutdown, how)
213+
}
214+
215+
/// [`TcpStream::take_error`] or [`UnixStream::take_error`]
216+
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
217+
fwd_ref!(self, take_error)
218+
}
219+
220+
/// [`TcpStream::try_clone`] or [`UnixStream::try_clone`]
221+
pub fn try_clone(&self) -> std::io::Result<TcpOrUnixStream> {
222+
match self {
223+
TcpOrUnixStream::Tcp(s) => s.try_clone().map(TcpOrUnixStream::Tcp),
224+
#[cfg(unix)]
225+
TcpOrUnixStream::Unix(s) => s.try_clone().map(TcpOrUnixStream::Unix),
226+
}
227+
}
228+
229+
/// [`TcpStream::write_timeout`] or [`UnixStream::write_timeout`]
230+
pub fn write_timeout(&self) -> std::io::Result<Option<std::time::Duration>> {
231+
fwd_ref!(self, write_timeout)
232+
}
233+
}
234+
235+
impl io::Read for TcpOrUnixStream {
236+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
237+
fwd!(self, read, buf)
238+
}
239+
fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
240+
fwd!(self, read_vectored, bufs)
241+
}
242+
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
243+
fwd!(self, read_to_end, buf)
244+
}
245+
fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
246+
fwd!(self, read_to_string, buf)
247+
}
248+
fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
249+
fwd!(self, read_exact, buf)
250+
}
251+
}
252+
253+
impl<'a> io::Read for &'a TcpOrUnixStream {
254+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
255+
fwd_ref!(self, read, buf)
256+
}
257+
fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
258+
fwd_ref!(self, read_vectored, bufs)
259+
}
260+
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
261+
fwd_ref!(self, read_to_end, buf)
262+
}
263+
fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
264+
fwd_ref!(self, read_to_string, buf)
265+
}
266+
fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
267+
fwd_ref!(self, read_exact, buf)
268+
}
269+
}
270+
271+
impl<'a> io::Write for &'a TcpOrUnixStream {
272+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
273+
fwd_ref!(self, write, buf)
274+
}
275+
fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
276+
fwd_ref!(self, write_vectored, bufs)
277+
}
278+
fn flush(&mut self) -> std::io::Result<()> {
279+
fwd_ref!(self, flush)
280+
}
281+
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
282+
fwd_ref!(self, write_all, buf)
283+
}
284+
fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
285+
fwd_ref!(self, write_fmt, fmt)
286+
}
287+
}
288+
289+
impl io::Write for TcpOrUnixStream {
290+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
291+
fwd!(self, write, buf)
292+
}
293+
fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
294+
fwd!(self, write_vectored, bufs)
295+
}
296+
fn flush(&mut self) -> std::io::Result<()> {
297+
fwd!(self, flush)
298+
}
299+
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
300+
fwd!(self, write_all, buf)
301+
}
302+
fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
303+
fwd!(self, write_fmt, fmt)
304+
}
305+
}
306+
307+
impl Into<TcpOrUnixStream> for TcpStream {
308+
fn into(self) -> TcpOrUnixStream {
309+
TcpOrUnixStream::Tcp(self)
310+
}
311+
}
312+
#[cfg(unix)]
313+
impl Into<TcpOrUnixStream> for UnixStream {
314+
fn into(self) -> TcpOrUnixStream {
315+
TcpOrUnixStream::Unix(self)
316+
}
317+
}
318+
23319
/// A description of a connection target.
24320
#[derive(Debug, Clone)]
25321
pub enum TargetAddr {

0 commit comments

Comments
 (0)