@@ -10,16 +10,343 @@ extern crate libc;
10
10
extern crate winapi;
11
11
12
12
use std:: io;
13
- use std:: net:: { Ipv4Addr , Ipv6Addr , SocketAddr , SocketAddrV4 , SocketAddrV6 , ToSocketAddrs } ;
13
+ use std:: net:: {
14
+ Ipv4Addr , Ipv6Addr , SocketAddr , SocketAddrV4 , SocketAddrV6 , TcpStream , ToSocketAddrs ,
15
+ } ;
16
+ #[ cfg( unix) ]
17
+ use std:: os:: unix:: net:: { UnixStream , SocketAddr as UnixSocketAddr } ;
18
+ #[ cfg( any( target_os = "linux" , target_os = "android" ) ) ]
19
+ use std:: os:: linux:: net:: SocketAddrExt ;
20
+ use std:: hash:: { Hash , Hasher } ;
14
21
use std:: vec;
15
22
16
- pub use v4:: { Socks4Stream , Socks4Listener } ;
17
- pub use v5:: { Socks5Stream , Socks5Listener , Socks5Datagram } ;
23
+ pub use v4:: { Socks4Listener , Socks4Stream } ;
24
+ pub use v5:: { Socks5Datagram , Socks5Listener , Socks5Stream } ;
18
25
19
26
mod v4;
20
27
mod v5;
21
28
mod writev;
22
29
30
+ /// Either a [`SocketAddr`], or, under unix, [`UnixSocketAddr`]
31
+ ///
32
+ /// If `#[cfg(unix)]`, this can hold an internet socket address *or* a unix-domain socket address.
33
+ ///
34
+ /// Otherwise, this can only hold an internet socket address.
35
+ #[ derive( Clone , Debug ) ]
36
+ pub enum SocketAddrOrUnixSocketAddr {
37
+ /// The internet address.
38
+ SocketAddr ( SocketAddr ) ,
39
+ /// The unix-domain address.
40
+ #[ cfg( unix) ]
41
+ UnixSocketAddr ( UnixSocketAddr ) ,
42
+ }
43
+
44
+ impl PartialEq for SocketAddrOrUnixSocketAddr {
45
+ fn eq ( & self , other : & Self ) -> bool {
46
+ match ( self , other) {
47
+ #[ cfg( unix) ]
48
+ ( Self :: SocketAddr ( _) , Self :: UnixSocketAddr ( _) ) => false ,
49
+ #[ cfg( unix) ]
50
+ ( Self :: UnixSocketAddr ( _) , Self :: SocketAddr ( _) ) => false ,
51
+ ( Self :: SocketAddr ( l) , Self :: SocketAddr ( r) ) => l == r,
52
+ #[ cfg( unix) ]
53
+ ( Self :: UnixSocketAddr ( l) , Self :: UnixSocketAddr ( r) ) => {
54
+ if ( l. is_unnamed ( ) && r. is_unnamed ( ) ) || ( l. as_pathname ( ) == r. as_pathname ( ) ) {
55
+ return true ;
56
+ }
57
+ #[ cfg( any( target_os = "linux" , target_os = "android" ) ) ]
58
+ if l. as_abstract_name ( ) == r. as_abstract_name ( ) {
59
+ return true ;
60
+ }
61
+ false
62
+ } ,
63
+ }
64
+ }
65
+ }
66
+
67
+ impl Hash for SocketAddrOrUnixSocketAddr {
68
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
69
+ match self {
70
+ Self :: SocketAddr ( a) => a. hash ( state) ,
71
+ #[ cfg( unix) ]
72
+ Self :: UnixSocketAddr ( a) => {
73
+ a. is_unnamed ( ) . hash ( state) ;
74
+ a. as_pathname ( ) . hash ( state) ;
75
+ #[ cfg( any( target_os = "linux" , target_os = "android" ) ) ]
76
+ a. as_abstract_name ( ) . hash ( state) ;
77
+ }
78
+ }
79
+ }
80
+ }
81
+
82
+ impl Into < SocketAddrOrUnixSocketAddr > for & SocketAddrOrUnixSocketAddr {
83
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
84
+ self . clone ( ) // no allocations, this struct is effectively Copy
85
+ }
86
+ }
87
+ impl Into < SocketAddrOrUnixSocketAddr > for & mut SocketAddrOrUnixSocketAddr {
88
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
89
+ self . clone ( ) // no allocations, this struct is effectively Copy
90
+ }
91
+ }
92
+ impl Into < SocketAddrOrUnixSocketAddr > for SocketAddr {
93
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
94
+ SocketAddrOrUnixSocketAddr :: SocketAddr ( self )
95
+ }
96
+ }
97
+ impl Into < SocketAddrOrUnixSocketAddr > for & SocketAddr {
98
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
99
+ SocketAddrOrUnixSocketAddr :: SocketAddr ( self . clone ( ) )
100
+ }
101
+ }
102
+ #[ cfg( unix) ]
103
+ impl Into < SocketAddrOrUnixSocketAddr > for UnixSocketAddr {
104
+ fn into ( self ) -> SocketAddrOrUnixSocketAddr {
105
+ SocketAddrOrUnixSocketAddr :: UnixSocketAddr ( self )
106
+ }
107
+ }
108
+
109
+ /// Either a [`TcpStream`], or, under unix, [`UnixStream`]
110
+ ///
111
+ /// If `#[cfg(unix)]`, this can hold an internet socket *or* a unix-domain socket.
112
+ ///
113
+ /// Otherwise, this can only hold an internet socket.
114
+ #[ derive( Debug ) ]
115
+ pub enum TcpOrUnixStream {
116
+ /// The internet socket.
117
+ Tcp ( TcpStream ) ,
118
+ #[ cfg( unix) ]
119
+ /// The unix-domain socket.
120
+ Unix ( UnixStream ) ,
121
+ }
122
+
123
+ macro_rules! fwd {
124
+ ( $self: expr, $fun: ident) => {
125
+ match $self {
126
+ TcpOrUnixStream :: Tcp ( ref mut s) => s. $fun( ) ,
127
+ #[ cfg( unix) ]
128
+ TcpOrUnixStream :: Unix ( ref mut s) => s. $fun( ) ,
129
+ }
130
+ } ;
131
+ ( $self: expr, $fun: ident, $arg: expr) => {
132
+ match $self {
133
+ TcpOrUnixStream :: Tcp ( ref mut s) => s. $fun( $arg) ,
134
+ #[ cfg( unix) ]
135
+ TcpOrUnixStream :: Unix ( ref mut s) => s. $fun( $arg) ,
136
+ }
137
+ }
138
+ }
139
+
140
+ macro_rules! fwd_ref {
141
+ ( $self: expr, $fun: ident) => {
142
+ match $self {
143
+ TcpOrUnixStream :: Tcp ( s) => ( & mut & * s) . $fun( ) ,
144
+ #[ cfg( unix) ]
145
+ TcpOrUnixStream :: Unix ( s) => ( & mut & * s) . $fun( ) ,
146
+ }
147
+ } ;
148
+ ( $self: expr, $fun: ident, $arg: expr) => {
149
+ match $self {
150
+ TcpOrUnixStream :: Tcp ( s) => ( & mut & * s) . $fun( $arg) ,
151
+ #[ cfg( unix) ]
152
+ TcpOrUnixStream :: Unix ( s) => ( & mut & * s) . $fun( $arg) ,
153
+ }
154
+ }
155
+ }
156
+
157
+ macro_rules! fwd_move {
158
+ ( $self: expr, $fun: ident) => {
159
+ match $self {
160
+ TcpOrUnixStream :: Tcp ( s) => s. $fun( ) ,
161
+ #[ cfg( unix) ]
162
+ TcpOrUnixStream :: Unix ( s) => s. $fun( ) ,
163
+ }
164
+ } ;
165
+ ( $self: expr, $fun: ident, $arg: expr) => {
166
+ match $self {
167
+ TcpOrUnixStream :: Tcp ( s) => s. $fun( $arg) ,
168
+ #[ cfg( unix) ]
169
+ TcpOrUnixStream :: Unix ( s) => s. $fun( $arg) ,
170
+ }
171
+ }
172
+ }
173
+
174
+ impl TcpOrUnixStream {
175
+ /// [`TcpStream::connect`] or [`UnixStream::connect_addr`]
176
+ pub fn connect < T : Into < SocketAddrOrUnixSocketAddr > > ( addr : T ) -> std:: io:: Result < TcpOrUnixStream > {
177
+ match addr. into ( ) {
178
+ SocketAddrOrUnixSocketAddr :: SocketAddr ( s) => TcpStream :: connect ( s) . map ( TcpOrUnixStream :: Tcp ) ,
179
+ #[ cfg( unix) ]
180
+ SocketAddrOrUnixSocketAddr :: UnixSocketAddr ( s) => UnixStream :: connect_addr ( & s) . map ( TcpOrUnixStream :: Unix ) ,
181
+ }
182
+ }
183
+
184
+ /// [`TcpStream::connect`] with [`ToSocketAddrs`]
185
+ pub fn connect_tsa < T : ToSocketAddrs > ( addr : T ) -> std:: io:: Result < TcpOrUnixStream > {
186
+ TcpStream :: connect ( addr) . map ( TcpOrUnixStream :: Tcp )
187
+ }
188
+
189
+ /// [`TcpStream::local_addr`] or [`UnixStream::local_addr`]
190
+ pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddrOrUnixSocketAddr > {
191
+ match self {
192
+ TcpOrUnixStream :: Tcp ( s) => s. local_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: SocketAddr ) ,
193
+ #[ cfg( unix) ]
194
+ TcpOrUnixStream :: Unix ( s) => s. local_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: UnixSocketAddr ) ,
195
+ }
196
+ }
197
+
198
+ /// [`TcpStream::peer_addr`] or [`UnixStream::peer_addr`]
199
+ pub fn peer_addr ( & self ) -> std:: io:: Result < SocketAddrOrUnixSocketAddr > {
200
+ match self {
201
+ TcpOrUnixStream :: Tcp ( s) => s. peer_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: SocketAddr ) ,
202
+ #[ cfg( unix) ]
203
+ TcpOrUnixStream :: Unix ( s) => s. peer_addr ( ) . map ( SocketAddrOrUnixSocketAddr :: UnixSocketAddr ) ,
204
+ }
205
+ }
206
+
207
+ /// [`TcpStream::read_timeout`] or [`UnixStream::read_timeout`]
208
+ pub fn read_timeout ( & self ) -> std:: io:: Result < Option < std:: time:: Duration > > {
209
+ fwd_ref ! ( self , read_timeout)
210
+ }
211
+
212
+ /// [`TcpStream::set_nonblocking`] or [`UnixStream::set_nonblocking`]
213
+ pub fn set_nonblocking ( & self , nonblocking : bool ) -> std:: io:: Result < ( ) > {
214
+ fwd_ref ! ( self , set_nonblocking, nonblocking)
215
+ }
216
+
217
+ /// [`TcpStream::set_read_timeout`] or [`UnixStream::set_read_timeout`]
218
+ pub fn set_read_timeout ( & self , timeout : Option < std:: time:: Duration > ) -> std:: io:: Result < ( ) > {
219
+ fwd_ref ! ( self , set_read_timeout, timeout)
220
+ }
221
+
222
+ /// [`TcpStream::set_write_timeout`] or [`UnixStream::set_write_timeout`]
223
+ pub fn set_write_timeout ( & self , timeout : Option < std:: time:: Duration > ) -> std:: io:: Result < ( ) > {
224
+ fwd_ref ! ( self , set_write_timeout, timeout)
225
+ }
226
+
227
+ /// [`TcpStream::shutdown`] or [`UnixStream::shutdown`]
228
+ pub fn shutdown ( & self , how : std:: net:: Shutdown ) -> std:: io:: Result < ( ) > {
229
+ fwd_ref ! ( self , shutdown, how)
230
+ }
231
+
232
+ /// [`TcpStream::take_error`] or [`UnixStream::take_error`]
233
+ pub fn take_error ( & self ) -> std:: io:: Result < Option < std:: io:: Error > > {
234
+ fwd_ref ! ( self , take_error)
235
+ }
236
+
237
+ /// [`TcpStream::try_clone`] or [`UnixStream::try_clone`]
238
+ pub fn try_clone ( & self ) -> std:: io:: Result < TcpOrUnixStream > {
239
+ match self {
240
+ TcpOrUnixStream :: Tcp ( s) => s. try_clone ( ) . map ( TcpOrUnixStream :: Tcp ) ,
241
+ #[ cfg( unix) ]
242
+ TcpOrUnixStream :: Unix ( s) => s. try_clone ( ) . map ( TcpOrUnixStream :: Unix ) ,
243
+ }
244
+ }
245
+
246
+ /// [`TcpStream::write_timeout`] or [`UnixStream::write_timeout`]
247
+ pub fn write_timeout ( & self ) -> std:: io:: Result < Option < std:: time:: Duration > > {
248
+ fwd_ref ! ( self , write_timeout)
249
+ }
250
+ }
251
+
252
+ impl io:: Read for TcpOrUnixStream {
253
+ fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
254
+ fwd ! ( self , read, buf)
255
+ }
256
+ fn read_vectored ( & mut self , bufs : & mut [ std:: io:: IoSliceMut < ' _ > ] ) -> std:: io:: Result < usize > {
257
+ fwd ! ( self , read_vectored, bufs)
258
+ }
259
+ fn read_to_end ( & mut self , buf : & mut Vec < u8 > ) -> std:: io:: Result < usize > {
260
+ fwd ! ( self , read_to_end, buf)
261
+ }
262
+ fn read_to_string ( & mut self , buf : & mut String ) -> std:: io:: Result < usize > {
263
+ fwd ! ( self , read_to_string, buf)
264
+ }
265
+ fn read_exact ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < ( ) > {
266
+ fwd ! ( self , read_exact, buf)
267
+ }
268
+ }
269
+
270
+ impl < ' a > io:: Read for & ' a TcpOrUnixStream {
271
+ fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
272
+ fwd_ref ! ( self , read, buf)
273
+ }
274
+ fn read_vectored ( & mut self , bufs : & mut [ std:: io:: IoSliceMut < ' _ > ] ) -> std:: io:: Result < usize > {
275
+ fwd_ref ! ( self , read_vectored, bufs)
276
+ }
277
+ fn read_to_end ( & mut self , buf : & mut Vec < u8 > ) -> std:: io:: Result < usize > {
278
+ fwd_ref ! ( self , read_to_end, buf)
279
+ }
280
+ fn read_to_string ( & mut self , buf : & mut String ) -> std:: io:: Result < usize > {
281
+ fwd_ref ! ( self , read_to_string, buf)
282
+ }
283
+ fn read_exact ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < ( ) > {
284
+ fwd_ref ! ( self , read_exact, buf)
285
+ }
286
+ }
287
+
288
+ impl < ' a > io:: Write for & ' a TcpOrUnixStream {
289
+ fn write ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < usize > {
290
+ fwd_ref ! ( self , write, buf)
291
+ }
292
+ fn write_vectored ( & mut self , bufs : & [ std:: io:: IoSlice < ' _ > ] ) -> std:: io:: Result < usize > {
293
+ fwd_ref ! ( self , write_vectored, bufs)
294
+ }
295
+ fn flush ( & mut self ) -> std:: io:: Result < ( ) > {
296
+ fwd_ref ! ( self , flush)
297
+ }
298
+ fn write_all ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < ( ) > {
299
+ fwd_ref ! ( self , write_all, buf)
300
+ }
301
+ fn write_fmt ( & mut self , fmt : std:: fmt:: Arguments < ' _ > ) -> std:: io:: Result < ( ) > {
302
+ fwd_ref ! ( self , write_fmt, fmt)
303
+ }
304
+ }
305
+
306
+ impl io:: Write for TcpOrUnixStream {
307
+ fn write ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < usize > {
308
+ fwd ! ( self , write, buf)
309
+ }
310
+ fn write_vectored ( & mut self , bufs : & [ std:: io:: IoSlice < ' _ > ] ) -> std:: io:: Result < usize > {
311
+ fwd ! ( self , write_vectored, bufs)
312
+ }
313
+ fn flush ( & mut self ) -> std:: io:: Result < ( ) > {
314
+ fwd ! ( self , flush)
315
+ }
316
+ fn write_all ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < ( ) > {
317
+ fwd ! ( self , write_all, buf)
318
+ }
319
+ fn write_fmt ( & mut self , fmt : std:: fmt:: Arguments < ' _ > ) -> std:: io:: Result < ( ) > {
320
+ fwd ! ( self , write_fmt, fmt)
321
+ }
322
+ }
323
+
324
+ #[ cfg( unix) ]
325
+ impl std:: os:: fd:: IntoRawFd for TcpOrUnixStream {
326
+ fn into_raw_fd ( self ) -> std:: os:: fd:: RawFd {
327
+ fwd_move ! ( self , into_raw_fd)
328
+ }
329
+ }
330
+
331
+ #[ cfg( windows) ]
332
+ impl std:: os:: windows:: io:: IntoRawSocket for TcpOrUnixStream {
333
+ fn into_raw_socket ( self ) -> std:: os:: windows:: io:: RawSocket {
334
+ fwd_move ! ( self , into_raw_fd)
335
+ }
336
+ }
337
+
338
+ impl Into < TcpOrUnixStream > for TcpStream {
339
+ fn into ( self ) -> TcpOrUnixStream {
340
+ TcpOrUnixStream :: Tcp ( self )
341
+ }
342
+ }
343
+ #[ cfg( unix) ]
344
+ impl Into < TcpOrUnixStream > for UnixStream {
345
+ fn into ( self ) -> TcpOrUnixStream {
346
+ TcpOrUnixStream :: Unix ( self )
347
+ }
348
+ }
349
+
23
350
/// A description of a connection target.
24
351
#[ derive( Debug , Clone ) ]
25
352
pub enum TargetAddr {
0 commit comments