@@ -10,16 +10,312 @@ extern crate libc;
10
10
extern crate winapi;
11
11
12
12
use std:: io;
13
- use std:: net:: { Ipv4Addr , Ipv6Addr , SocketAddr , SocketAddrV4 , SocketAddrV6 , ToSocketAddrs } ;
13
+ use std:: net:: {
14
+ Ipv4Addr , Ipv6Addr , SocketAddr , SocketAddrV4 , SocketAddrV6 , TcpStream , ToSocketAddrs ,
15
+ } ;
16
+ #[ cfg( unix) ]
17
+ use std:: os:: unix:: net:: { UnixStream , SocketAddr as UnixSocketAddr } ;
18
+ #[ cfg( any( target_os = "linux" , target_os = "android" ) ) ]
19
+ use std:: os:: linux:: net:: SocketAddrExt ;
20
+ use std:: hash:: { Hash , Hasher } ;
14
21
use std:: vec;
15
22
16
- pub use v4:: { Socks4Stream , Socks4Listener } ;
17
- pub use v5:: { Socks5Stream , Socks5Listener , Socks5Datagram } ;
23
+ pub use v4:: { Socks4Listener , Socks4Stream } ;
24
+ pub use v5:: { Socks5Datagram , Socks5Listener , Socks5Stream } ;
18
25
19
26
mod v4;
20
27
mod v5;
21
28
mod writev;
22
29
30
+ /// Either a [`SocketAddr`], or, under unix, [`UnixSocketAddr`]
31
+ ///
32
+ /// If `#[cfg(unix)]`, this can hold an internet socket address *or* a unix-domain socket address.
33
+ ///
34
+ /// Otherwise, this can only hold an internet socket address.
35
+ #[ derive( Clone , Debug ) ]
36
+ pub enum SocketAddrOrUnixSocketAddr {
37
+ /// The internet address.
38
+ SocketAddr ( SocketAddr ) ,
39
+ /// The unix-domain address.
40
+ #[ cfg( unix) ]
41
+ UnixSocketAddr ( UnixSocketAddr ) ,
42
+ }
43
+
44
+ impl PartialEq for SocketAddrOrUnixSocketAddr {
45
+ fn eq ( & self , other : & Self ) -> bool {
46
+ match ( self , other) {
47
+ #[ cfg( unix) ]
48
+ ( Self :: SocketAddr ( _) , Self :: UnixSocketAddr ( _) ) => false ,
49
+ #[ cfg( unix) ]
50
+ ( Self :: UnixSocketAddr ( _) , Self :: SocketAddr ( _) ) => false ,
51
+ ( Self :: SocketAddr ( l) , Self :: SocketAddr ( r) ) => l == r,
52
+ #[ cfg( unix) ]
53
+ ( Self :: UnixSocketAddr ( l) , Self :: UnixSocketAddr ( r) ) => {
54
+ if ( l. is_unnamed ( ) && r. is_unnamed ( ) ) || ( l. as_pathname ( ) == r. as_pathname ( ) ) {
55
+ return true ;
56
+ }
57
+ #[ cfg( any( target_os = "linux" , target_os = "android" ) ) ]
58
+ if l. as_abstract_name ( ) == r. as_abstract_name ( ) {
59
+ return true ;
60
+ }
61
+ false
62
+ } ,
63
+ }
64
+ }
65
+ }
66
+
67
+ impl Hash for SocketAddrOrUnixSocketAddr {
68
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
69
+ match self {
70
+ Self :: SocketAddr ( a) => a. hash ( state) ,
71
+ #[ cfg( unix) ]
72
+ Self :: UnixSocketAddr ( a) => {
73
+ a. is_unnamed ( ) . hash ( state) ;
74
+ a. as_pathname ( ) . hash ( state) ;
75
+ #[ cfg( any( target_os = "linux" , target_os = "android" ) ) ]
76
+ a. as_abstract_name ( ) . hash ( state) ;
77
+ }
78
+ }
79
+ }
80
+ }
81
+
82
+ impl Into < SocketAddrOrUnixSocketAddr > for & SocketAddrOrUnixSocketAddr {
83
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
84
+ self . clone ( ) // no allocations, this struct is effectively Copy
85
+ }
86
+ }
87
+ impl Into < SocketAddrOrUnixSocketAddr > for & mut SocketAddrOrUnixSocketAddr {
88
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
89
+ self . clone ( ) // no allocations, this struct is effectively Copy
90
+ }
91
+ }
92
+ impl Into < SocketAddrOrUnixSocketAddr > for SocketAddr {
93
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
94
+ SocketAddrOrUnixSocketAddr :: SocketAddr ( self )
95
+ }
96
+ }
97
+ impl Into < SocketAddrOrUnixSocketAddr > for & SocketAddr {
98
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
99
+ SocketAddrOrUnixSocketAddr :: SocketAddr ( self . clone ( ) )
100
+ }
101
+ }
102
+ #[ cfg( unix) ]
103
+ impl Into < SocketAddrOrUnixSocketAddr > for UnixSocketAddr {
104
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
105
+ SocketAddrOrUnixSocketAddr :: UnixSocketAddr ( self )
106
+ }
107
+ }
108
+
109
+ /// Either a [`TcpStream`], or, under unix, [`UnixStream`]
110
+ ///
111
+ /// If `#[cfg(unix)]`, this can hold an internet socket *or* a unix-domain socket.
112
+ ///
113
+ /// Otherwise, this can only hold an internet socket.
114
+ #[ derive( Debug ) ]
115
+ pub enum TcpOrUnixStream {
116
+ /// The internet socket.
117
+ Tcp ( TcpStream ) ,
118
+ #[ cfg( unix) ]
119
+ /// The unix-domain socket.
120
+ Unix ( UnixStream ) ,
121
+ }
122
+
123
+ macro_rules! fwd {
124
+ ( $self: expr, $fun: ident) => {
125
+ match $self {
126
+ TcpOrUnixStream :: Tcp ( ref mut s) => s. $fun( ) ,
127
+ #[ cfg( unix) ]
128
+ TcpOrUnixStream :: Unix ( ref mut s) => s. $fun( ) ,
129
+ }
130
+ } ;
131
+ ( $self: expr, $fun: ident, $arg: expr) => {
132
+ match $self {
133
+ TcpOrUnixStream :: Tcp ( ref mut s) => s. $fun( $arg) ,
134
+ #[ cfg( unix) ]
135
+ TcpOrUnixStream :: Unix ( ref mut s) => s. $fun( $arg) ,
136
+ }
137
+ }
138
+ }
139
+
140
+ macro_rules! fwd_ref {
141
+ ( $self: expr, $fun: ident) => {
142
+ match $self {
143
+ TcpOrUnixStream :: Tcp ( s) => ( & mut & * s) . $fun( ) ,
144
+ #[ cfg( unix) ]
145
+ TcpOrUnixStream :: Unix ( s) => ( & mut & * s) . $fun( ) ,
146
+ }
147
+ } ;
148
+ ( $self: expr, $fun: ident, $arg: expr) => {
149
+ match $self {
150
+ TcpOrUnixStream :: Tcp ( s) => ( & mut & * s) . $fun( $arg) ,
151
+ #[ cfg( unix) ]
152
+ TcpOrUnixStream :: Unix ( s) => ( & mut & * s) . $fun( $arg) ,
153
+ }
154
+ }
155
+ }
156
+
157
+ impl TcpOrUnixStream {
158
+ /// [`TcpStream::connect`] or [`UnixStream::connect_addr`]
159
+ pub fn connect < T : Into < SocketAddrOrUnixSocketAddr > > ( addr : T ) -> std:: io:: Result < TcpOrUnixStream > {
160
+ match addr. into ( ) {
161
+ SocketAddrOrUnixSocketAddr :: SocketAddr ( s) => TcpStream :: connect ( s) . map ( TcpOrUnixStream :: Tcp ) ,
162
+ #[ cfg( unix) ]
163
+ SocketAddrOrUnixSocketAddr :: UnixSocketAddr ( s) => UnixStream :: connect_addr ( & s) . map ( TcpOrUnixStream :: Unix ) ,
164
+ }
165
+ }
166
+
167
+ /// [`TcpStream::connect`] with [`ToSocketAddrs`]
168
+ pub fn connect_tsa < T : ToSocketAddrs > ( addr : T ) -> std:: io:: Result < TcpOrUnixStream > {
169
+ TcpStream :: connect ( addr) . map ( TcpOrUnixStream :: Tcp )
170
+ }
171
+
172
+ /// [`TcpStream::local_addr`] or [`UnixStream::local_addr`]
173
+ pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddrOrUnixSocketAddr > {
174
+ match self {
175
+ TcpOrUnixStream :: Tcp ( s) => s. local_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: SocketAddr ) ,
176
+ #[ cfg( unix) ]
177
+ TcpOrUnixStream :: Unix ( s) => s. local_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: UnixSocketAddr ) ,
178
+ }
179
+ }
180
+
181
+ /// [`TcpStream::peer_addr`] or [`UnixStream::peer_addr`]
182
+ pub fn peer_addr ( & self ) -> std:: io:: Result < SocketAddrOrUnixSocketAddr > {
183
+ match self {
184
+ TcpOrUnixStream :: Tcp ( s) => s. peer_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: SocketAddr ) ,
185
+ #[ cfg( unix) ]
186
+ TcpOrUnixStream :: Unix ( s) => s. peer_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: UnixSocketAddr ) ,
187
+ }
188
+ }
189
+
190
+ /// [`TcpStream::read_timeout`] or [`UnixStream::read_timeout`]
191
+ pub fn read_timeout ( & self ) -> std:: io:: Result < Option < std:: time:: Duration > > {
192
+ fwd_ref ! ( self , read_timeout)
193
+ }
194
+
195
+ /// [`TcpStream::set_nonblocking`] or [`UnixStream::set_nonblocking`]
196
+ pub fn set_nonblocking ( & self , nonblocking : bool ) -> std:: io:: Result < ( ) > {
197
+ fwd_ref ! ( self , set_nonblocking, nonblocking)
198
+ }
199
+
200
+ /// [`TcpStream::set_read_timeout`] or [`UnixStream::set_read_timeout`]
201
+ pub fn set_read_timeout ( & self , timeout : Option < std:: time:: Duration > ) -> std:: io:: Result < ( ) > {
202
+ fwd_ref ! ( self , set_read_timeout, timeout)
203
+ }
204
+
205
+ /// [`TcpStream::set_write_timeout`] or [`UnixStream::set_write_timeout`]
206
+ pub fn set_write_timeout ( & self , timeout : Option < std:: time:: Duration > ) -> std:: io:: Result < ( ) > {
207
+ fwd_ref ! ( self , set_write_timeout, timeout)
208
+ }
209
+
210
+ /// [`TcpStream::shutdown`] or [`UnixStream::shutdown`]
211
+ pub fn shutdown ( & self , how : std:: net:: Shutdown ) -> std:: io:: Result < ( ) > {
212
+ fwd_ref ! ( self , shutdown, how)
213
+ }
214
+
215
+ /// [`TcpStream::take_error`] or [`UnixStream::take_error`]
216
+ pub fn take_error ( & self ) -> std:: io:: Result < Option < std:: io:: Error > > {
217
+ fwd_ref ! ( self , take_error)
218
+ }
219
+
220
+ /// [`TcpStream::try_clone`] or [`UnixStream::try_clone`]
221
+ pub fn try_clone ( & self ) -> std:: io:: Result < TcpOrUnixStream > {
222
+ match self {
223
+ TcpOrUnixStream :: Tcp ( s) => s. try_clone ( ) . map ( TcpOrUnixStream :: Tcp ) ,
224
+ #[ cfg( unix) ]
225
+ TcpOrUnixStream :: Unix ( s) => s. try_clone ( ) . map ( TcpOrUnixStream :: Unix ) ,
226
+ }
227
+ }
228
+
229
+ /// [`TcpStream::write_timeout`] or [`UnixStream::write_timeout`]
230
+ pub fn write_timeout ( & self ) -> std:: io:: Result < Option < std:: time:: Duration > > {
231
+ fwd_ref ! ( self , write_timeout)
232
+ }
233
+ }
234
+
235
+ impl io:: Read for TcpOrUnixStream {
236
+ fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
237
+ fwd ! ( self , read, buf)
238
+ }
239
+ fn read_vectored ( & mut self , bufs : & mut [ std:: io:: IoSliceMut < ' _ > ] ) -> std:: io:: Result < usize > {
240
+ fwd ! ( self , read_vectored, bufs)
241
+ }
242
+ fn read_to_end ( & mut self , buf : & mut Vec < u8 > ) -> std:: io:: Result < usize > {
243
+ fwd ! ( self , read_to_end, buf)
244
+ }
245
+ fn read_to_string ( & mut self , buf : & mut String ) -> std:: io:: Result < usize > {
246
+ fwd ! ( self , read_to_string, buf)
247
+ }
248
+ fn read_exact ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < ( ) > {
249
+ fwd ! ( self , read_exact, buf)
250
+ }
251
+ }
252
+
253
+ impl < ' a > io:: Read for & ' a TcpOrUnixStream {
254
+ fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
255
+ fwd_ref ! ( self , read, buf)
256
+ }
257
+ fn read_vectored ( & mut self , bufs : & mut [ std:: io:: IoSliceMut < ' _ > ] ) -> std:: io:: Result < usize > {
258
+ fwd_ref ! ( self , read_vectored, bufs)
259
+ }
260
+ fn read_to_end ( & mut self , buf : & mut Vec < u8 > ) -> std:: io:: Result < usize > {
261
+ fwd_ref ! ( self , read_to_end, buf)
262
+ }
263
+ fn read_to_string ( & mut self , buf : & mut String ) -> std:: io:: Result < usize > {
264
+ fwd_ref ! ( self , read_to_string, buf)
265
+ }
266
+ fn read_exact ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < ( ) > {
267
+ fwd_ref ! ( self , read_exact, buf)
268
+ }
269
+ }
270
+
271
+ impl < ' a > io:: Write for & ' a TcpOrUnixStream {
272
+ fn write ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < usize > {
273
+ fwd_ref ! ( self , write, buf)
274
+ }
275
+ fn write_vectored ( & mut self , bufs : & [ std:: io:: IoSlice < ' _ > ] ) -> std:: io:: Result < usize > {
276
+ fwd_ref ! ( self , write_vectored, bufs)
277
+ }
278
+ fn flush ( & mut self ) -> std:: io:: Result < ( ) > {
279
+ fwd_ref ! ( self , flush)
280
+ }
281
+ fn write_all ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < ( ) > {
282
+ fwd_ref ! ( self , write_all, buf)
283
+ }
284
+ fn write_fmt ( & mut self , fmt : std:: fmt:: Arguments < ' _ > ) -> std:: io:: Result < ( ) > {
285
+ fwd_ref ! ( self , write_fmt, fmt)
286
+ }
287
+ }
288
+
289
+ impl io:: Write for TcpOrUnixStream {
290
+ fn write ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < usize > {
291
+ fwd ! ( self , write, buf)
292
+ }
293
+ fn write_vectored ( & mut self , bufs : & [ std:: io:: IoSlice < ' _ > ] ) -> std:: io:: Result < usize > {
294
+ fwd ! ( self , write_vectored, bufs)
295
+ }
296
+ fn flush ( & mut self ) -> std:: io:: Result < ( ) > {
297
+ fwd ! ( self , flush)
298
+ }
299
+ fn write_all ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < ( ) > {
300
+ fwd ! ( self , write_all, buf)
301
+ }
302
+ fn write_fmt ( & mut self , fmt : std:: fmt:: Arguments < ' _ > ) -> std:: io:: Result < ( ) > {
303
+ fwd ! ( self , write_fmt, fmt)
304
+ }
305
+ }
306
+
307
+ impl Into < TcpOrUnixStream > for TcpStream {
308
+ fn into ( self ) -> TcpOrUnixStream {
309
+ TcpOrUnixStream :: Tcp ( self )
310
+ }
311
+ }
312
+ #[ cfg( unix) ]
313
+ impl Into < TcpOrUnixStream > for UnixStream {
314
+ fn into ( self ) -> TcpOrUnixStream {
315
+ TcpOrUnixStream :: Unix ( self )
316
+ }
317
+ }
318
+
23
319
/// A description of a connection target.
24
320
#[ derive( Debug , Clone ) ]
25
321
pub enum TargetAddr {
0 commit comments