diff --git a/src/net/mod.rs b/src/net/mod.rs index 7d714ca00..41d81a2d4 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -32,8 +32,13 @@ pub use self::tcp::{TcpListener, TcpStream}; mod udp; #[cfg(not(target_os = "wasi"))] pub use self::udp::UdpSocket; - -#[cfg(unix)] +#[cfg(not(target_os = "wasi"))] mod uds; +#[cfg(not(target_os = "wasi"))] +pub use self::uds::{SocketAddr, UnixListener, UnixStream}; + +#[cfg(not(target_os = "wasi"))] +pub(crate) use self::uds::AddressKind; + #[cfg(unix)] -pub use self::uds::{SocketAddr, UnixDatagram, UnixListener, UnixStream}; +pub use self::uds::UnixDatagram; diff --git a/src/net/uds/addr.rs b/src/net/uds/addr.rs new file mode 100644 index 000000000..9fb4c9c88 --- /dev/null +++ b/src/net/uds/addr.rs @@ -0,0 +1,97 @@ +use crate::sys; +use std::path::Path; +use std::{ascii, fmt}; + +/// An address associated with a `mio` specific Unix socket. +/// +/// This is implemented instead of imported from [`net::SocketAddr`] because +/// there is no way to create a [`net::SocketAddr`]. One must be returned by +/// [`accept`], so this is returned instead. +/// +/// [`net::SocketAddr`]: std::os::unix::net::SocketAddr +/// [`accept`]: #method.accept +pub struct SocketAddr { + inner: sys::SocketAddr, +} + +struct AsciiEscaped<'a>(&'a [u8]); + +pub(crate) enum AddressKind<'a> { + Unnamed, + Pathname(&'a Path), + Abstract(&'a [u8]), +} + +impl SocketAddr { + pub(crate) fn new(inner: sys::SocketAddr) -> Self { + SocketAddr { inner } + } + + fn address(&self) -> AddressKind<'_> { + self.inner.address() + } +} + +cfg_os_poll! { + impl SocketAddr { + /// Returns `true` if the address is unnamed. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn is_unnamed(&self) -> bool { + matches!(self.address(), AddressKind::Unnamed) + } + + /// Returns the contents of this address if it is a `pathname` address. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } + } + + /// Returns the contents of this address if it is an abstract namespace + /// without the leading null byte. + // Link to std::os::unix::net::SocketAddr pending + // https://github.com/rust-lang/rust/issues/85410. + pub fn as_abstract_namespace(&self) -> Option<&[u8]> { + if let AddressKind::Abstract(path) = self.address() { + Some(path) + } else { + None + } + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.address()) + } +} + +impl fmt::Debug for AddressKind<'_> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AddressKind::Unnamed => write!(fmt, "(unnamed)"), + AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), + AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), + } + } +} + +impl<'a> fmt::Display for AsciiEscaped<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "\"")?; + for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { + write!(fmt, "{}", byte as char)?; + } + write!(fmt, "\"") + } +} diff --git a/src/net/uds/datagram.rs b/src/net/uds/datagram.rs index e963d6e2f..7bc1b7b1f 100644 --- a/src/net/uds/datagram.rs +++ b/src/net/uds/datagram.rs @@ -1,4 +1,5 @@ use crate::io_source::IoSource; +use crate::net::SocketAddr; use crate::{event, sys, Interest, Registry, Token}; use std::net::Shutdown; @@ -54,24 +55,25 @@ impl UnixDatagram { } /// Returns the address of this socket. - pub fn local_addr(&self) -> io::Result { - sys::uds::datagram::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::datagram::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the address of this socket's peer. /// /// The `connect` method will connect the socket to a peer. - pub fn peer_addr(&self) -> io::Result { - sys::uds::datagram::peer_addr(&self.inner) + pub fn peer_addr(&self) -> io::Result { + sys::uds::datagram::peer_addr(&self.inner).map(SocketAddr::new) } /// Receives data from the socket. /// /// On success, returns the number of bytes read and the address from /// whence the data came. - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, sys::SocketAddr)> { + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { self.inner .do_io(|inner| sys::uds::datagram::recv_from(inner, buf)) + .map(|(nread, addr)| (nread, SocketAddr::new(addr))) } /// Receives data from the socket. diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 37e8106d8..181806202 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -2,8 +2,14 @@ use crate::io_source::IoSource; use crate::net::{SocketAddr, UnixStream}; use crate::{event, sys, Interest, Registry, Token}; +#[cfg(windows)] +use crate::sys::windows::stdnet as net; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(unix)] use std::os::unix::net; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::path::Path; use std::{fmt, io}; @@ -24,23 +30,34 @@ impl UnixListener { /// standard library in the Mio equivalent. The conversion assumes nothing /// about the underlying listener; it is left up to the user to set it in /// non-blocking mode. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(listener: net::UnixListener) -> UnixListener { UnixListener { inner: IoSource::new(listener), } } + #[cfg(windows)] + pub(crate) fn from_std(listener: net::UnixListener) -> UnixListener { + UnixListener { + inner: IoSource::new(listener), + } + } + /// Accepts a new incoming connection to this listener. /// /// The call is responsible for ensuring that the listening socket is in /// non-blocking mode. pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - sys::uds::listener::accept(&self.inner) + self.inner + .do_io(sys::uds::listener::accept) + .map(|(stream, addr)| (stream, SocketAddr::new(addr))) } /// Returns the local socket address of this listener. - pub fn local_addr(&self) -> io::Result { - sys::uds::listener::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::listener::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the value of the `SO_ERROR` option. @@ -79,18 +96,24 @@ impl fmt::Debug for UnixListener { } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl IntoRawFd for UnixListener { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl AsRawFd for UnixListener { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl FromRawFd for UnixListener { /// Converts a `RawFd` to a `UnixListener`. /// @@ -102,3 +125,27 @@ impl FromRawFd for UnixListener { UnixListener::from_std(FromRawFd::from_raw_fd(fd)) } } + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener::from_std(FromRawSocket::from_raw_socket(sock)) + } +} diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index 6b4ffdc43..2a12f965e 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -1,4 +1,7 @@ +#[cfg(unix)] mod datagram; +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] pub use self::datagram::UnixDatagram; mod listener; @@ -7,4 +10,6 @@ pub use self::listener::UnixListener; mod stream; pub use self::stream::UnixStream; -pub use crate::sys::SocketAddr; +mod addr; +pub(crate) use self::addr::AddressKind; +pub use self::addr::SocketAddr; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index b41ef9da3..0a04f035a 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -1,11 +1,18 @@ use crate::io_source::IoSource; +use crate::net::SocketAddr; use crate::{event, sys, Interest, Registry, Token}; +#[cfg(windows)] +use crate::sys::windows::stdnet as net; use std::fmt; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(unix)] use std::os::unix::net; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::path::Path; /// A non-blocking Unix stream socket. @@ -34,15 +41,26 @@ impl UnixStream { /// The Unix stream here will not have `connect` called on it, so it /// should already be connected via some other means (be it manually, or /// the standard library). + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(stream: net::UnixStream) -> UnixStream { UnixStream { inner: IoSource::new(stream), } } + #[cfg(windows)] + pub(crate) fn from_std(stream: net::UnixStream) -> UnixStream { + UnixStream { + inner: IoSource::new(stream), + } + } + /// Creates an unnamed pair of connected sockets. /// /// Returns two `UnixStream`s which are connected to each other. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn pair() -> io::Result<(UnixStream, UnixStream)> { sys::uds::stream::pair().map(|(stream1, stream2)| { (UnixStream::from_std(stream1), UnixStream::from_std(stream2)) @@ -50,13 +68,13 @@ impl UnixStream { } /// Returns the socket address of the local half of this connection. - pub fn local_addr(&self) -> io::Result { - sys::uds::stream::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::stream::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the socket address of the remote half of this connection. - pub fn peer_addr(&self) -> io::Result { - sys::uds::stream::peer_addr(&self.inner) + pub fn peer_addr(&self) -> io::Result { + sys::uds::stream::peer_addr(&self.inner).map(SocketAddr::new) } /// Returns the value of the `SO_ERROR` option. @@ -86,7 +104,8 @@ impl UnixStream { /// /// # Examples /// - /// ``` + #[cfg_attr(unix, doc = "```")] + #[cfg_attr(windows, doc = "```ignore")] /// # use std::error::Error; /// # /// # fn main() -> Result<(), Box> { @@ -134,6 +153,83 @@ impl UnixStream { /// # Ok(()) /// # } /// ``` + /// + #[cfg_attr(windows, doc = "```")] + #[cfg_attr(unix, doc = "```ignore")] + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use std::io; + /// use std::os::windows::io::AsRawSocket; + /// use std::os::raw::c_int; + /// use mio::net::{UnixStream, UnixListener}; + /// use windows_sys::Win32::Networking::WinSock; + /// use std::convert::TryInto; + /// + /// let file_path = std::env::temp_dir().join("server.sock"); + /// # let _ = std::fs::remove_file(&file_path); + /// let server = UnixListener::bind(&file_path).unwrap(); + /// + /// let handle = std::thread::spawn(move || { + /// if let Ok((stream2, _)) = server.accept() { + /// // Wait until the stream is readable... + /// + /// // Read from the stream using a direct WinSock call, of course the + /// // `io::Read` implementation would be easier to use. + /// let mut buf = [0; 512]; + /// let n = stream2.try_io(|| { + /// let res = unsafe { + /// WinSock::recv( + /// stream2.as_raw_socket().try_into().unwrap(), + /// &mut buf as *mut _ as *mut _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::last_os_error()) + /// } + /// }).unwrap(); + /// eprintln!("read {} bytes", n); + /// } + /// }); + /// + /// let stream1 = UnixStream::connect(&file_path).unwrap(); + /// + /// // Wait until the stream is writable... + /// + /// // Write to the stream using a direct WinSock call, of course the + /// // `io::Write` implementation would be easier to use. + /// let buf = b"hello"; + /// let n = stream1.try_io(|| { + /// let res = unsafe { + /// WinSock::send( + /// stream1.as_raw_socket().try_into().unwrap(), + /// &buf as *const _ as *const _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::send, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::from_raw_os_error(unsafe { + /// WinSock::WSAGetLastError() + /// })) + /// } + /// })?; + /// eprintln!("write {} bytes", n); + /// + /// # handle.join().unwrap(); + /// # Ok(()) + /// # } + /// ``` pub fn try_io(&self, f: F) -> io::Result where F: FnOnce() -> io::Result, @@ -220,18 +316,24 @@ impl fmt::Debug for UnixStream { } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl IntoRawFd for UnixStream { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl AsRawFd for UnixStream { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } +#[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl FromRawFd for UnixStream { /// Converts a `RawFd` to a `UnixStream`. /// @@ -243,3 +345,27 @@ impl FromRawFd for UnixStream { UnixStream::from_std(FromRawFd::from_raw_fd(fd)) } } + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream::from_std(FromRawSocket::from_raw_socket(sock)) + } +} diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 2a968b265..13b180c4c 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -59,7 +59,7 @@ cfg_os_poll! { #[cfg(windows)] cfg_os_poll! { - mod windows; + pub(crate) mod windows; pub use self::windows::*; } @@ -81,6 +81,16 @@ cfg_not_os_poll! { #[cfg(unix)] cfg_net! { - pub use self::unix::SocketAddr; + pub(crate) use self::unix::SocketAddr; + } + + #[cfg(windows)] + cfg_any_os_ext! { + pub(crate) mod windows; + } + + #[cfg(windows)] + cfg_net! { + pub(crate) use self::windows::SocketAddr; } } diff --git a/src/sys/shell/mod.rs b/src/sys/shell/mod.rs index 8a3175f76..c29bcc9f6 100644 --- a/src/sys/shell/mod.rs +++ b/src/sys/shell/mod.rs @@ -15,7 +15,6 @@ pub(crate) use self::waker::Waker; cfg_net! { pub(crate) mod tcp; pub(crate) mod udp; - #[cfg(unix)] pub(crate) mod uds; } diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index c18aca042..3aac1bd7a 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -1,5 +1,6 @@ +#[cfg(unix)] pub(crate) mod datagram { - use crate::net::SocketAddr; + use crate::sys::SocketAddr; use std::io; use std::os::unix::net; use std::path::Path; @@ -33,8 +34,12 @@ pub(crate) mod datagram { } pub(crate) mod listener { - use crate::net::{SocketAddr, UnixStream}; + use crate::net::UnixStream; + #[cfg(windows)] + use crate::sys::windows::stdnet as net; + use crate::sys::SocketAddr; use std::io; + #[cfg(unix)] use std::os::unix::net; use std::path::Path; @@ -52,8 +57,11 @@ pub(crate) mod listener { } pub(crate) mod stream { - use crate::net::SocketAddr; + #[cfg(windows)] + use crate::sys::windows::stdnet as net; + use crate::sys::SocketAddr; use std::io; + #[cfg(unix)] use std::os::unix::net; use std::path::Path; @@ -61,6 +69,7 @@ pub(crate) mod stream { os_required!() } + #[cfg(unix)] pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { os_required!() } diff --git a/src/sys/unix/mod.rs b/src/sys/unix/mod.rs index 231480a5d..b80bfa7d2 100644 --- a/src/sys/unix/mod.rs +++ b/src/sys/unix/mod.rs @@ -29,7 +29,7 @@ cfg_os_poll! { pub(crate) mod tcp; pub(crate) mod udp; pub(crate) mod uds; - pub use self::uds::SocketAddr; + pub(crate) use self::uds::SocketAddr; } cfg_io_source! { @@ -62,7 +62,7 @@ cfg_os_poll! { cfg_not_os_poll! { cfg_net! { mod uds; - pub use self::uds::SocketAddr; + pub(crate) use self::uds::SocketAddr; } cfg_any_os_ext! { diff --git a/src/sys/unix/uds/listener.rs b/src/sys/unix/uds/listener.rs index 79bd14ee0..0b13ab817 100644 --- a/src/sys/unix/uds/listener.rs +++ b/src/sys/unix/uds/listener.rs @@ -1,5 +1,6 @@ use super::socket_addr; -use crate::net::{SocketAddr, UnixStream}; +use super::SocketAddr; +use crate::net::UnixStream; use crate::sys::unix::net::new_socket; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net; diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index 8e28a9573..d715e611e 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -1,5 +1,5 @@ mod socketaddr; -pub use self::socketaddr::SocketAddr; +pub(crate) use self::socketaddr::SocketAddr; /// Get the `sun_path` field offset of `sockaddr_un` for the target OS. /// diff --git a/src/sys/unix/uds/socketaddr.rs b/src/sys/unix/uds/socketaddr.rs index 4c7c41161..acdc8a662 100644 --- a/src/sys/unix/uds/socketaddr.rs +++ b/src/sys/unix/uds/socketaddr.rs @@ -1,32 +1,15 @@ use super::path_offset; +use crate::net::AddressKind; use std::ffi::OsStr; use std::os::unix::ffi::OsStrExt; -use std::path::Path; -use std::{ascii, fmt}; -/// An address associated with a `mio` specific Unix socket. -/// -/// This is implemented instead of imported from [`net::SocketAddr`] because -/// there is no way to create a [`net::SocketAddr`]. One must be returned by -/// [`accept`], so this is returned instead. -/// -/// [`net::SocketAddr`]: std::os::unix::net::SocketAddr -/// [`accept`]: #method.accept -pub struct SocketAddr { +pub(crate) struct SocketAddr { sockaddr: libc::sockaddr_un, socklen: libc::socklen_t, } -struct AsciiEscaped<'a>(&'a [u8]); - -enum AddressKind<'a> { - Unnamed, - Pathname(&'a Path), - Abstract(&'a [u8]), -} - impl SocketAddr { - fn address(&self) -> AddressKind<'_> { + pub(crate) fn address(&self) -> AddressKind<'_> { let offset = path_offset(&self.sockaddr); // Don't underflow in `len` below. if (self.socklen as usize) < offset { @@ -72,59 +55,5 @@ cfg_os_poll! { pub(crate) fn from_parts(sockaddr: libc::sockaddr_un, socklen: libc::socklen_t) -> SocketAddr { SocketAddr { sockaddr, socklen } } - - /// Returns `true` if the address is unnamed. - /// - /// Documentation reflected in [`SocketAddr`] - /// - /// [`SocketAddr`]: std::os::unix::net::SocketAddr - pub fn is_unnamed(&self) -> bool { - matches!(self.address(), AddressKind::Unnamed) - } - - /// Returns the contents of this address if it is a `pathname` address. - /// - /// Documentation reflected in [`SocketAddr`] - /// - /// [`SocketAddr`]: std::os::unix::net::SocketAddr - pub fn as_pathname(&self) -> Option<&Path> { - if let AddressKind::Pathname(path) = self.address() { - Some(path) - } else { - None - } - } - - /// Returns the contents of this address if it is an abstract namespace - /// without the leading null byte. - // Link to std::os::unix::net::SocketAddr pending - // https://github.com/rust-lang/rust/issues/85410. - pub fn as_abstract_namespace(&self) -> Option<&[u8]> { - if let AddressKind::Abstract(path) = self.address() { - Some(path) - } else { - None - } - } - } -} - -impl fmt::Debug for SocketAddr { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.address() { - AddressKind::Unnamed => write!(fmt, "(unnamed)"), - AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), - AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), - } - } -} - -impl<'a> fmt::Display for AsciiEscaped<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "\"")?; - for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { - write!(fmt, "{}", byte as char)?; - } - write!(fmt, "\"") } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index f8b72fc49..07f7dda6c 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,151 +1,174 @@ -mod afd; +// Macro must be defined before any modules that uses them. +/// Helper macro to execute a system call that returns an `io::Result`. +#[allow(unused_macros)] +macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { $fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} -pub mod event; -pub use event::{Event, Events}; +/// Helper macro to execute a WinSock system call that returns an `io::Result`. +#[allow(unused_macros)] +macro_rules! wsa_syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_value: expr) => {{ + let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; + if PartialEq::eq(&res, &$err_value) { + Err(std::io::Error::from_raw_os_error(unsafe { + windows_sys::Win32::Networking::WinSock::WSAGetLastError() + })) + } else { + Ok(res) + } + }}; +} -mod handle; -use handle::Handle; +cfg_net! { + pub(crate) mod stdnet; + pub(crate) mod uds; + pub(crate) use self::uds::SocketAddr; +} -mod io_status_block; -mod iocp; +cfg_os_poll! { + mod afd; -mod overlapped; -use overlapped::Overlapped; + pub mod event; + pub use event::{Event, Events}; -mod selector; -pub use selector::{Selector, SelectorInner, SockState}; + mod handle; + use handle::Handle; -// Macros must be defined before the modules that use them -cfg_net! { - /// Helper macro to execute a system call that returns an `io::Result`. - // - // Macro must be defined before any modules that uses them. - macro_rules! syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { $fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } - }}; - } + mod io_status_block; + mod iocp; - mod net; + mod overlapped; + use overlapped::Overlapped; - pub(crate) mod tcp; - pub(crate) mod udp; -} + mod selector; + pub use selector::{Selector, SelectorInner, SockState}; -cfg_os_ext! { - pub(crate) mod named_pipe; -} + // Macros must be defined before the modules that use them + cfg_net! { + mod net; -mod waker; -pub(crate) use waker::Waker; + pub(crate) mod tcp; + pub(crate) mod udp; + } -cfg_io_source! { - use std::io; - use std::os::windows::io::RawSocket; - use std::pin::Pin; - use std::sync::{Arc, Mutex}; + cfg_os_ext! { + pub(crate) mod named_pipe; + } - use crate::{Interest, Registry, Token}; + mod waker; + pub(crate) use waker::Waker; - struct InternalState { - selector: Arc, - token: Token, - interests: Interest, - sock_state: Pin>>, - } + cfg_io_source! { + use std::io; + use std::os::windows::io::RawSocket; + use std::pin::Pin; + use std::sync::{Arc, Mutex}; - impl Drop for InternalState { - fn drop(&mut self) { - let mut sock_state = self.sock_state.lock().unwrap(); - sock_state.mark_delete(); + use crate::{Interest, Registry, Token}; + + struct InternalState { + selector: Arc, + token: Token, + interests: Interest, + sock_state: Pin>>, } - } - pub struct IoSourceState { - // This is `None` if the socket has not yet been registered. - // - // We box the internal state to not increase the size on the stack as the - // type might move around a lot. - inner: Option>, - } + impl Drop for InternalState { + fn drop(&mut self) { + let mut sock_state = self.sock_state.lock().unwrap(); + sock_state.mark_delete(); + } + } - impl IoSourceState { - pub fn new() -> IoSourceState { - IoSourceState { inner: None } + pub struct IoSourceState { + // This is `None` if the socket has not yet been registered. + // + // We box the internal state to not increase the size on the stack as the + // type might move around a lot. + inner: Option>, } - pub fn do_io(&self, f: F, io: &T) -> io::Result - where - F: FnOnce(&T) -> io::Result, - { - let result = f(io); - if let Err(ref e) = result { - if e.kind() == io::ErrorKind::WouldBlock { - self.inner.as_ref().map_or(Ok(()), |state| { - state - .selector - .reregister(state.sock_state.clone(), state.token, state.interests) - })?; - } + impl IoSourceState { + pub fn new() -> IoSourceState { + IoSourceState { inner: None } } - result - } - pub fn register( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - socket: RawSocket, - ) -> io::Result<()> { - if self.inner.is_some() { - Err(io::ErrorKind::AlreadyExists.into()) - } else { - registry - .selector() - .register(socket, token, interests) - .map(|state| { - self.inner = Some(Box::new(state)); - }) + pub fn do_io(&self, f: F, io: &T) -> io::Result + where + F: FnOnce(&T) -> io::Result, + { + let result = f(io); + if let Err(ref e) = result { + if e.kind() == io::ErrorKind::WouldBlock { + self.inner.as_ref().map_or(Ok(()), |state| { + state + .selector + .reregister(state.sock_state.clone(), state.token, state.interests) + })?; + } + } + result } - } - pub fn reregister( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - ) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { + pub fn register( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + socket: RawSocket, + ) -> io::Result<()> { + if self.inner.is_some() { + Err(io::ErrorKind::AlreadyExists.into()) + } else { registry .selector() - .reregister(state.sock_state.clone(), token, interests) - .map(|()| { - state.token = token; - state.interests = interests; + .register(socket, token, interests) + .map(|state| { + self.inner = Some(Box::new(state)); }) } - None => Err(io::ErrorKind::NotFound.into()), } - } - pub fn deregister(&mut self) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { - { - let mut sock_state = state.sock_state.lock().unwrap(); - sock_state.mark_delete(); + pub fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + ) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + registry + .selector() + .reregister(state.sock_state.clone(), token, interests) + .map(|()| { + state.token = token; + state.interests = interests; + }) + } + None => Err(io::ErrorKind::NotFound.into()), + } + } + + pub fn deregister(&mut self) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + { + let mut sock_state = state.sock_state.lock().unwrap(); + sock_state.mark_delete(); + } + self.inner = None; + Ok(()) } - self.inner = None; - Ok(()) + None => Err(io::ErrorKind::NotFound.into()), } - None => Err(io::ErrorKind::NotFound.into()), } } } diff --git a/src/sys/windows/net.rs b/src/sys/windows/net.rs index 102ba7979..d114da408 100644 --- a/src/sys/windows/net.rs +++ b/src/sys/windows/net.rs @@ -1,24 +1,12 @@ use std::io; use std::mem; use std::net::SocketAddr; -use std::sync::Once; use windows_sys::Win32::Networking::WinSock::{ ioctlsocket, socket, AF_INET, AF_INET6, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6, SOCKADDR_IN6_0, SOCKET, }; -/// Initialise the network stack for Windows. -pub(crate) fn init() { - static INIT: Once = Once::new(); - INIT.call_once(|| { - // Let standard library call `WSAStartup` for us, we can't do it - // ourselves because otherwise using any type in `std::net` would panic - // when it tries to call `WSAStartup` a second time. - drop(std::net::UdpSocket::bind("127.0.0.1:0")); - }); -} - /// Create a new non-blocking socket. pub(crate) fn new_ip_socket(addr: SocketAddr, socket_type: u16) -> io::Result { let domain = match addr { diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs new file mode 100644 index 000000000..26b1fddde --- /dev/null +++ b/src/sys/windows/stdnet/addr.rs @@ -0,0 +1,124 @@ +use crate::net::AddressKind; +use std::os::raw::c_int; +use std::path::Path; +use std::{fmt, io, mem}; + +use windows_sys::Win32::Networking::WinSock::{sockaddr_un, SOCKADDR}; + +fn path_offset(addr: &sockaddr_un) -> usize { + // Work with an actual instance of the type since using a null pointer is UB + let base = addr as *const _ as usize; + let path = &addr.sun_path as *const _ as usize; + path - base +} + +cfg_os_poll! { + use windows_sys::Win32::Networking::WinSock::AF_UNIX; + pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + sockaddr.sun_family = AF_UNIX; + + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ) + })?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } + + if bytes.len() >= sockaddr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + + sockaddr.sun_path[..bytes.len()].copy_from_slice(bytes); + + let offset = path_offset(&sockaddr); + let mut socklen = offset + bytes.len(); + + match bytes.first() { + // The struct has already been zeroes so the null byte for pathname + // addresses is already there. + Some(&0) | None => {} + Some(_) => socklen += 1, + } + + Ok((sockaddr, socklen as c_int)) + } +} + +pub(crate) struct SocketAddr { + addr: sockaddr_un, + len: c_int, +} + +impl SocketAddr { + pub(crate) fn init(f: F) -> io::Result<(T, SocketAddr)> + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, + { + let mut sockaddr = { + let sockaddr = mem::MaybeUninit::::zeroed(); + unsafe { sockaddr.assume_init() } + }; + + let mut len = mem::size_of::() as c_int; + let result = f(&mut sockaddr as *mut _ as *mut _, &mut len)?; + Ok(( + result, + SocketAddr { + addr: sockaddr, + len, + }, + )) + } + + pub(crate) fn new(f: F) -> io::Result + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, + { + SocketAddr::init(f).map(|(_, addr)| addr) + } + + pub(crate) fn address(&self) -> AddressKind<'_> { + let len = self.len as usize - path_offset(&self.addr); + // sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path + + // macOS seems to return a len of 16 and a zeroed sun_path for unnamed addresses + if len == 0 { + AddressKind::Unnamed + } else if self.addr.sun_path[0] == 0 { + AddressKind::Abstract(&self.addr.sun_path[1..len]) + } else { + use std::ffi::CStr; + let pathname = + unsafe { CStr::from_bytes_with_nul_unchecked(&self.addr.sun_path[..len]) }; + AddressKind::Pathname(Path::new(pathname.to_str().unwrap())) + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.address()) + } +} diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs new file mode 100644 index 000000000..214167276 --- /dev/null +++ b/src/sys/windows/stdnet/listener.rs @@ -0,0 +1,83 @@ +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::{fmt, io, mem}; + +use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; + +use super::{socket::Socket, SocketAddr}; + +pub(crate) struct UnixListener(Socket); + +impl UnixListener { + pub(crate) fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn take_error(&self) -> io::Result> { + self.0.take_error() + } +} + +cfg_os_poll! { + use std::path::Path; + + use super::{socket_addr, UnixStream}; + + impl UnixListener { + pub(crate) fn bind>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + wsa_syscall!( + bind(inner.as_raw_socket() as _, &addr as *const _ as *const _, len as _), + SOCKET_ERROR + )?; + wsa_syscall!(listen(inner.as_raw_socket() as _, 1024), SOCKET_ERROR)?; + Ok(UnixListener(inner)) + } + + pub(crate) fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + SocketAddr::init(|addr, len| self.0.accept(addr, len)) + .map(|(sock, addr)| (UnixStream(sock), addr)) + } + + pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + } +} + +impl fmt::Debug for UnixListener { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixListener"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + builder.finish() + } +} + +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs new file mode 100644 index 000000000..0eb5130d4 --- /dev/null +++ b/src/sys/windows/stdnet/mod.rs @@ -0,0 +1,26 @@ +//! Implementation of blocking UDS types for windows, mirrors std::os::unix::net. +mod addr; +mod listener; +mod socket; +mod stream; + +pub(crate) use self::addr::SocketAddr; +pub(crate) use self::listener::UnixListener; +pub(crate) use self::stream::UnixStream; + +cfg_os_poll! { + pub(self) use self::addr::socket_addr; + + use std::sync::Once; + + /// Initialise the network stack for Windows. + pub(crate) fn init() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + // Let standard library call `WSAStartup` for us, we can't do it + // ourselves because otherwise using any type in `std::net` would panic + // when it tries to call `WSAStartup` a second time. + drop(std::net::UdpSocket::bind("127.0.0.1:0")); + }); + } +} diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs new file mode 100644 index 000000000..9212c1e04 --- /dev/null +++ b/src/sys/windows/stdnet/socket.rs @@ -0,0 +1,186 @@ +use std::cmp::min; +use std::convert::TryInto; +use std::io::{self, IoSlice, IoSliceMut}; +use std::mem; +use std::net::Shutdown; +use std::os::raw::c_int; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::ptr; + +use windows_sys::Win32::Networking::WinSock::{self, closesocket, SOCKET, SOCKET_ERROR, WSABUF}; + +/// Maximum size of a buffer passed to system call like `recv` and `send`. +const MAX_BUF_LEN: usize = c_int::MAX as usize; + +#[derive(Debug)] +pub(crate) struct Socket(SOCKET); + +impl Socket { + pub fn recv(&self, buf: &mut [u8]) -> io::Result { + let ret = wsa_syscall!( + recv(self.0, buf.as_mut_ptr() as *mut _, buf.len() as c_int, 0,), + SOCKET_ERROR + )?; + Ok(ret as usize) + } + + pub fn recv_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let mut total = 0; + let mut flags: u32 = 0; + let bufs = unsafe { &mut *(bufs as *mut [IoSliceMut<'_>] as *mut [WSABUF]) }; + let res = wsa_syscall!( + WSARecv( + self.0, + bufs.as_mut_ptr().cast(), + min(bufs.len(), u32::MAX as usize) as u32, + &mut total, + &mut flags, + ptr::null_mut(), + None, + ), + SOCKET_ERROR + ); + match res { + Ok(_) => Ok(total as usize), + Err(ref err) if err.raw_os_error() == Some(WinSock::WSAESHUTDOWN as i32) => Ok(0), + Err(err) => Err(err), + } + } + + pub fn send(&self, buf: &[u8]) -> io::Result { + wsa_syscall!( + send( + self.0, + buf.as_ptr().cast(), + min(buf.len(), MAX_BUF_LEN) as c_int, + 0, + ), + SOCKET_ERROR + ) + .map(|n| n as usize) + } + + pub fn send_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + let mut total = 0; + wsa_syscall!( + WSASend( + self.0, + // FIXME: From the `WSASend` docs [1]: + // > For a Winsock application, once the WSASend function is called, + // > the system owns these buffers and the application may not + // > access them. + // + // So what we're doing is actually UB as `bufs` needs to be `&mut + // [IoSlice<'_>]`. + // + // See: https://github.com/rust-lang/socket2-rs/issues/129. + // + // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend + bufs.as_ptr() as *mut _, + min(bufs.len(), u32::MAX as usize) as u32, + &mut total, + 0, + std::ptr::null_mut(), + None, + ), + SOCKET_ERROR + ) + .map(|_| total as usize) + } + + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + let how = match how { + Shutdown::Write => WinSock::SD_SEND, + Shutdown::Read => WinSock::SD_RECEIVE, + Shutdown::Both => WinSock::SD_BOTH, + }; + wsa_syscall!(shutdown(self.0, how.try_into().unwrap()), SOCKET_ERROR)?; + Ok(()) + } + + pub fn take_error(&self) -> io::Result> { + let mut val: mem::MaybeUninit = mem::MaybeUninit::uninit(); + let mut len = mem::size_of::() as i32; + wsa_syscall!( + getsockopt( + self.0 as _, + WinSock::SOL_SOCKET.try_into().unwrap(), + WinSock::SO_ERROR.try_into().unwrap(), + &mut val as *mut _ as *mut _, + &mut len, + ), + SOCKET_ERROR + )?; + assert_eq!(len as usize, mem::size_of::()); + let val = unsafe { val.assume_init() }; + if val == 0 { + Ok(None) + } else { + Ok(Some(io::Error::from_raw_os_error(val as i32))) + } + } +} + +cfg_os_poll! { + use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; + use super::init; + + impl Socket { + pub fn new() -> io::Result { + init(); + wsa_syscall!( + WSASocketW( + WinSock::AF_UNIX.into(), + WinSock::SOCK_STREAM.into(), + 0, + ptr::null_mut(), + 0, + WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, + ), + INVALID_SOCKET + ).map(Socket) + } + + pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { + // WinSock's accept returns a socket with the same properties as the listener. it is + // called on. In particular, the WSA_FLAG_NO_HANDLE_INHERIT will be inherited from the + // listener. + wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET).map(Socket) + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = if nonblocking { 1 } else { 0 }; + wsa_syscall!( + ioctlsocket(self.0, WinSock::FIONBIO, &mut nonblocking), + SOCKET_ERROR + )?; + Ok(()) + } + } +} + +impl Drop for Socket { + fn drop(&mut self) { + let _ = unsafe { closesocket(self.0) }; + } +} + +impl AsRawSocket for Socket { + fn as_raw_socket(&self) -> RawSocket { + self.0 as RawSocket + } +} + +impl FromRawSocket for Socket { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + Socket(sock as SOCKET) + } +} + +impl IntoRawSocket for Socket { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0 as RawSocket; + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs new file mode 100644 index 000000000..ce1da2f54 --- /dev/null +++ b/src/sys/windows/stdnet/stream.rs @@ -0,0 +1,151 @@ +use std::io::{self, IoSlice, IoSliceMut}; +use std::net::Shutdown; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::{fmt, mem}; + +use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; + +use super::{socket::Socket, SocketAddr}; + +pub(crate) struct UnixStream(pub(super) Socket); + +impl UnixStream { + pub(crate) fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getpeername(self.0.as_raw_socket() as _, addr, len), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + pub(crate) fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } +} + +cfg_os_poll! { + use std::path::Path; + use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; + use super::socket_addr; + + impl UnixStream { + pub(crate) fn connect>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + SOCKET_ERROR + ) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(e) => return Err(e), + } + Ok(UnixStream(inner)) + } + + pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + } +} + +impl fmt::Debug for UnixStream { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixStream"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} + +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + io::Read::read_vectored(&mut &*self, bufs) + } +} + +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.recv(buf) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + self.0.recv_vectored(bufs) + } +} + +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + io::Write::write_vectored(&mut &*self, bufs) + } + + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} + +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.send(buf) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.0.send_vectored(bufs) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 533074be9..af0e25106 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -6,7 +6,8 @@ use windows_sys::Win32::Networking::WinSock::{ self, AF_INET, AF_INET6, SOCKET, SOCKET_ERROR, SOCK_STREAM, }; -use crate::sys::windows::net::{init, new_socket, socket_addr}; +use crate::sys::windows::net::{new_socket, socket_addr}; +use crate::sys::windows::stdnet::init; pub(crate) fn new_for_addr(address: SocketAddr) -> io::Result { init(); diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index 91516ccc2..213f2d329 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -4,7 +4,8 @@ use std::net::{self, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // windows-sys uses usize, stdlib uses u32/u64. -use crate::sys::windows::net::{init, new_ip_socket, socket_addr}; +use crate::sys::windows::net::{new_ip_socket, socket_addr}; +use crate::sys::windows::stdnet::init; use windows_sys::Win32::Networking::WinSock::{ bind as win_bind, closesocket, getsockopt, IPPROTO_IPV6, IPV6_V6ONLY, SOCKET_ERROR, SOCK_DGRAM, }; diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs new file mode 100644 index 000000000..4ba4395e5 --- /dev/null +++ b/src/sys/windows/uds/listener.rs @@ -0,0 +1,23 @@ +use std::io; +use std::os::windows::io::AsRawSocket; +use std::path::Path; + +use super::SocketAddr; +use crate::net::UnixStream; +use crate::sys::windows::stdnet as net; + +pub(crate) fn bind(path: &Path) -> io::Result { + let listener = net::UnixListener::bind(path)?; + listener.set_nonblocking(true)?; + Ok(listener) +} + +pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { + listener + .accept() + .map(|(stream, addr)| (UnixStream::from_std(stream), addr)) +} + +pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { + super::local_addr(listener.as_raw_socket()) +} diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs new file mode 100644 index 000000000..b99c01e42 --- /dev/null +++ b/src/sys/windows/uds/mod.rs @@ -0,0 +1,29 @@ +pub(crate) use super::stdnet::SocketAddr; + +cfg_os_poll! { + use std::convert::TryInto; + use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; + use std::os::windows::io::RawSocket; + use std::io; + + pub(crate) mod listener; + pub(crate) mod stream; + + pub(crate) fn local_addr(socket: RawSocket) -> io::Result { + SocketAddr::new(|sockaddr, socklen| { + wsa_syscall!( + getsockname(socket.try_into().unwrap(), sockaddr, socklen), + SOCKET_ERROR + ) + }) + } + + pub(crate) fn peer_addr(socket: RawSocket) -> io::Result { + SocketAddr::new(|sockaddr, socklen| { + wsa_syscall!( + getpeername(socket.try_into().unwrap(), sockaddr, socklen), + SOCKET_ERROR + ) + }) + } +} diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs new file mode 100644 index 000000000..b02f32e8f --- /dev/null +++ b/src/sys/windows/uds/stream.rs @@ -0,0 +1,19 @@ +use super::SocketAddr; +use crate::sys::windows::stdnet as net; +use std::io; +use std::os::windows::io::AsRawSocket; +use std::path::Path; + +pub(crate) fn connect(path: &Path) -> io::Result { + let socket = net::UnixStream::connect(path)?; + socket.set_nonblocking(true)?; + Ok(socket) +} + +pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { + super::local_addr(socket.as_raw_socket()) +} + +pub(crate) fn peer_addr(socket: &net::UnixStream) -> io::Result { + super::peer_addr(socket.as_raw_socket()) +} diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 0aeda8153..c131497cc 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,8 +1,11 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] +#[cfg(windows)] +use mio::net; use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; +#[cfg(unix)] use std::os::unix::net; use std::path::{Path, PathBuf}; use std::sync::{Arc, Barrier}; @@ -30,6 +33,7 @@ fn unix_listener_smoke() { smoke_test(|path| UnixListener::bind(path), "unix_listener_smoke"); } +#[cfg(unix)] #[test] fn unix_listener_from_std() { smoke_test( diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 79b7c3d4b..42eef6c53 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,14 +1,19 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] +#[cfg(windows)] +use mio::net; use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; +#[cfg(unix)] use std::os::unix::net; use std::path::Path; use std::sync::mpsc::channel; use std::sync::{Arc, Barrier}; use std::thread; +#[cfg(windows)] +use std::time::Duration; #[macro_use] mod util; @@ -24,6 +29,7 @@ const DATA1_LEN: usize = 16; const DATA2_LEN: usize = 14; const DEFAULT_BUF_SIZE: usize = 64; const TOKEN_1: Token = Token(0); +#[cfg(unix)] const TOKEN_2: Token = Token(1); #[test] @@ -77,6 +83,7 @@ fn unix_stream_connect() { handle.join().unwrap(); } +#[cfg(unix)] #[test] fn unix_stream_from_std() { smoke_test( @@ -91,6 +98,7 @@ fn unix_stream_from_std() { ) } +#[cfg(unix)] #[test] fn unix_stream_pair() { let (mut poll, mut events) = init_with_poll(); @@ -241,7 +249,13 @@ fn unix_stream_shutdown_write() { ); let err = stream.write(DATA2).unwrap_err(); + #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); + #[cfg(windows)] + { + use windows_sys::Win32::Networking::WinSock::WSAESHUTDOWN; + assert_eq!(err.raw_os_error(), Some(WSAESHUTDOWN)); + } // Read should be ok let mut buf = [0; DEFAULT_BUF_SIZE]; @@ -304,8 +318,8 @@ fn unix_stream_shutdown_both() { let err = stream.write(DATA2).unwrap_err(); #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - #[cfg(window)] - assert_eq!(err.kind(), io::ErrorKind::ConnectionAbroted); + #[cfg(windows)] + assert_eq!(err.kind(), io::ErrorKind::ConnectionAborted); // Close the connection to allow the remote to shutdown drop(stream); @@ -445,6 +459,8 @@ where assert!(stream.take_error().unwrap().is_none()); + assert_would_block(stream.read(&mut buf)); + let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; let wrote = stream.write_vectored(&bufs).unwrap(); assert_eq!(wrote, DATA1_LEN + DATA2_LEN); @@ -470,70 +486,107 @@ where handle.join().unwrap(); } -fn new_echo_listener( +#[cfg(windows)] +fn new_listener( connections: usize, test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { + handle_stream: F, +) -> (thread::JoinHandle<()>, net::SocketAddr) +where + F: Fn(net::UnixStream) + std::marker::Send + 'static, +{ let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); - let listener = net::UnixListener::bind(path).unwrap(); + // We use mio's non-blocking listener here for windows, since there is no listener in std + // yet. We must be sure to poll before listener I/O. + let mut listener = net::UnixListener::bind(path).unwrap(); + let (mut poll, mut events) = init_with_poll(); + poll.registry() + .register(&mut listener, TOKEN_1, Interest::READABLE) + .unwrap(); + let local_addr = listener.local_addr().unwrap(); addr_sender.send(local_addr).unwrap(); for _ in 0..connections { - let (mut stream, _) = listener.accept().unwrap(); - - // On Linux based system it will cause a connection reset - // error when the reading side of the peer connection is - // shutdown, we don't consider it an actual here. - let (mut read, mut written) = (0, 0); - let mut buf = [0; DEFAULT_BUF_SIZE]; - loop { - let n = match stream.read(&mut buf) { - Ok(amount) => { - read += amount; - amount - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, - Err(err) => panic!("{}", err), - }; - if n == 0 { - break; - } - match stream.write(&buf[..n]) { - Ok(amount) => written += amount, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, - Err(err) => panic!("{}", err), - }; - } - assert_eq!(read, written, "unequal reads and writes"); + poll.poll(&mut events, Some(Duration::from_millis(500))) + .unwrap(); + let (stream, _) = listener.accept().unwrap(); + assert_would_block(listener.accept()); + handle_stream(stream); } }); (handle, addr_receiver.recv().unwrap()) } -fn new_noop_listener( +#[cfg(unix)] +fn new_listener( connections: usize, - barrier: Arc, test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { - let (sender, receiver) = channel(); + handle_stream: F, +) -> (thread::JoinHandle<()>, net::SocketAddr) +where + F: Fn(net::UnixStream) + std::marker::Send + 'static, +{ + let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); let listener = net::UnixListener::bind(path).unwrap(); let local_addr = listener.local_addr().unwrap(); - sender.send(local_addr).unwrap(); + addr_sender.send(local_addr).unwrap(); for _ in 0..connections { let (stream, _) = listener.accept().unwrap(); - barrier.wait(); - stream.shutdown(Shutdown::Write).unwrap(); - barrier.wait(); - drop(stream); + handle_stream(stream); } }); - (handle, receiver.recv().unwrap()) + (handle, addr_receiver.recv().unwrap()) +} + +fn new_echo_listener( + connections: usize, + test_name: &'static str, +) -> (thread::JoinHandle<()>, net::SocketAddr) { + new_listener(connections, test_name, |mut stream| { + // On Linux based system it will cause a connection reset + // error when the reading side of the peer connection is + // shutdown, we don't consider it an actual here. + let (mut read, mut written) = (0, 0); + let mut buf = [0; DEFAULT_BUF_SIZE]; + loop { + let n = match stream.read(&mut buf) { + Ok(amount) => { + read += amount; + amount + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, + Err(err) => panic!("{}", err), + }; + if n == 0 { + break; + } + match stream.write(&buf[..n]) { + Ok(amount) => written += amount, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => panic!("{}", err), + }; + } + assert_eq!(read, written, "unequal reads and writes"); + }) +} + +fn new_noop_listener( + connections: usize, + barrier: Arc, + test_name: &'static str, +) -> (thread::JoinHandle<()>, net::SocketAddr) { + new_listener(connections, test_name, move |stream| { + barrier.wait(); + stream.shutdown(Shutdown::Write).unwrap(); + barrier.wait(); + drop(stream); + }) }