diff --git a/src/lib.rs b/src/lib.rs index ea19a8f..429974c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,16 +10,357 @@ extern crate libc; extern crate winapi; use std::io; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::net::{ + Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpStream, ToSocketAddrs, +}; +#[cfg(unix)] +use std::os::unix::net::{UnixStream, SocketAddr as UnixSocketAddr}; +#[cfg(any(target_os = "linux", target_os = "android"))] +use std::os::linux::net::SocketAddrExt; +use std::hash::{Hash, Hasher}; use std::vec; -pub use v4::{Socks4Stream, Socks4Listener}; -pub use v5::{Socks5Stream, Socks5Listener, Socks5Datagram}; +pub use v4::{Socks4Listener, Socks4Stream}; +pub use v5::{Socks5Datagram, Socks5Listener, Socks5Stream}; mod v4; mod v5; mod writev; +/// Either a [`SocketAddr`], or, under unix, [`UnixSocketAddr`] +/// +/// If `#[cfg(unix)]`, this can hold an internet socket address *or* a unix-domain socket address. +/// +/// Otherwise, this can only hold an internet socket address. +#[derive(Clone, Debug)] +pub enum SocketAddrOrUnixSocketAddr { + /// The internet address. + SocketAddr(SocketAddr), + /// The unix-domain address. + #[cfg(unix)] + UnixSocketAddr(UnixSocketAddr), +} + +impl PartialEq for SocketAddrOrUnixSocketAddr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + #[cfg(unix)] + (Self::SocketAddr(_), Self::UnixSocketAddr(_)) => false, + #[cfg(unix)] + (Self::UnixSocketAddr(_), Self::SocketAddr(_)) => false, + (Self::SocketAddr(l), Self::SocketAddr(r)) => l == r, + #[cfg(unix)] + (Self::UnixSocketAddr(l), Self::UnixSocketAddr(r)) => { + if (l.is_unnamed() && r.is_unnamed()) || (l.as_pathname() == r.as_pathname()) { + return true; + } + #[cfg(any(target_os = "linux", target_os = "android"))] + if l.as_abstract_name() == r.as_abstract_name() { + return true; + } + false + }, + } + } +} + +impl Hash for SocketAddrOrUnixSocketAddr { + fn hash(&self, state: &mut H) { + match self { + Self::SocketAddr(a) => a.hash(state), + #[cfg(unix)] + Self::UnixSocketAddr(a) => { + a.is_unnamed().hash(state); + a.as_pathname().hash(state); + #[cfg(any(target_os = "linux", target_os = "android"))] + a.as_abstract_name().hash(state); + } + } + } +} + +impl Into for &SocketAddrOrUnixSocketAddr { + fn into(self) -> SocketAddrOrUnixSocketAddr { + self.clone() // no allocations, this struct is effectively Copy + } +} +impl Into for &mut SocketAddrOrUnixSocketAddr { + fn into(self) -> SocketAddrOrUnixSocketAddr { + self.clone() // no allocations, this struct is effectively Copy + } +} +impl Into for SocketAddr { + fn into(self) -> SocketAddrOrUnixSocketAddr { + SocketAddrOrUnixSocketAddr::SocketAddr(self) + } +} +impl Into for &SocketAddr { + fn into(self) -> SocketAddrOrUnixSocketAddr { + SocketAddrOrUnixSocketAddr::SocketAddr(self.clone()) + } +} +#[cfg(unix)] +impl Into for UnixSocketAddr { + fn into(self) -> SocketAddrOrUnixSocketAddr { + SocketAddrOrUnixSocketAddr::UnixSocketAddr(self) + } +} + +/// Either a [`TcpStream`], or, under unix, [`UnixStream`] +/// +/// If `#[cfg(unix)]`, this can hold an internet socket *or* a unix-domain socket. +/// +/// Otherwise, this can only hold an internet socket. +#[derive(Debug)] +pub enum TcpOrUnixStream { + /// The internet socket. + Tcp(TcpStream), + #[cfg(unix)] + /// The unix-domain socket. + Unix(UnixStream), +} + +macro_rules! fwd { + ($self:expr, $fun:ident) => { + match $self { + TcpOrUnixStream::Tcp(ref mut s) => s.$fun(), + #[cfg(unix)] + TcpOrUnixStream::Unix(ref mut s) => s.$fun(), + } + }; + ($self:expr, $fun:ident, $arg:expr) => { + match $self { + TcpOrUnixStream::Tcp(ref mut s) => s.$fun($arg), + #[cfg(unix)] + TcpOrUnixStream::Unix(ref mut s) => s.$fun($arg), + } + } +} + +macro_rules! fwd_ref { + ($self:expr, $fun:ident) => { + match $self { + TcpOrUnixStream::Tcp(s) => (&mut &*s).$fun(), + #[cfg(unix)] + TcpOrUnixStream::Unix(s) => (&mut &*s).$fun(), + } + }; + ($self:expr, $fun:ident, $arg:expr) => { + match $self { + TcpOrUnixStream::Tcp(s) => (&mut &*s).$fun($arg), + #[cfg(unix)] + TcpOrUnixStream::Unix(s) => (&mut &*s).$fun($arg), + } + } +} + +macro_rules! fwd_move { + ($self:expr, $fun:ident) => { + match $self { + TcpOrUnixStream::Tcp(s) => s.$fun(), + #[cfg(unix)] + TcpOrUnixStream::Unix(s) => s.$fun(), + } + }; + ($self:expr, $fun:ident, $arg:expr) => { + match $self { + TcpOrUnixStream::Tcp(s) => s.$fun($arg), + #[cfg(unix)] + TcpOrUnixStream::Unix(s) => s.$fun($arg), + } + } +} + +impl TcpOrUnixStream { + /// [`TcpStream::connect`] or [`UnixStream::connect_addr`] + pub fn connect>(addr: T) -> std::io::Result { + match addr.into() { + SocketAddrOrUnixSocketAddr::SocketAddr(s) => TcpStream::connect(s).map(TcpOrUnixStream::Tcp), + #[cfg(unix)] + SocketAddrOrUnixSocketAddr::UnixSocketAddr(s) => UnixStream::connect_addr(&s).map(TcpOrUnixStream::Unix), + } + } + + /// [`TcpStream::connect`] with [`ToSocketAddrs`] + pub fn connect_tsa(addr: T) -> std::io::Result { + TcpStream::connect(addr).map(TcpOrUnixStream::Tcp) + } + + /// [`TcpStream::local_addr`] or [`UnixStream::local_addr`] + pub fn local_addr(&self) -> std::io::Result { + match self { + TcpOrUnixStream::Tcp(s) => s.local_addr().map(SocketAddrOrUnixSocketAddr::SocketAddr), + #[cfg(unix)] + TcpOrUnixStream::Unix(s) => s.local_addr().map(SocketAddrOrUnixSocketAddr::UnixSocketAddr), + } + } + + /// [`TcpStream::peer_addr`] or [`UnixStream::peer_addr`] + pub fn peer_addr(&self) -> std::io::Result { + match self { + TcpOrUnixStream::Tcp(s) => s.peer_addr().map(SocketAddrOrUnixSocketAddr::SocketAddr), + #[cfg(unix)] + TcpOrUnixStream::Unix(s) => s.peer_addr().map(SocketAddrOrUnixSocketAddr::UnixSocketAddr), + } + } + + /// [`TcpStream::read_timeout`] or [`UnixStream::read_timeout`] + pub fn read_timeout(&self) -> std::io::Result> { + fwd_ref!(self, read_timeout) + } + + /// [`TcpStream::set_nonblocking`] or [`UnixStream::set_nonblocking`] + pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> { + fwd_ref!(self, set_nonblocking, nonblocking) + } + + /// [`TcpStream::set_read_timeout`] or [`UnixStream::set_read_timeout`] + pub fn set_read_timeout(&self, timeout: Option) -> std::io::Result<()> { + fwd_ref!(self, set_read_timeout, timeout) + } + + /// [`TcpStream::set_write_timeout`] or [`UnixStream::set_write_timeout`] + pub fn set_write_timeout(&self, timeout: Option) -> std::io::Result<()> { + fwd_ref!(self, set_write_timeout, timeout) + } + + /// [`TcpStream::shutdown`] or [`UnixStream::shutdown`] + pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> { + fwd_ref!(self, shutdown, how) + } + + /// [`TcpStream::take_error`] or [`UnixStream::take_error`] + pub fn take_error(&self) -> std::io::Result> { + fwd_ref!(self, take_error) + } + + /// [`TcpStream::try_clone`] or [`UnixStream::try_clone`] + pub fn try_clone(&self) -> std::io::Result { + match self { + TcpOrUnixStream::Tcp(s) => s.try_clone().map(TcpOrUnixStream::Tcp), + #[cfg(unix)] + TcpOrUnixStream::Unix(s) => s.try_clone().map(TcpOrUnixStream::Unix), + } + } + + /// [`TcpStream::write_timeout`] or [`UnixStream::write_timeout`] + pub fn write_timeout(&self) -> std::io::Result> { + fwd_ref!(self, write_timeout) + } +} + +impl io::Read for TcpOrUnixStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + fwd!(self, read, buf) + } + fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result { + fwd!(self, read_vectored, bufs) + } + fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { + fwd!(self, read_to_end, buf) + } + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { + fwd!(self, read_to_string, buf) + } + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + fwd!(self, read_exact, buf) + } +} + +impl<'a> io::Read for &'a TcpOrUnixStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + fwd_ref!(self, read, buf) + } + fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result { + fwd_ref!(self, read_vectored, bufs) + } + fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { + fwd_ref!(self, read_to_end, buf) + } + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { + fwd_ref!(self, read_to_string, buf) + } + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + fwd_ref!(self, read_exact, buf) + } +} + +impl<'a> io::Write for &'a TcpOrUnixStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + fwd_ref!(self, write, buf) + } + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { + fwd_ref!(self, write_vectored, bufs) + } + fn flush(&mut self) -> std::io::Result<()> { + fwd_ref!(self, flush) + } + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + fwd_ref!(self, write_all, buf) + } + fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> { + fwd_ref!(self, write_fmt, fmt) + } +} + +impl io::Write for TcpOrUnixStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + fwd!(self, write, buf) + } + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { + fwd!(self, write_vectored, bufs) + } + fn flush(&mut self) -> std::io::Result<()> { + fwd!(self, flush) + } + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + fwd!(self, write_all, buf) + } + fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> { + fwd!(self, write_fmt, fmt) + } +} + +#[cfg(unix)] +impl std::os::fd::AsRawFd for TcpOrUnixStream { + fn as_raw_fd(&self) -> std::os::fd::RawFd { + fwd_ref!(self, as_raw_fd) + } +} + +#[cfg(windows)] +impl std::os::windows::io::AsRawSocket for TcpOrUnixStream { + fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + fwd_ref!(self, as_raw_socket) + } +} + +#[cfg(unix)] +impl std::os::fd::IntoRawFd for TcpOrUnixStream { + fn into_raw_fd(self) -> std::os::fd::RawFd { + fwd_move!(self, into_raw_fd) + } +} + +#[cfg(windows)] +impl std::os::windows::io::IntoRawSocket for TcpOrUnixStream { + fn into_raw_socket(self) -> std::os::windows::io::RawSocket { + fwd_move!(self, into_raw_socket) + } +} + +impl Into for TcpStream { + fn into(self) -> TcpOrUnixStream { + TcpOrUnixStream::Tcp(self) + } +} +#[cfg(unix)] +impl Into for UnixStream { + fn into(self) -> TcpOrUnixStream { + TcpOrUnixStream::Unix(self) + } +} + /// A description of a connection target. #[derive(Debug, Clone)] pub enum TargetAddr { diff --git a/src/v5.rs b/src/v5.rs index 4de4ac2..7a7d5f8 100644 --- a/src/v5.rs +++ b/src/v5.rs @@ -1,11 +1,11 @@ use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian}; use std::cmp; use std::io::{self, Read, Write}; -use std::net::{SocketAddr, ToSocketAddrs, SocketAddrV4, SocketAddrV6, TcpStream, Ipv4Addr, +use std::net::{SocketAddr, ToSocketAddrs, SocketAddrV4, SocketAddrV6, Ipv4Addr, Ipv6Addr, UdpSocket}; use std::ptr; -use {ToTargetAddr, TargetAddr}; +use {TcpOrUnixStream, SocketAddrOrUnixSocketAddr, ToTargetAddr, TargetAddr}; use writev::WritevExt; const MAX_ADDR_LEN: usize = 260; @@ -37,7 +37,7 @@ fn read_addr(socket: &mut R) -> io::Result { } } -fn read_response(socket: &mut TcpStream) -> io::Result { +fn read_response(socket: &mut TcpOrUnixStream) -> io::Result { if socket.read_u8()? != 5 { return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid response version")); @@ -117,7 +117,7 @@ impl<'a> Authentication<'a> { /// A SOCKS5 client. #[derive(Debug)] pub struct Socks5Stream { - socket: TcpStream, + socket: TcpOrUnixStream, proxy_addr: TargetAddr, } @@ -127,7 +127,7 @@ impl Socks5Stream { where T: ToSocketAddrs, U: ToTargetAddr { - Self::connect_raw(1, proxy, target, &Authentication::None) + Self::connect_raw(1, TcpOrUnixStream::connect_tsa(proxy)?, target, &Authentication::None) } /// Connects to a target server through a SOCKS5 proxy using given @@ -137,15 +137,30 @@ impl Socks5Stream { U: ToTargetAddr { let auth = Authentication::Password { username, password }; - Self::connect_raw(1, proxy, target, &auth) + Self::connect_raw(1, TcpOrUnixStream::connect_tsa(proxy)?, target, &auth) } - fn connect_raw(command: u8, proxy: T, target: U, auth: &Authentication) -> io::Result - where T: ToSocketAddrs, + /// Connects to a target server through a SOCKS5 proxy. + pub fn connect_either(proxy: T, target: U) -> io::Result + where T: Into, + U: ToTargetAddr + { + Self::connect_raw(1, TcpOrUnixStream::connect(proxy)?, target, &Authentication::None) + } + + /// Connects to a target server through a SOCKS5 proxy using given + /// username and password. + pub fn connect_either_with_password(proxy: T, target: U, username: &str, password: &str) -> io::Result + where T: Into, U: ToTargetAddr { - let mut socket = TcpStream::connect(proxy)?; + let auth = Authentication::Password { username, password }; + Self::connect_raw(1, TcpOrUnixStream::connect(proxy)?, target, &auth) + } + fn connect_raw(command: u8, mut socket: TcpOrUnixStream, target: U, auth: &Authentication) -> io::Result + where U: ToTargetAddr + { let target = target.to_target_addr()?; let packet_len = if auth.is_no_auth() { 3 } else { 4 }; @@ -196,7 +211,7 @@ impl Socks5Stream { }) } - fn password_authentication(socket: &mut TcpStream, username: &str, password: &str) -> io::Result<()> { + fn password_authentication(socket: &mut TcpOrUnixStream, username: &str, password: &str) -> io::Result<()> { if username.len() < 1 || username.len() > 255 { return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid username")) }; @@ -231,18 +246,18 @@ impl Socks5Stream { &self.proxy_addr } - /// Returns a shared reference to the inner `TcpStream`. - pub fn get_ref(&self) -> &TcpStream { + /// Returns a shared reference to the inner `TcpOrUnixStream`. + pub fn get_ref(&self) -> &TcpOrUnixStream { &self.socket } - /// Returns a mutable reference to the inner `TcpStream`. - pub fn get_mut(&mut self) -> &mut TcpStream { + /// Returns a mutable reference to the inner `TcpOrUnixStream`. + pub fn get_mut(&mut self) -> &mut TcpOrUnixStream { &mut self.socket } - /// Consumes the `Socks5Stream`, returning the inner `TcpStream`. - pub fn into_inner(self) -> TcpStream { + /// Consumes the `Socks5Stream`, returning the inner `TcpOrUnixStream`. + pub fn into_inner(self) -> TcpOrUnixStream { self.socket } } @@ -292,7 +307,7 @@ impl Socks5Listener { where T: ToSocketAddrs, U: ToTargetAddr { - Socks5Stream::connect_raw(2, proxy, target, &Authentication::None).map(Socks5Listener) + Socks5Stream::connect_raw(2, TcpOrUnixStream::connect_tsa(proxy)?, target, &Authentication::None).map(Socks5Listener) } /// Initiates a BIND request to the specified proxy using given username /// and password. @@ -304,7 +319,25 @@ impl Socks5Listener { U: ToTargetAddr { let auth = Authentication::Password { username, password }; - Socks5Stream::connect_raw(2, proxy, target, &auth).map(Socks5Listener) + Socks5Stream::connect_raw(2, TcpOrUnixStream::connect_tsa(proxy)?, target, &auth).map(Socks5Listener) + } + + /// Connects to a target server through a SOCKS5 proxy. + pub fn bind_either(proxy: T, target: U) -> io::Result + where T: Into, + U: ToTargetAddr + { + Socks5Stream::connect_raw(2, TcpOrUnixStream::connect(proxy)?, target, &Authentication::None).map(Socks5Listener) + } + + /// Connects to a target server through a SOCKS5 proxy using given + /// username and password. + pub fn bind_either_with_password(proxy: T, target: U, username: &str, password: &str) -> io::Result + where T: Into, + U: ToTargetAddr + { + let auth = Authentication::Password { username, password }; + Socks5Stream::connect_raw(2, TcpOrUnixStream::connect(proxy)?, target, &auth).map(Socks5Listener) } /// The address of the proxy-side TCP listener. @@ -360,7 +393,7 @@ impl Socks5Datagram { // we don't know what our IP is from the perspective of the proxy, so // don't try to pass `addr` in here. let dst = TargetAddr::Ip(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0))); - let stream = Socks5Stream::connect_raw(3, proxy, dst, auth)?; + let stream = Socks5Stream::connect_raw(3, TcpOrUnixStream::connect_tsa(proxy)?, dst, auth)?; let socket = UdpSocket::bind(addr)?; socket.connect(&stream.proxy_addr)?;