Skip to content

Commit 14f9d18

Browse files
Add TcpOrUnixStream and SocketAddrOrUnixSocketAddr, allow using SOCKS5 over unix-domain sockets
1 parent 4e7463d commit 14f9d18

File tree

2 files changed

+382
-22
lines changed

2 files changed

+382
-22
lines changed

src/lib.rs

Lines changed: 330 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,343 @@ extern crate libc;
1010
extern crate winapi;
1111

1212
use std::io;
13-
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
13+
use std::net::{
14+
Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpStream, ToSocketAddrs,
15+
};
16+
#[cfg(unix)]
17+
use std::os::unix::net::{UnixStream, SocketAddr as UnixSocketAddr};
18+
#[cfg(any(target_os = "linux", target_os = "android"))]
19+
use std::os::linux::net::SocketAddrExt;
20+
use std::hash::{Hash, Hasher};
1421
use std::vec;
1522

16-
pub use v4::{Socks4Stream, Socks4Listener};
17-
pub use v5::{Socks5Stream, Socks5Listener, Socks5Datagram};
23+
pub use v4::{Socks4Listener, Socks4Stream};
24+
pub use v5::{Socks5Datagram, Socks5Listener, Socks5Stream};
1825

1926
mod v4;
2027
mod v5;
2128
mod writev;
2229

30+
/// Either a [`SocketAddr`], or, under unix, [`UnixSocketAddr`]
31+
///
32+
/// If `#[cfg(unix)]`, this can hold an internet socket address *or* a unix-domain socket address.
33+
///
34+
/// Otherwise, this can only hold an internet socket address.
35+
#[derive(Clone, Debug)]
36+
pub enum SocketAddrOrUnixSocketAddr {
37+
/// The internet address.
38+
SocketAddr(SocketAddr),
39+
/// The unix-domain address.
40+
#[cfg(unix)]
41+
UnixSocketAddr(UnixSocketAddr),
42+
}
43+
44+
impl PartialEq for SocketAddrOrUnixSocketAddr {
45+
fn eq(&self, other: &Self) -> bool {
46+
match (self, other) {
47+
#[cfg(unix)]
48+
(Self::SocketAddr(_), Self::UnixSocketAddr(_)) => false,
49+
#[cfg(unix)]
50+
(Self::UnixSocketAddr(_), Self::SocketAddr(_)) => false,
51+
(Self::SocketAddr(l), Self::SocketAddr(r)) => l == r,
52+
#[cfg(unix)]
53+
(Self::UnixSocketAddr(l), Self::UnixSocketAddr(r)) => {
54+
if (l.is_unnamed() && r.is_unnamed()) || (l.as_pathname() == r.as_pathname()) {
55+
return true;
56+
}
57+
#[cfg(any(target_os = "linux", target_os = "android"))]
58+
if l.as_abstract_name() == r.as_abstract_name() {
59+
return true;
60+
}
61+
false
62+
},
63+
}
64+
}
65+
}
66+
67+
impl Hash for SocketAddrOrUnixSocketAddr {
68+
fn hash<H: Hasher>(&self, state: &mut H) {
69+
match self {
70+
Self::SocketAddr(a) => a.hash(state),
71+
#[cfg(unix)]
72+
Self::UnixSocketAddr(a) => {
73+
a.is_unnamed().hash(state);
74+
a.as_pathname().hash(state);
75+
#[cfg(any(target_os = "linux", target_os = "android"))]
76+
a.as_abstract_name().hash(state);
77+
}
78+
}
79+
}
80+
}
81+
82+
impl Into<SocketAddrOrUnixSocketAddr> for &SocketAddrOrUnixSocketAddr {
83+
fn into(self) -> SocketAddrOrUnixSocketAddr {
84+
self.clone() // no allocations, this struct is effectively Copy
85+
}
86+
}
87+
impl Into<SocketAddrOrUnixSocketAddr> for &mut SocketAddrOrUnixSocketAddr {
88+
fn into(self) -> SocketAddrOrUnixSocketAddr {
89+
self.clone() // no allocations, this struct is effectively Copy
90+
}
91+
}
92+
impl Into<SocketAddrOrUnixSocketAddr> for SocketAddr {
93+
fn into(self) -> SocketAddrOrUnixSocketAddr {
94+
SocketAddrOrUnixSocketAddr::SocketAddr(self)
95+
}
96+
}
97+
impl Into<SocketAddrOrUnixSocketAddr> for &SocketAddr {
98+
fn into(self) -> SocketAddrOrUnixSocketAddr {
99+
SocketAddrOrUnixSocketAddr::SocketAddr(self.clone())
100+
}
101+
}
102+
#[cfg(unix)]
103+
impl Into<SocketAddrOrUnixSocketAddr> for UnixSocketAddr {
104+
fn into(self) -> SocketAddrOrUnixSocketAddr {
105+
SocketAddrOrUnixSocketAddr::UnixSocketAddr(self)
106+
}
107+
}
108+
109+
/// Either a [`TcpStream`], or, under unix, [`UnixStream`]
110+
///
111+
/// If `#[cfg(unix)]`, this can hold an internet socket *or* a unix-domain socket.
112+
///
113+
/// Otherwise, this can only hold an internet socket.
114+
#[derive(Debug)]
115+
pub enum TcpOrUnixStream {
116+
/// The internet socket.
117+
Tcp(TcpStream),
118+
#[cfg(unix)]
119+
/// The unix-domain socket.
120+
Unix(UnixStream),
121+
}
122+
123+
macro_rules! fwd {
124+
($self:expr, $fun:ident) => {
125+
match $self {
126+
TcpOrUnixStream::Tcp(ref mut s) => s.$fun(),
127+
#[cfg(unix)]
128+
TcpOrUnixStream::Unix(ref mut s) => s.$fun(),
129+
}
130+
};
131+
($self:expr, $fun:ident, $arg:expr) => {
132+
match $self {
133+
TcpOrUnixStream::Tcp(ref mut s) => s.$fun($arg),
134+
#[cfg(unix)]
135+
TcpOrUnixStream::Unix(ref mut s) => s.$fun($arg),
136+
}
137+
}
138+
}
139+
140+
macro_rules! fwd_ref {
141+
($self:expr, $fun:ident) => {
142+
match $self {
143+
TcpOrUnixStream::Tcp(s) => (&mut &*s).$fun(),
144+
#[cfg(unix)]
145+
TcpOrUnixStream::Unix(s) => (&mut &*s).$fun(),
146+
}
147+
};
148+
($self:expr, $fun:ident, $arg:expr) => {
149+
match $self {
150+
TcpOrUnixStream::Tcp(s) => (&mut &*s).$fun($arg),
151+
#[cfg(unix)]
152+
TcpOrUnixStream::Unix(s) => (&mut &*s).$fun($arg),
153+
}
154+
}
155+
}
156+
157+
macro_rules! fwd_move {
158+
($self:expr, $fun:ident) => {
159+
match $self {
160+
TcpOrUnixStream::Tcp(s) => s.$fun(),
161+
#[cfg(unix)]
162+
TcpOrUnixStream::Unix(s) => s.$fun(),
163+
}
164+
};
165+
($self:expr, $fun:ident, $arg:expr) => {
166+
match $self {
167+
TcpOrUnixStream::Tcp(s) => s.$fun($arg),
168+
#[cfg(unix)]
169+
TcpOrUnixStream::Unix(s) => s.$fun($arg),
170+
}
171+
}
172+
}
173+
174+
impl TcpOrUnixStream {
175+
/// [`TcpStream::connect`] or [`UnixStream::connect_addr`]
176+
pub fn connect<T: Into<SocketAddrOrUnixSocketAddr>>(addr: T) -> std::io::Result<TcpOrUnixStream> {
177+
match addr.into() {
178+
SocketAddrOrUnixSocketAddr::SocketAddr(s) => TcpStream::connect(s).map(TcpOrUnixStream::Tcp),
179+
#[cfg(unix)]
180+
SocketAddrOrUnixSocketAddr::UnixSocketAddr(s) => UnixStream::connect_addr(&s).map(TcpOrUnixStream::Unix),
181+
}
182+
}
183+
184+
/// [`TcpStream::connect`] with [`ToSocketAddrs`]
185+
pub fn connect_tsa<T: ToSocketAddrs>(addr: T) -> std::io::Result<TcpOrUnixStream> {
186+
TcpStream::connect(addr).map(TcpOrUnixStream::Tcp)
187+
}
188+
189+
/// [`TcpStream::local_addr`] or [`UnixStream::local_addr`]
190+
pub fn local_addr(&self) -> std::io::Result<SocketAddrOrUnixSocketAddr> {
191+
match self {
192+
TcpOrUnixStream::Tcp(s) => s.local_addr().map(SocketAddrOrUnixSocketAddr::SocketAddr),
193+
#[cfg(unix)]
194+
TcpOrUnixStream::Unix(s) => s.local_addr().map(SocketAddrOrUnixSocketAddr::UnixSocketAddr),
195+
}
196+
}
197+
198+
/// [`TcpStream::peer_addr`] or [`UnixStream::peer_addr`]
199+
pub fn peer_addr(&self) -> std::io::Result<SocketAddrOrUnixSocketAddr> {
200+
match self {
201+
TcpOrUnixStream::Tcp(s) => s.peer_addr().map(SocketAddrOrUnixSocketAddr::SocketAddr),
202+
#[cfg(unix)]
203+
TcpOrUnixStream::Unix(s) => s.peer_addr().map(SocketAddrOrUnixSocketAddr::UnixSocketAddr),
204+
}
205+
}
206+
207+
/// [`TcpStream::read_timeout`] or [`UnixStream::read_timeout`]
208+
pub fn read_timeout(&self) -> std::io::Result<Option<std::time::Duration>> {
209+
fwd_ref!(self, read_timeout)
210+
}
211+
212+
/// [`TcpStream::set_nonblocking`] or [`UnixStream::set_nonblocking`]
213+
pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
214+
fwd_ref!(self, set_nonblocking, nonblocking)
215+
}
216+
217+
/// [`TcpStream::set_read_timeout`] or [`UnixStream::set_read_timeout`]
218+
pub fn set_read_timeout(&self, timeout: Option<std::time::Duration>) -> std::io::Result<()> {
219+
fwd_ref!(self, set_read_timeout, timeout)
220+
}
221+
222+
/// [`TcpStream::set_write_timeout`] or [`UnixStream::set_write_timeout`]
223+
pub fn set_write_timeout(&self, timeout: Option<std::time::Duration>) -> std::io::Result<()> {
224+
fwd_ref!(self, set_write_timeout, timeout)
225+
}
226+
227+
/// [`TcpStream::shutdown`] or [`UnixStream::shutdown`]
228+
pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
229+
fwd_ref!(self, shutdown, how)
230+
}
231+
232+
/// [`TcpStream::take_error`] or [`UnixStream::take_error`]
233+
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
234+
fwd_ref!(self, take_error)
235+
}
236+
237+
/// [`TcpStream::try_clone`] or [`UnixStream::try_clone`]
238+
pub fn try_clone(&self) -> std::io::Result<TcpOrUnixStream> {
239+
match self {
240+
TcpOrUnixStream::Tcp(s) => s.try_clone().map(TcpOrUnixStream::Tcp),
241+
#[cfg(unix)]
242+
TcpOrUnixStream::Unix(s) => s.try_clone().map(TcpOrUnixStream::Unix),
243+
}
244+
}
245+
246+
/// [`TcpStream::write_timeout`] or [`UnixStream::write_timeout`]
247+
pub fn write_timeout(&self) -> std::io::Result<Option<std::time::Duration>> {
248+
fwd_ref!(self, write_timeout)
249+
}
250+
}
251+
252+
impl io::Read for TcpOrUnixStream {
253+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
254+
fwd!(self, read, buf)
255+
}
256+
fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
257+
fwd!(self, read_vectored, bufs)
258+
}
259+
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
260+
fwd!(self, read_to_end, buf)
261+
}
262+
fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
263+
fwd!(self, read_to_string, buf)
264+
}
265+
fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
266+
fwd!(self, read_exact, buf)
267+
}
268+
}
269+
270+
impl<'a> io::Read for &'a TcpOrUnixStream {
271+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
272+
fwd_ref!(self, read, buf)
273+
}
274+
fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
275+
fwd_ref!(self, read_vectored, bufs)
276+
}
277+
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
278+
fwd_ref!(self, read_to_end, buf)
279+
}
280+
fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
281+
fwd_ref!(self, read_to_string, buf)
282+
}
283+
fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
284+
fwd_ref!(self, read_exact, buf)
285+
}
286+
}
287+
288+
impl<'a> io::Write for &'a TcpOrUnixStream {
289+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
290+
fwd_ref!(self, write, buf)
291+
}
292+
fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
293+
fwd_ref!(self, write_vectored, bufs)
294+
}
295+
fn flush(&mut self) -> std::io::Result<()> {
296+
fwd_ref!(self, flush)
297+
}
298+
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
299+
fwd_ref!(self, write_all, buf)
300+
}
301+
fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
302+
fwd_ref!(self, write_fmt, fmt)
303+
}
304+
}
305+
306+
impl io::Write for TcpOrUnixStream {
307+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
308+
fwd!(self, write, buf)
309+
}
310+
fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
311+
fwd!(self, write_vectored, bufs)
312+
}
313+
fn flush(&mut self) -> std::io::Result<()> {
314+
fwd!(self, flush)
315+
}
316+
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
317+
fwd!(self, write_all, buf)
318+
}
319+
fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
320+
fwd!(self, write_fmt, fmt)
321+
}
322+
}
323+
324+
#[cfg(unix)]
325+
impl std::os::fd::IntoRawFd for TcpOrUnixStream {
326+
fn into_raw_fd(self) -> std::os::fd::RawFd {
327+
fwd_move!(self, into_raw_fd)
328+
}
329+
}
330+
331+
#[cfg(windows)]
332+
impl std::os::windows::io::IntoRawSocket for TcpOrUnixStream {
333+
fn into_raw_socket(self) -> std::os::windows::io::RawSocket {
334+
fwd_move!(self, into_raw_fd)
335+
}
336+
}
337+
338+
impl Into<TcpOrUnixStream> for TcpStream {
339+
fn into(self) -> TcpOrUnixStream {
340+
TcpOrUnixStream::Tcp(self)
341+
}
342+
}
343+
#[cfg(unix)]
344+
impl Into<TcpOrUnixStream> for UnixStream {
345+
fn into(self) -> TcpOrUnixStream {
346+
TcpOrUnixStream::Unix(self)
347+
}
348+
}
349+
23350
/// A description of a connection target.
24351
#[derive(Debug, Clone)]
25352
pub enum TargetAddr {

0 commit comments

Comments
 (0)