Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 74 additions & 11 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"golang.org/x/net/bpf"
)

type BufferAllocationFunc func() ([]byte, error)

// A Conn is a connection to netlink. A Conn can be used to send and
// receives messages to and from netlink.
//
Expand Down Expand Up @@ -53,6 +55,7 @@ type Socket interface {
Send(m Message) error
SendMessages(m []Message) error
Receive() ([]Message, error)
ReceiveBuffer(fn BufferAllocationFunc) ([]Message, error)
}

// Dial dials a connection to netlink, using the specified netlink family.
Expand Down Expand Up @@ -114,17 +117,20 @@ func (c *Conn) Close() error {
return newOpError("close", c.sock.Close())
}

// Execute sends a single Message to netlink using Send, receives one or more
// ExecuteBuffer sends a single Message to netlink using Send, receives one or more
// replies using Receive, and then checks the validity of the replies against
// the request using Validate.
//
// ExecuteBuffer uses the BufferAllocationFunc to execute buffer allocation
// for data coming from the underlying socket.
//
// Execute acquires a lock for the duration of the function call which blocks
// concurrent calls to Send, SendMessages, and Receive, in order to ensure
// consistency between netlink request/reply messages.
//
// See the documentation of Send, Receive, and Validate for details about
// each function.
func (c *Conn) Execute(m Message) ([]Message, error) {
func (c *Conn) ExecuteBuffer(m Message, fn BufferAllocationFunc) ([]Message, error) {
// Acquire the write lock and invoke the internal implementations of Send
// and Receive which require the lock already be held.
c.mu.Lock()
Expand All @@ -135,7 +141,7 @@ func (c *Conn) Execute(m Message) ([]Message, error) {
return nil, err
}

res, err := c.lockedReceive()
res, err := c.lockedReceive(fn)
if err != nil {
return nil, err
}
Expand All @@ -147,6 +153,20 @@ func (c *Conn) Execute(m Message) ([]Message, error) {
return res, nil
}

// Execute sends a single Message to netlink using Send, receives one or more
// replies using Receive, and then checks the validity of the replies against
// the request using Validate.
//
// Execute acquires a lock for the duration of the function call which blocks
// concurrent calls to Send, SendMessages, and Receive, in order to ensure
// consistency between netlink request/reply messages.
//
// See the documentation of Send, Receive, and Validate for details about
// each function.
func (c *Conn) Execute(m Message) ([]Message, error) {
return c.ExecuteBuffer(m, nil)
}

// SendMessages sends multiple Messages to netlink. The handling of
// a Header's Length, Sequence and PID fields is the same as when
// calling Send.
Expand Down Expand Up @@ -218,24 +238,36 @@ func (c *Conn) lockedSend(m Message) (Message, error) {
return m, nil
}

// Receive receives one or more messages from netlink. Multi-part messages are
// handled transparently and returned as a single slice of Messages, with the
// ReceiveBuffer receives one or more messages from netlink. Multi-part messages
// are handled transparently and returned as a single slice of Messages, with the
// final empty "multi-part done" message removed.
//
// ReceiveBuffer uses BufferAllocationFunc to execute buffer allocation for the
// underlying socket receiving data.
//
// If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) Receive() ([]Message, error) {
func (c *Conn) ReceiveBuffer(fn BufferAllocationFunc) ([]Message, error) {
// Wait for any concurrent calls to Execute to finish before proceeding.
c.mu.RLock()
defer c.mu.RUnlock()

return c.lockedReceive()
return c.lockedReceive(fn)
}

// Receive receives one or more messages from netlink. Multi-part messages are
// handled transparently and returned as a single slice of Messages, with the
// final empty "multi-part done" message removed.
//
// If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) Receive() ([]Message, error) {
return c.ReceiveBuffer(nil)
}

// lockedReceive implements Receive, but must be called with c.mu acquired for reading.
// We rely on the kernel to deal with concurrent reads and writes to the netlink
// socket itself.
func (c *Conn) lockedReceive() ([]Message, error) {
msgs, err := c.receive()
func (c *Conn) lockedReceive(fn BufferAllocationFunc) ([]Message, error) {
msgs, err := c.receive(fn)
if err != nil {
c.debug(func(d *debugger) {
d.debugf(1, "recv: err: %v", err)
Expand Down Expand Up @@ -266,7 +298,7 @@ func (c *Conn) lockedReceive() ([]Message, error) {

// receive is the internal implementation of Conn.Receive, which can be called
// recursively to handle multi-part messages.
func (c *Conn) receive() ([]Message, error) {
func (c *Conn) receive(fn BufferAllocationFunc) ([]Message, error) {
// NB: All non-nil errors returned from this function *must* be of type
// OpError in order to maintain the appropriate contract with callers of
// this package.
Expand All @@ -276,7 +308,7 @@ func (c *Conn) receive() ([]Message, error) {

var res []Message
for {
msgs, err := c.sock.Receive()
msgs, err := c.sock.ReceiveBuffer(fn)
if err != nil {
return nil, newOpError("receive", err)
}
Expand Down Expand Up @@ -463,6 +495,37 @@ func (c *Conn) SetWriteBuffer(bytes int) error {
return newOpError("set-write-buffer", conn.SetWriteBuffer(bytes))
}

// A bufferGetter is a Socket that supports retrieving connection buffer sizes.
type bufferGetter interface {
Socket
ReadBuffer() (int, error)
WriteBuffer() (int, error)
}

// WriteBuffer retrieves the size of the operating system's receive buffer
// associated with the Conn.
func (c *Conn) ReadBuffer() (int, error) {
conn, ok := c.sock.(bufferGetter)
if !ok {
return -1, notSupported("get-read-buffer")
}

n, err := conn.ReadBuffer()
return n, newOpError("get-read-buffer", err)
}

// WriteBuffer retrieves the size of the operating system's transmit buffer
// associated with the Conn.
func (c *Conn) WriteBuffer() (int, error) {
conn, ok := c.sock.(bufferGetter)
if !ok {
return -1, notSupported("get-write-buffer")
}

n, err := conn.WriteBuffer()
return n, newOpError("get-write-buffer", err)
}

// A syscallConner is a Socket that supports syscall.Conn.
type syscallConner interface {
Socket
Expand Down
102 changes: 81 additions & 21 deletions conn_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package netlink

import (
"errors"
"os"
"syscall"
"time"
Expand Down Expand Up @@ -116,30 +117,19 @@ func (c *conn) Send(m Message) error {
return c.s.Sendmsg(b, nil, sa, 0)
}

// Receive receives one or more Messages from netlink.
func (c *conn) Receive() ([]Message, error) {
b := make([]byte, os.Getpagesize())
for {
// Peek at the buffer to see how many bytes are available.
//
// TODO(mdlayher): deal with OOB message data if available, such as
// when PacketInfo ConnOption is true.
n, _, _, _, err := c.s.Recvmsg(b, nil, unix.MSG_PEEK)
if err != nil {
return nil, err
}

// Break when we can read all messages
if n < len(b) {
break
}

// Double in size if not enough bytes
b = make([]byte, len(b)*2)
// ReceiveBuffer receives one or more Messages from netlink.
// Requires buffer allocation function in order to allocate socket buffer.
func (c *conn) ReceiveBuffer(fn BufferAllocationFunc) ([]Message, error) {
if fn == nil {
fn = c.bufferAllocation
}
b, err := fn()
if err != nil {
return nil, err
}

// Read out all available messages
n, _, _, _, err := c.s.Recvmsg(b, nil, 0)
n, _, _, _, err := c.recvENoBufsAware(b, nil, 0)
if err != nil {
return nil, err
}
Expand All @@ -162,6 +152,12 @@ func (c *conn) Receive() ([]Message, error) {
return msgs, nil
}

// Receive receives one or more Messages from netlink.
// Uses default BufferAllocation() func for buffer allocation.
func (c *conn) Receive() ([]Message, error) {
return c.ReceiveBuffer(nil)
}

// Close closes the connection.
func (c *conn) Close() error { return c.s.Close() }

Expand Down Expand Up @@ -209,9 +205,73 @@ func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes)
// associated with the Conn.
func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) }

// ReadBuffer returns the size of the operating system's receive buffer
// associated with the Conn.
func (c *conn) ReadBuffer() (int, error) { return c.s.ReadBuffer() }

// WriteBuffer returns the size of the operating system's transmit buffer
// associated with the Conn.
func (c *conn) WriteBuffer() (int, error) { return c.s.WriteBuffer() }

// SyscallConn returns a raw network connection.
func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() }

// Allocates buffer by peeking into the socket buffer and
// extending it in case there is not enough space.
// In case of ENOBUFS error extends socket read and write
// buffers for subsequent requests.
func (c *conn) bufferAllocation() ([]byte, error) {
b := make([]byte, os.Getpagesize())
for {
// Peek at the buffer to see how many bytes are available.
//
// TODO(mdlayher): deal with OOB message data if available, such as
// when PacketInfo ConnOption is true.
n, _, _, _, err := c.recvENoBufsAware(b, nil, unix.MSG_PEEK)
if err != nil {
return nil, err
}

// Break when we can read all messages
if n < len(b) {
break
}

// Double in size if not enough bytes
b = make([]byte, len(b)*2)
}
return b, nil
}

// recvENoBufsAware wraps (*socket.Conn).Recvmsg to extend socket read and write
// buffers if recvmsg call returns the ENOBUFS error
func (c *conn) recvENoBufsAware(p, oob []byte, flags int) (n, oobn, recvflags int, from unix.Sockaddr, recvErr error) {
n, oobn, recvflags, from, recvErr = c.s.Recvmsg(p, oob, flags)
if recvErr != nil {
var syscallErr syscall.Errno
if !errors.As(recvErr, &syscallErr) {
return
}

if !errors.Is(syscallErr, syscall.ENOBUFS) {
return
}

rbLen, err := c.s.ReadBuffer()
if err != nil {
return
}
c.SetReadBuffer(rbLen * 2)

wbLen, err := c.s.WriteBuffer()
if err != nil {
return
}
c.SetWriteBuffer(wbLen * 2)
}
return
}

// linuxOption converts a ConnOption to its Linux value.
func linuxOption(o ConnOption) (int, bool) {
switch o {
Expand Down
67 changes: 67 additions & 0 deletions conn_linux_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os/exec"
"os/user"
"sync"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -342,6 +343,72 @@ func TestIntegrationConnConcurrentSerializeExecute(t *testing.T) {
}
}

func TestIntegrationConnExtendBuffers(t *testing.T) {
c, err := netlink.Dial(unix.NETLINK_GENERIC, nil)
if err != nil {
t.Fatalf("failed to dial netlink: %v", err)
}
defer c.Close()

if err := c.SetReadBuffer(8); err != nil {
t.Fatalf("failed resizing buffer")
}

initRBuf, err := c.ReadBuffer()
if err != nil {
t.Fatalf("failed to fetch initial read buffer size: %v", err)
}
initWBuf, err := c.WriteBuffer()
if err != nil {
t.Fatalf("failed to fetch initial write buffer size: %v", err)
}

req := netlink.Message{
Header: netlink.Header{
Flags: netlink.Request | netlink.Acknowledge,
},
}
const recordLen = 50
msgs := make([]netlink.Message, recordLen)
for i := 0; i < recordLen; i++ {
msgs[i] = req
}
if _, err := c.SendMessages(msgs); err != nil {
t.Fatalf("failed to send message: %v", err)
}

_, err = c.Receive()
if err == nil {
t.Skipf("execute succeeded, buffer is large enough? buffer %d, data %d", initRBuf, len(req.Data))
}

var syscallErr syscall.Errno
if !errors.As(err, &syscallErr) {
t.Fatalf("got error on execute: %v", err)
}

if !errors.Is(syscallErr, syscall.ENOBUFS) {
t.Fatalf("syscall error is not ENOBUFS: %v", syscallErr)
}

rbuf, err := c.ReadBuffer()
if err != nil {
t.Fatalf("failed to fetch read buffer size: %v", err)
}
wbuf, err := c.WriteBuffer()
if err != nil {
t.Fatalf("failed to fetch write buffer size: %v", err)
}

if initRBuf >= rbuf {
t.Errorf("current rbuf %d is not bigger than initial %d", rbuf, initRBuf)
}

if initWBuf >= wbuf {
t.Errorf("current wbuf %d is not bigger than initial %d", wbuf, initWBuf)
}
}

func TestIntegrationConnSetBuffersSyscallConn(t *testing.T) {
tests := []struct {
name string
Expand Down
5 changes: 4 additions & 1 deletion conn_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ func newError(_ int) error { return errUnimplemented }
func (c *conn) Send(_ Message) error { return errUnimplemented }
func (c *conn) SendMessages(_ []Message) error { return errUnimplemented }
func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented }
func (c *conn) Close() error { return errUnimplemented }
func (c *conn) ReceiveBuffer(fn BufferAllocationFunc) ([]Message, error) {
return nil, errUninmplemented
}
func (c *conn) Close() error { return errUnimplemented }
Loading