From e294f2d318ac5eb3608f7393a13ce12c4b309d82 Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Thu, 7 Dec 2023 17:29:38 +0800 Subject: [PATCH 1/5] feat: add WithReadThreshold API --- connection_errors.go | 17 ++++--- connection_impl.go | 69 ++++++++++++++++--------- connection_onevent.go | 1 + connection_reactor.go | 43 ++++++++++++++-- connection_test.go | 110 ++++++++++++++++++++++++++++++++++++---- docs/guide/guide_cn.md | 20 ++++++++ docs/guide/guide_en.md | 24 +++++++++ eventloop.go | 30 +++++------ fd_operator.go | 25 ++++++--- mux/shard_queue_test.go | 45 +++++++--------- net_dialer.go | 21 +++++--- net_dialer_test.go | 8 +-- net_polldesc_test.go | 8 +-- net_sock.go | 38 +++++++------- net_tcpsock.go | 30 ++++++----- net_unixsock.go | 28 +++++----- netpoll_options.go | 21 +++++--- netpoll_test.go | 108 +++++++++++++++++++++++++++++++++++++++ nocopy.go | 6 +-- poll.go | 16 ++++-- poll_default_bsd.go | 20 ++++++++ poll_default_linux.go | 21 ++++++++ sys_exec.go | 6 ++- 23 files changed, 551 insertions(+), 164 deletions(-) diff --git a/connection_errors.go b/connection_errors.go index 1edfa21d..8509a0c1 100644 --- a/connection_errors.go +++ b/connection_errors.go @@ -36,6 +36,8 @@ const ( ErrEOF = syscall.Errno(0x106) // Write I/O buffer timeout, calling by Connection.Writer ErrWriteTimeout = syscall.Errno(0x107) + // The wait read size large than read threshold + ErrReadOutOfThreshold = syscall.Errno(0x108) ) const ErrnoMask = 0xFF @@ -110,11 +112,12 @@ func (e *exception) Temporary() bool { // Errors defined in netpoll var errnos = [...]string{ - ErrnoMask & ErrConnClosed: "connection has been closed", - ErrnoMask & ErrReadTimeout: "connection read timeout", - ErrnoMask & ErrDialTimeout: "dial wait timeout", - ErrnoMask & ErrDialNoDeadline: "dial no deadline", - ErrnoMask & ErrUnsupported: "netpoll dose not support", - ErrnoMask & ErrEOF: "EOF", - ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrConnClosed: "connection has been closed", + ErrnoMask & ErrReadTimeout: "connection read timeout", + ErrnoMask & ErrDialTimeout: "dial wait timeout", + ErrnoMask & ErrDialNoDeadline: "dial no deadline", + ErrnoMask & ErrUnsupported: "netpoll dose not support", + ErrnoMask & ErrEOF: "EOF", + ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrReadOutOfThreshold: "connection read size is out of threshold", } diff --git a/connection_impl.go b/connection_impl.go index 77212de6..291de144 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -33,20 +33,21 @@ type connection struct { netFD onEvent locker - operator *FDOperator - readTimeout time.Duration - readTimer *time.Timer - readTrigger chan error - waitReadSize int64 - writeTimeout time.Duration - writeTimer *time.Timer - writeTrigger chan error - inputBuffer *LinkBuffer - outputBuffer *LinkBuffer - outputBarrier *barrier - supportZeroCopy bool - maxSize int // The maximum size of data between two Release(). - bookSize int // The size of data that can be read at once. + operator *FDOperator + readTimeout time.Duration + readTimer *time.Timer + readTrigger chan error + waitReadSize int64 + writeTimeout time.Duration + writeTimer *time.Timer + writeTrigger chan error + inputBuffer *LinkBuffer + outputBuffer *LinkBuffer + outputBarrier *barrier + supportZeroCopy bool + maxSize int // The maximum size of data between two Release(). + bookSize int // The size of data that can be read at once. + readBufferThreshold int64 // The readBufferThreshold limit the size of connection inputBuffer. In bytes. } var ( @@ -94,6 +95,12 @@ func (c *connection) SetWriteTimeout(timeout time.Duration) error { return nil } +// SetReadBufferThreshold implements Connection. +func (c *connection) SetReadBufferThreshold(threshold int64) error { + c.readBufferThreshold = threshold + return nil +} + // ------------------------------------------ implement zero-copy reader ------------------------------------------ // Next implements Connection. @@ -394,28 +401,44 @@ func (c *connection) triggerWrite(err error) { // waitRead will wait full n bytes. func (c *connection) waitRead(n int) (err error) { if n <= c.inputBuffer.Len() { - return nil + goto CLEANUP } + // cannot wait read with an out of threshold size + if c.readBufferThreshold > 0 && int64(n) > c.readBufferThreshold { + // just return error and dont do cleanup + return Exception(ErrReadOutOfThreshold, "wait read") + } + atomic.StoreInt64(&c.waitReadSize, int64(n)) - defer atomic.StoreInt64(&c.waitReadSize, 0) if c.readTimeout > 0 { - return c.waitReadWithTimeout(n) + err = c.waitReadWithTimeout(n) + goto CLEANUP } // wait full n for c.inputBuffer.Len() < n { switch c.status(closing) { case poller: - return Exception(ErrEOF, "wait read") + err = Exception(ErrEOF, "wait read") case user: - return Exception(ErrConnClosed, "wait read") + err = Exception(ErrConnClosed, "wait read") default: err = <-c.readTrigger - if err != nil { - return err - } + } + if err != nil { + goto CLEANUP } } - return nil +CLEANUP: + atomic.StoreInt64(&c.waitReadSize, 0) + if c.readBufferThreshold > 0 && err == nil { + // only resume read when current read size could make newBufferSize < readBufferThreshold + bufferSize := int64(c.inputBuffer.Len()) + newBufferSize := bufferSize - int64(n) + if bufferSize >= c.readBufferThreshold && newBufferSize < c.readBufferThreshold { + c.resumeRead() + } + } + return err } // waitReadWithTimeout will wait full n bytes or until timeout. diff --git a/connection_onevent.go b/connection_onevent.go index 6f055f37..f2893fa1 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -103,6 +103,7 @@ func (c *connection) onPrepare(opts *options) (err error) { c.SetReadTimeout(opts.readTimeout) c.SetWriteTimeout(opts.writeTimeout) c.SetIdleTimeout(opts.idleTimeout) + c.SetReadBufferThreshold(opts.readBufferThreshold) // calling prepare first and then register. if opts.onPrepare != nil { diff --git a/connection_reactor.go b/connection_reactor.go index cd5d717c..eb5620ca 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -104,6 +104,11 @@ func (c *connection) inputAck(n int) (err error) { c.maxSize = mallocMax } + // trigger throttle + if c.readBufferThreshold > 0 && int64(length) >= c.readBufferThreshold { + c.pauseRead() + } + var needTrigger = true if length == n { // first start onRequest needTrigger = c.onRequest() @@ -117,7 +122,7 @@ func (c *connection) inputAck(n int) (err error) { // outputs implements FDOperator. func (c *connection) outputs(vs [][]byte) (rs [][]byte, supportZeroCopy bool) { if c.outputBuffer.IsEmpty() { - c.rw2r() + c.pauseWrite() return rs, c.supportZeroCopy } rs = c.outputBuffer.GetBytes(vs) @@ -131,13 +136,41 @@ func (c *connection) outputAck(n int) (err error) { c.outputBuffer.Release() } if c.outputBuffer.IsEmpty() { - c.rw2r() + c.pauseWrite() } return nil } -// rw2r removed the monitoring of write events. -func (c *connection) rw2r() { - c.operator.Control(PollRW2R) +// pauseWrite removed the monitoring of write events. +// pauseWrite used in poller +func (c *connection) pauseWrite() { + switch c.operator.getMode() { + case opreadwrite: + c.operator.Control(PollRW2R) + case opwrite: + c.operator.Control(PollW2Hup) + } c.triggerWrite(nil) } + +// pauseRead removed the monitoring of read events. +// pauseRead used in poller +func (c *connection) pauseRead() { + switch c.operator.getMode() { + case opread: + c.operator.Control(PollR2Hup) + case opreadwrite: + c.operator.Control(PollRW2W) + } +} + +// resumeRead add the monitoring of read events. +// resumeRead used by users +func (c *connection) resumeRead() { + switch c.operator.getMode() { + case ophup: + c.operator.Control(PollHup2R) + case opwrite: + c.operator.Control(PollW2RW) + } +} diff --git a/connection_test.go b/connection_test.go index 782e85c2..9a5c551a 100644 --- a/connection_test.go +++ b/connection_test.go @@ -499,18 +499,15 @@ func TestConnDetach(t *testing.T) { func TestParallelShortConnection(t *testing.T) { ln, err := createTestListener("tcp", ":12345") MustNil(t, err) - defer ln.Close() - var received int64 el, err := NewEventLoop(func(ctx context.Context, connection Connection) error { data, err := connection.Reader().Next(connection.Reader().Len()) - if err != nil { - return err - } + Assert(t, err == nil || errors.Is(err, ErrEOF)) atomic.AddInt64(&received, int64(len(data))) - //t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive()) + t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive()) return nil }) + defer el.Shutdown(context.Background()) go func() { el.Serve(ln) }() @@ -536,10 +533,11 @@ func TestParallelShortConnection(t *testing.T) { } wg.Wait() - for atomic.LoadInt64(&received) < int64(totalSize) { - t.Logf("received: %d, except: %d", atomic.LoadInt64(&received), totalSize) + start := time.Now() + for atomic.LoadInt64(&received) < int64(totalSize) && time.Now().Sub(start) < time.Second { time.Sleep(time.Millisecond * 100) } + Equal(t, atomic.LoadInt64(&received), int64(totalSize)) } func TestConnectionServerClose(t *testing.T) { @@ -643,8 +641,6 @@ func TestConnectionServerClose(t *testing.T) { func TestConnectionDailTimeoutAndClose(t *testing.T) { ln, err := createTestListener("tcp", ":12345") MustNil(t, err) - defer ln.Close() - el, err := NewEventLoop( func(ctx context.Context, connection Connection) error { _, err = connection.Reader().Next(connection.Reader().Len()) @@ -668,10 +664,102 @@ func TestConnectionDailTimeoutAndClose(t *testing.T) { go func() { defer wg.Done() conn, err := DialConnection("tcp", ":12345", time.Nanosecond) - Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout")) + Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout"), err) _ = conn }() } wg.Wait() } } + +func TestConnectionReadOutOfThreshold(t *testing.T) { + var readThreshold = 1024 * 100 + var readSize = readThreshold + 1 + var opts = &options{} + var wg sync.WaitGroup + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + defer wg.Done() + // read throttled data + _, err := connection.Reader().Next(readSize) + Assert(t, errors.Is(err, ErrReadOutOfThreshold), err) + connection.Close() + return nil + } + + WithReadBufferThreshold(int64(readThreshold)).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + wg.Wait() +} + +func TestConnectionReadThreshold(t *testing.T) { + var readThreshold int64 = 1024 * 100 + var opts = &options{} + var wg sync.WaitGroup + var throttled int32 + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if int64(connection.Reader().Len()) < readThreshold { + return nil + } + defer wg.Done() + + atomic.StoreInt32(&throttled, 1) + // check if no more read data when throttled + inbuffered := connection.Reader().Len() + t.Logf("Inbuffered: %d", inbuffered) + time.Sleep(time.Millisecond * 100) + Equal(t, inbuffered, connection.Reader().Len()) + + // read non-throttled data + buf, err := connection.Reader().Next(int(readThreshold)) + Equal(t, int64(len(buf)), readThreshold) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + t.Logf("read non-throttled data") + + // continue read throttled data + buf, err = connection.Reader().Next(5) + MustNil(t, err) + t.Logf("read throttled data: [%s]", buf) + Equal(t, len(buf), 5) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + Equal(t, connection.Reader().Len(), 0) + return nil + } + + WithReadBufferThreshold(readThreshold).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + Assert(t, rconn.readBufferThreshold == readThreshold) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + _, err = wconn.Writer().WriteString("hello") + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + t.Logf("flush final msg") + + wg.Wait() +} diff --git a/docs/guide/guide_cn.md b/docs/guide/guide_cn.md index f9b9a0db..bffa9b29 100644 --- a/docs/guide/guide_cn.md +++ b/docs/guide/guide_cn.md @@ -519,6 +519,26 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. 如何配置连接的读取阈值大小 ? + +Netpoll 默认不会对端发送数据的读取速度有任何限制,每当连接有数据时,Netpoll 会尽可能快地将数据存放在自己的 buffer 中。但有时候可能用户不希望数据过快发送,或者是希望控制服务内存使用量,又或者业务 OnRequest 回调处理速度很慢需要限制发送方速度,此时可以使用 `WithReadThreshold` 来控制读取的最大阈值。 + +### Client 侧使用 + +``` +dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server 侧使用 + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # 注意事项 ## 1. 错误设置 NumLoops diff --git a/docs/guide/guide_en.md b/docs/guide/guide_en.md index 08c522f3..1cbbcd8d 100644 --- a/docs/guide/guide_en.md +++ b/docs/guide/guide_en.md @@ -558,6 +558,30 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. How to configure the read threshold of the connection? + +By default, Netpoll does not place any limit on the reading speed of data sent by the end. +Whenever there have more data on the connection, Netpoll will read the data into its own buffer as quickly as possible. + +But sometimes users may not want data to be read too quickly, or they want to control the service memory usage, or the user's OnRequest callback processing data very slowly and need to control the peer's send speed. +In this case, you can use `WithReadThreshold` to control the maximum reading threshold. + +### Client side use + +``` +dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server side use + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # Attention ## 1. Wrong setting of NumLoops diff --git a/eventloop.go b/eventloop.go index c9a903c0..333e2833 100644 --- a/eventloop.go +++ b/eventloop.go @@ -54,27 +54,27 @@ type OnPrepare func(connection Connection) context.Context // // An example usage in TCP Proxy scenario: // -// func onConnect(ctx context.Context, upstream netpoll.Connection) context.Context { -// downstream, _ := netpoll.DialConnection("tcp", downstreamAddr, time.Second) -// return context.WithValue(ctx, downstreamKey, downstream) -// } -// func onRequest(ctx context.Context, upstream netpoll.Connection) error { -// downstream := ctx.Value(downstreamKey).(netpoll.Connection) -// } +// func onConnect(ctx context.Context, upstream netpoll.Connection) context.Context { +// downstream, _ := netpoll.DialConnection("tcp", downstreamAddr, time.Second) +// return context.WithValue(ctx, downstreamKey, downstream) +// } +// func onRequest(ctx context.Context, upstream netpoll.Connection) error { +// downstream := ctx.Value(downstreamKey).(netpoll.Connection) +// } type OnConnect func(ctx context.Context, connection Connection) context.Context // OnRequest defines the function for handling connection. When data is sent from the connection peer, // netpoll actively reads the data in LT mode and places it in the connection's input buffer. // Generally, OnRequest starts handling the data in the following way: // -// func OnRequest(ctx context, connection Connection) error { -// input := connection.Reader().Next(n) -// handling input data... -// send, _ := connection.Writer().Malloc(l) -// copy(send, output) -// connection.Flush() -// return nil -// } +// func OnRequest(ctx context, connection Connection) error { +// input := connection.Reader().Next(n) +// handling input data... +// send, _ := connection.Writer().Malloc(l) +// copy(send, output) +// connection.Flush() +// return nil +// } // // OnRequest will run in a separate goroutine and // it is guaranteed that there is one and only one OnRequest running at the same time. diff --git a/fd_operator.go b/fd_operator.go index 1ac843a9..b94e3825 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -19,6 +19,15 @@ import ( "sync/atomic" ) +const ( + opdetach int32 = -1 + _ int32 = 0 // default op mode, means nothing + opread int32 = 1 + opwrite int32 = 2 + opreadwrite int32 = 3 + ophup int32 = 4 +) + // FDOperator is a collection of operations on file descriptors. type FDOperator struct { // FD is file descriptor, poll will bind when register. @@ -42,8 +51,7 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll - // protect only detach once - detached int32 + mode int32 // private, used by operatorCache next *FDOperator @@ -52,9 +60,6 @@ type FDOperator struct { } func (op *FDOperator) Control(event PollEvent) error { - if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 { - return nil - } return op.poll.Control(op, event) } @@ -62,6 +67,14 @@ func (op *FDOperator) Free() { op.poll.Free(op) } +func (op *FDOperator) getMode() int32 { + return atomic.LoadInt32(&op.mode) +} + +func (op *FDOperator) setMode(mode int32) { + atomic.StoreInt32(&op.mode, mode) +} + func (op *FDOperator) do() (can bool) { return atomic.CompareAndSwapInt32(&op.state, 1, 2) } @@ -98,5 +111,5 @@ func (op *FDOperator) reset() { op.Inputs, op.InputAck = nil, nil op.Outputs, op.OutputAck = nil, nil op.poll = nil - op.detached = 0 + op.mode = 0 } diff --git a/mux/shard_queue_test.go b/mux/shard_queue_test.go index 7a595d21..b0d3f5b4 100644 --- a/mux/shard_queue_test.go +++ b/mux/shard_queue_test.go @@ -19,6 +19,7 @@ package mux import ( "net" + "sync" "testing" "time" @@ -26,28 +27,25 @@ import ( ) func TestShardQueue(t *testing.T) { - var svrConn net.Conn accepted := make(chan struct{}) network, address := "tcp", ":18888" ln, err := net.Listen("tcp", ":18888") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) + count, pkgsize := 16, 11 + var wg sync.WaitGroup + wg.Add(1) go func() { - var err error - for { - select { - case <-stop: - err = ln.Close() - MustNil(t, err) - return - default: - } - svrConn, err = ln.Accept() - MustNil(t, err) - accepted <- struct{}{} - } + defer wg.Done() + svrConn, err := ln.Accept() + MustNil(t, err) + accepted <- struct{}{} + + total := count * pkgsize + recv := make([]byte, total) + rn, err := svrConn.Read(recv) + MustNil(t, err) + Equal(t, rn, total) }() conn, err := netpoll.DialConnection(network, address, time.Second) @@ -56,8 +54,7 @@ func TestShardQueue(t *testing.T) { // test queue := NewShardQueue(4, conn) - count, pkgsize := 16, 11 - for i := 0; i < int(count); i++ { + for i := 0; i < count; i++ { var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) { buf = netpoll.NewLinkBuffer(pkgsize) buf.Malloc(pkgsize) @@ -68,14 +65,8 @@ func TestShardQueue(t *testing.T) { err = queue.Close() MustNil(t, err) - total := count * pkgsize - recv := make([]byte, total) - rn, err := svrConn.Read(recv) - MustNil(t, err) - Equal(t, rn, total) -} -// TODO: need mock flush -func BenchmarkShardQueue(b *testing.B) { - b.Skip() + wg.Wait() + err = ln.Close() + MustNil(t, err) } diff --git a/net_dialer.go b/net_dialer.go index 4c4e8dd2..edcafca1 100644 --- a/net_dialer.go +++ b/net_dialer.go @@ -29,13 +29,22 @@ func DialConnection(network, address string, timeout time.Duration) (connection } // NewDialer only support TCP and unix socket now. -func NewDialer() Dialer { - return &dialer{} +func NewDialer(opts ...Option) Dialer { + d := new(dialer) + if len(opts) > 0 { + d.opts = new(options) + for _, opt := range opts { + opt.f(d.opts) + } + } + return d } var defaultDialer = NewDialer() -type dialer struct{} +type dialer struct { + opts *options +} // DialTimeout implements Dialer. func (d *dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { @@ -59,7 +68,7 @@ func (d *dialer) DialConnection(network, address string, timeout time.Duration) raddr := &UnixAddr{ UnixAddr: net.UnixAddr{Name: address, Net: network}, } - return DialUnix(network, nil, raddr) + return dialUnix(network, nil, raddr, d.opts) default: return nil, net.UnknownNetworkError(network) } @@ -95,9 +104,9 @@ func (d *dialer) dialTCP(ctx context.Context, network, address string) (connecti tcpAddr.Port = portnum tcpAddr.Zone = ipaddr.Zone if ipaddr.IP != nil && ipaddr.IP.To4() == nil { - connection, err = DialTCP(ctx, "tcp6", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp6", nil, tcpAddr, d.opts) } else { - connection, err = DialTCP(ctx, "tcp", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp", nil, tcpAddr, d.opts) } if err == nil { return connection, nil diff --git a/net_dialer_test.go b/net_dialer_test.go index 7383fd0d..deca3889 100644 --- a/net_dialer_test.go +++ b/net_dialer_test.go @@ -38,15 +38,14 @@ func TestDialerTCP(t *testing.T) { ln, err := CreateListener("tcp", ":1234") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) - + stop := make(chan int) go func() { for { select { case <-stop: err := ln.Close() MustNil(t, err) + close(stop) return default: } @@ -61,6 +60,9 @@ func TestDialerTCP(t *testing.T) { MustNil(t, err) MustTrue(t, strings.HasPrefix(conn.LocalAddr().String(), "127.0.0.1:")) Equal(t, conn.RemoteAddr().String(), "127.0.0.1:1234") + + stop <- 0 + <-stop } func TestDialerUnix(t *testing.T) { diff --git a/net_polldesc_test.go b/net_polldesc_test.go index 40804b62..6f379167 100644 --- a/net_polldesc_test.go +++ b/net_polldesc_test.go @@ -30,15 +30,14 @@ func TestRuntimePoll(t *testing.T) { ln, err := CreateListener("tcp", ":1234") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) - + stop := make(chan int) go func() { for { select { case <-stop: err := ln.Close() MustNil(t, err) + close(stop) return default: } @@ -54,4 +53,7 @@ func TestRuntimePoll(t *testing.T) { MustNil(t, err) conn.Close() } + + stop <- 0 + <-stop } diff --git a/net_sock.go b/net_sock.go index a3d318c7..c6ec98e8 100644 --- a/net_sock.go +++ b/net_sock.go @@ -55,29 +55,29 @@ func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, soty // address family, both AF_INET and AF_INET6, and a wildcard address // like the following: // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with a wildcard address: If the platform supports -// both IPv6 and IPv4-mapped IPv6 communication capabilities, -// or does not support IPv4, we use a dual stack, AF_INET6 and -// IPV6_V6ONLY=0, wildcard address listen. The dual stack -// wildcard address listen may fall back to an IPv6-only, -// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen. -// Otherwise we prefer an IPv4-only, AF_INET, wildcard address -// listen. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with a wildcard address: If the platform supports +// both IPv6 and IPv4-mapped IPv6 communication capabilities, +// or does not support IPv4, we use a dual stack, AF_INET6 and +// IPV6_V6ONLY=0, wildcard address listen. The dual stack +// wildcard address listen may fall back to an IPv6-only, +// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen. +// Otherwise we prefer an IPv4-only, AF_INET, wildcard address +// listen. // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with an IPv4 wildcard address: same as above. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv4 wildcard address: same as above. // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with an IPv6 wildcard address: same as above. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv6 wildcard address: same as above. // -// - A listen for an IPv4 communication domain, "tcp4" or "udp4", -// with an IPv4 wildcard address: We use an IPv4-only, AF_INET, -// wildcard address listen. +// - A listen for an IPv4 communication domain, "tcp4" or "udp4", +// with an IPv4 wildcard address: We use an IPv4-only, AF_INET, +// wildcard address listen. // -// - A listen for an IPv6 communication domain, "tcp6" or "udp6", -// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6 -// and IPV6_V6ONLY=1, wildcard address listen. +// - A listen for an IPv6 communication domain, "tcp6" or "udp6", +// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6 +// and IPV6_V6ONLY=1, wildcard address listen. // // Otherwise guess: If the addresses are IPv4 then returns AF_INET, // or else returns AF_INET6. It also returns a boolean value what diff --git a/net_tcpsock.go b/net_tcpsock.go index 2c90634b..87fb84eb 100644 --- a/net_tcpsock.go +++ b/net_tcpsock.go @@ -138,23 +138,16 @@ type TCPConnection struct { } // newTCPConnection wraps *TCPConnection. -func newTCPConnection(conn Conn) (connection *TCPConnection, err error) { +func newTCPConnection(conn Conn, opts *options) (connection *TCPConnection, err error) { connection = &TCPConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialTCP acts like Dial for TCP networks. -// -// The network must be a TCP network name; see func Dial for details. -// -// If laddr is nil, a local address is automatically chosen. -// If the IP field of raddr is nil or an unspecified IP address, the -// local system is assumed. -func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { +func dialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -167,14 +160,25 @@ func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPCo ctx = context.Background() } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialTCP(ctx, laddr, raddr) + c, err := sd.dialTCP(ctx, laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConnection, error) { +// DialTCP acts like Dial for TCP networks. +// +// The network must be a TCP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { + return dialTCP(ctx, network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { conn, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") // TCP has a rarely used mechanism called a 'simultaneous connection' in @@ -211,7 +215,7 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo if err != nil { return nil, err } - return newTCPConnection(conn) + return newTCPConnection(conn, opts) } func selfConnect(conn *netFD, err error) bool { diff --git a/net_unixsock.go b/net_unixsock.go index c5213a1c..9564dbc7 100644 --- a/net_unixsock.go +++ b/net_unixsock.go @@ -74,41 +74,45 @@ type UnixConnection struct { } // newUnixConnection wraps UnixConnection. -func newUnixConnection(conn Conn) (connection *UnixConnection, err error) { +func newUnixConnection(conn Conn, opts *options) (connection *UnixConnection, err error) { connection = &UnixConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialUnix acts like Dial for Unix networks. -// -// The network must be a Unix network name; see func Dial for details. -// -// If laddr is non-nil, it is used as the local address for the -// connection. -func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { +func dialUnix(network string, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { switch network { case "unix", "unixgram", "unixpacket": default: return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: net.UnknownNetworkError(network)} } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialUnix(context.Background(), laddr, raddr) + c, err := sd.dialUnix(context.Background(), laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConnection, error) { +// DialUnix acts like Dial for Unix networks. +// +// The network must be a Unix network name; see func Dial for details. +// +// If laddr is non-nil, it is used as the local address for the +// connection. +func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { + return dialUnix(network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { conn, err := unixSocket(ctx, sd.network, laddr, raddr, "dial") if err != nil { return nil, err } - return newUnixConnection(conn) + return newUnixConnection(conn, opts) } func unixSocket(ctx context.Context, network string, laddr, raddr sockaddr, mode string) (conn *netFD, err error) { diff --git a/netpoll_options.go b/netpoll_options.go index 023e574c..633b8f95 100644 --- a/netpoll_options.go +++ b/netpoll_options.go @@ -98,16 +98,25 @@ func WithIdleTimeout(timeout time.Duration) Option { }} } +// WithReadBufferThreshold sets the max read buffer threshold. +// If connection already read the threshold bytes data, it will stop read more data. +func WithReadBufferThreshold(threshold int64) Option { + return Option{func(op *options) { + op.readBufferThreshold = threshold + }} +} + // Option . type Option struct { f func(*options) } type options struct { - onPrepare OnPrepare - onConnect OnConnect - onRequest OnRequest - readTimeout time.Duration - writeTimeout time.Duration - idleTimeout time.Duration + onPrepare OnPrepare + onConnect OnConnect + onRequest OnRequest + readTimeout time.Duration + writeTimeout time.Duration + idleTimeout time.Duration + readBufferThreshold int64 // bytes } diff --git a/netpoll_test.go b/netpoll_test.go index 0467e879..c77c0cca 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -397,6 +397,114 @@ func TestClientWriteAndClose(t *testing.T) { MustNil(t, err) } +func TestReadThresholdOption(t *testing.T) { + /* + client => server: 102400 bytes + 5 bytes + server cached: 102400 bytes, and throttled + server read: 102400 bytes, and unthrottled + server cached: 5 bytes + server read: 5 bytes + server write: 102400 bytes + 5 bytes + client cached: 102400 bytes, and throttled + client read: 102400 bytes, and unthrottled + client cached: 5 bytes + client read: 5 bytes + */ + readThreshold := 1024 * 100 + trigger := make(chan struct{}) + msg1 := make([]byte, readThreshold) + msg2 := []byte("hello") + var wg sync.WaitGroup + + // server + ln, err := CreateListener("tcp", ":12345") + MustNil(t, err) + wg.Add(3) + svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + go func() { + defer wg.Done() + // server write + t.Logf("server writing msg1") + _, err := connection.Writer().WriteBinary(msg1) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("server writing msg2") + _, err = connection.Writer().WriteBinary(msg2) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + }() + + // server read + defer wg.Done() + t.Logf("server reading msg1") + trigger <- struct{}{} // let client send msg2 + time.Sleep(time.Millisecond * 100) // ensure client send msg2 + Equal(t, connection.Reader().Len(), readThreshold) + msg, err := connection.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("server reading msg2") + msg, err = connection.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + _, err = connection.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF)) + t.Logf("server closed") + return nil + }, WithReadBufferThreshold(int64(readThreshold))) + defer svr.Shutdown(context.Background()) + go func() { + svr.Serve(ln) + }() + time.Sleep(time.Millisecond * 100) + + // client write + dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold))) + cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second) + MustNil(t, err) + go func() { + defer wg.Done() + t.Logf("client writing msg1") + _, err := cli.Writer().WriteBinary(msg1) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("client writing msg2") + _, err = cli.Writer().WriteBinary(msg2) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + }() + + // client read + trigger <- struct{}{} // let server send msg2 + time.Sleep(time.Millisecond * 100) // ensure server send msg2 + Equal(t, cli.Reader().Len(), readThreshold) + t.Logf("client reading msg1") + msg, err := cli.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("client reading msg2") + msg, err = cli.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + err = cli.Close() + MustNil(t, err) + t.Logf("client closed") + wg.Wait() +} + func createTestListener(network, address string) (Listener, error) { for { ln, err := CreateListener(network, address) diff --git a/nocopy.go b/nocopy.go index 80df5f9b..47ad2c6c 100644 --- a/nocopy.go +++ b/nocopy.go @@ -108,9 +108,9 @@ type Reader interface { // The usage of the design is a two-step operation, first apply for a section of memory, // fill it and then submit. E.g: // -// var buf, _ = Malloc(n) -// buf = append(buf[:0], ...) -// Flush() +// var buf, _ = Malloc(n) +// buf = append(buf[:0], ...) +// Flush() // // Note that it is not recommended to submit self-managed buffers to Writer. // Since the writer is processed asynchronously, if the self-managed buffer is used and recycled after submission, diff --git a/poll.go b/poll.go index c494ffd6..649bf42f 100644 --- a/poll.go +++ b/poll.go @@ -59,8 +59,18 @@ const ( // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. - PollR2RW PollEvent = 0x5 - + PollR2RW PollEvent = 0x4 // PollRW2R is used to remove the writable monitor of FDOperator, generally used with PollR2RW. - PollRW2R PollEvent = 0x6 + PollRW2R PollEvent = 0x5 + + // PollRW2W is used to remove the readable monitor of FDOperator. + PollRW2W PollEvent = 0x6 + // PollW2RW is used to add the readable monitor of FDOperator, generally used with PollRW2W. + PollW2RW PollEvent = 0x7 + PollW2Hup PollEvent = 0x8 + + // PollR2Hup is used to remove the readable monitor of FDOperator. + PollR2Hup PollEvent = 0x9 + // PollHup2R is used to add the readable monitor of FDOperator, generally used with PollR2Hup. + PollHup2R PollEvent = 0x10 ) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 9c8aa8c9..a6488cb2 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -182,11 +182,14 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: operator.inuse() + operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE case PollWritable: operator.inuse() + operator.setMode(opwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollDetach: + operator.setMode(ophup) if operator.OnWrite != nil { // means WaitWrite finished evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE } else { @@ -194,9 +197,26 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { } p.delOperator(operator) case PollR2RW: + operator.setMode(opreadwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: + operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + case PollRW2W: + operator.setMode(opwrite) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + case PollW2RW: + operator.setMode(opreadwrite) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE + case PollR2Hup: + operator.setMode(ophup) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + case PollW2Hup: + operator.setMode(ophup) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + case PollHup2R: + operator.setMode(opread) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_linux.go b/poll_default_linux.go index a0087ee0..72737370 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -244,16 +244,37 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: // server accept a new connection and wait read operator.inuse() + operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_ADD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollWritable: // client create a new connection and wait connect finished operator.inuse() + operator.setMode(opwrite) op, evt.events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister + if operator.getMode() == opdetach { + // protect only detach once + return nil + } + operator.setMode(opdetach) p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollR2RW: // connection wait read/write + operator.setMode(opreadwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2R: // connection wait read + operator.setMode(opread) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollRW2W: + operator.setMode(opwrite) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollW2RW: + operator.setMode(opreadwrite) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollR2Hup, PollW2Hup: + operator.setMode(ophup) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollHup2R: + operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR } return EpollCtl(p.fd, op, operator.FD, &evt) diff --git a/sys_exec.go b/sys_exec.go index 1c8e40e4..8a7c5784 100644 --- a/sys_exec.go +++ b/sys_exec.go @@ -94,8 +94,10 @@ func readv(fd int, bs [][]byte, ivs []syscall.Iovec) (n int, err error) { } // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is -// 1024 and this seems conservative enough for now. Darwin's -// UIO_MAXIOV also seems to be 1024. +// +// 1024 and this seems conservative enough for now. Darwin's +// UIO_MAXIOV also seems to be 1024. +// // iovecs limit length to 2GB(2^31) func iovecs(bs [][]byte, ivs []syscall.Iovec) (iovLen int) { totalLen := 0 From 3c5842f526d91b74286096bb433c9fa834f3ef7a Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Fri, 22 Dec 2023 10:55:31 +0800 Subject: [PATCH 2/5] fix: add throttled check when connection closed by peer --- connection_reactor.go | 18 +++++++++--- connection_test.go | 62 +++++++++++++++++++++++++++++++++++++++ fd_operator.go | 4 ++- netpoll_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++ poll_default_bsd.go | 3 +- poll_default_linux.go | 3 +- 6 files changed, 150 insertions(+), 7 deletions(-) diff --git a/connection_reactor.go b/connection_reactor.go index eb5620ca..8fd582d4 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -156,21 +156,31 @@ func (c *connection) pauseWrite() { // pauseRead removed the monitoring of read events. // pauseRead used in poller func (c *connection) pauseRead() { + // Note that the poller ensure that every fd should read all left data in socket buffer before detach it. + // So the operator mode should never be ophup. + var changeTo PollEvent switch c.operator.getMode() { case opread: - c.operator.Control(PollR2Hup) + changeTo = PollR2Hup case opreadwrite: - c.operator.Control(PollRW2W) + changeTo = PollRW2W + } + if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 0, 1) { + c.operator.Control(changeTo) } } // resumeRead add the monitoring of read events. // resumeRead used by users func (c *connection) resumeRead() { + var changeTo PollEvent switch c.operator.getMode() { case ophup: - c.operator.Control(PollHup2R) + changeTo = PollHup2R case opwrite: - c.operator.Control(PollW2RW) + changeTo = PollW2RW + } + if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 1, 0) { + c.operator.Control(changeTo) } } diff --git a/connection_test.go b/connection_test.go index 9a5c551a..094ce5b6 100644 --- a/connection_test.go +++ b/connection_test.go @@ -763,3 +763,65 @@ func TestConnectionReadThreshold(t *testing.T) { wg.Wait() } + +func TestConnectionReadThresholdWithClosed(t *testing.T) { + var readThreshold int64 = 1024 * 100 + var opts = &options{} + var trigger = make(chan struct{}) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if int64(connection.Reader().Len()) < readThreshold { + return nil + } + Equal(t, connection.Reader().Len(), int(readThreshold)) + trigger <- struct{}{} // let client send final msg and close + <-trigger // wait for client send and close + + // read non-throttled data + buf, err := connection.Reader().Next(int(readThreshold)) + Equal(t, int64(len(buf)), readThreshold) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + t.Logf("read non-throttled data") + + // continue read throttled data + buf, err = connection.Reader().Next(5) + MustNil(t, err) + t.Logf("read throttled data: [%s]", buf) + Equal(t, len(buf), 5) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + Equal(t, connection.Reader().Len(), 0) + + _, err = connection.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF)) + trigger <- struct{}{} + return nil + } + + WithReadBufferThreshold(readThreshold).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + Assert(t, rconn.readBufferThreshold == readThreshold) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + + <-trigger + _, err = wconn.Writer().WriteString("hello") + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + t.Logf("flush final msg") + err = wconn.Close() + MustNil(t, err) + trigger <- struct{}{} + + <-trigger +} diff --git a/fd_operator.go b/fd_operator.go index b94e3825..89dae80f 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -51,7 +51,8 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll - mode int32 + mode int32 + throttled int32 // private, used by operatorCache next *FDOperator @@ -112,4 +113,5 @@ func (op *FDOperator) reset() { op.Outputs, op.OutputAck = nil, nil op.poll = nil op.mode = 0 + op.throttled = 0 } diff --git a/netpoll_test.go b/netpoll_test.go index c77c0cca..85a0bef7 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -505,6 +505,73 @@ func TestReadThresholdOption(t *testing.T) { wg.Wait() } +func TestReadThresholdClosed(t *testing.T) { + /* + client => server: 102400 bytes + 5 bytes + client => server: close connection + server cached: 102400 bytes, and throttled + server read: 102400 bytes, and unthrottled + server cached: 5 bytes + server read: 5 bytes + */ + readThreshold := 1024 * 100 + trigger := make(chan struct{}) + msg1 := make([]byte, readThreshold) + msg2 := []byte("hello") + + // server + ln, err := CreateListener("tcp", ":12345") + MustNil(t, err) + svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + // server read + t.Logf("server reading msg1") + trigger <- struct{}{} // let client send msg2 + <-trigger // ensure client send msg2 and closed + total := 0 + for { + msg, err := connection.Reader().Next(1) + total += len(msg) + if errors.Is(err, ErrEOF) { + break + } + _ = msg + } + Equal(t, total, readThreshold+5) + close(trigger) + return nil + }, WithReadBufferThreshold(int64(readThreshold))) + defer svr.Shutdown(context.Background()) + go func() { + svr.Serve(ln) + }() + time.Sleep(time.Millisecond * 100) + + // client write + dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold))) + cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second) + MustNil(t, err) + t.Logf("client writing msg1") + _, err = cli.Writer().WriteBinary(msg1) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("client writing msg2") + _, err = cli.Writer().WriteBinary(msg2) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + err = cli.Close() + MustNil(t, err) + t.Logf("client closed") + trigger <- struct{}{} + <-trigger +} + func createTestListener(network, address string) (Listener, error) { for { ln, err := CreateListener(network, address) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index a6488cb2..43c5c579 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -115,7 +115,8 @@ func (p *defaultPoll) Wait() error { } } if triggerHup { - if triggerRead && operator.Inputs != nil { + // if peer closed with throttled state, we should ensure we read all left data to avoid data loss + if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil { var leftRead int // read all left data if peer send and close if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) { diff --git a/poll_default_linux.go b/poll_default_linux.go index 72737370..e4c8312b 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -168,7 +168,8 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { } } if triggerHup { - if triggerRead && operator.Inputs != nil { + // if peer closed with throttled state, we should ensure we read all left data to avoid data loss + if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil { // read all left data if peer send and close var leftRead int // read all left data if peer send and close From 8b331dd31c7c96a18e8768fa7cea8e05249089df Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Thu, 4 Jan 2024 11:21:16 +0800 Subject: [PATCH 3/5] chore: rename error --- connection_errors.go | 18 +++++++++--------- connection_impl.go | 2 +- connection_test.go | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/connection_errors.go b/connection_errors.go index 8509a0c1..2b17ada7 100644 --- a/connection_errors.go +++ b/connection_errors.go @@ -37,7 +37,7 @@ const ( // Write I/O buffer timeout, calling by Connection.Writer ErrWriteTimeout = syscall.Errno(0x107) // The wait read size large than read threshold - ErrReadOutOfThreshold = syscall.Errno(0x108) + ErrReadExceedThreshold = syscall.Errno(0x108) ) const ErrnoMask = 0xFF @@ -112,12 +112,12 @@ func (e *exception) Temporary() bool { // Errors defined in netpoll var errnos = [...]string{ - ErrnoMask & ErrConnClosed: "connection has been closed", - ErrnoMask & ErrReadTimeout: "connection read timeout", - ErrnoMask & ErrDialTimeout: "dial wait timeout", - ErrnoMask & ErrDialNoDeadline: "dial no deadline", - ErrnoMask & ErrUnsupported: "netpoll dose not support", - ErrnoMask & ErrEOF: "EOF", - ErrnoMask & ErrWriteTimeout: "connection write timeout", - ErrnoMask & ErrReadOutOfThreshold: "connection read size is out of threshold", + ErrnoMask & ErrConnClosed: "connection has been closed", + ErrnoMask & ErrReadTimeout: "connection read timeout", + ErrnoMask & ErrDialTimeout: "dial wait timeout", + ErrnoMask & ErrDialNoDeadline: "dial no deadline", + ErrnoMask & ErrUnsupported: "netpoll dose not support", + ErrnoMask & ErrEOF: "EOF", + ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrReadExceedThreshold: "connection read size exceeds the threshold", } diff --git a/connection_impl.go b/connection_impl.go index 291de144..585ec4ff 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -406,7 +406,7 @@ func (c *connection) waitRead(n int) (err error) { // cannot wait read with an out of threshold size if c.readBufferThreshold > 0 && int64(n) > c.readBufferThreshold { // just return error and dont do cleanup - return Exception(ErrReadOutOfThreshold, "wait read") + return Exception(ErrReadExceedThreshold, "wait read") } atomic.StoreInt64(&c.waitReadSize, int64(n)) diff --git a/connection_test.go b/connection_test.go index 094ce5b6..528f29d3 100644 --- a/connection_test.go +++ b/connection_test.go @@ -685,7 +685,7 @@ func TestConnectionReadOutOfThreshold(t *testing.T) { defer wg.Done() // read throttled data _, err := connection.Reader().Next(readSize) - Assert(t, errors.Is(err, ErrReadOutOfThreshold), err) + Assert(t, errors.Is(err, ErrReadExceedThreshold), err) connection.Close() return nil } From a9d69755946a76a40f35e5fea299348e5766f4ac Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Mon, 8 Jan 2024 11:23:33 +0800 Subject: [PATCH 4/5] fix: trigger write event only when flush with read throttled --- connection_impl.go | 13 ++++++++----- poll.go | 4 +++- poll_default_bsd.go | 3 +++ poll_default_linux.go | 3 +++ 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/connection_impl.go b/connection_impl.go index 585ec4ff..43ea5cb9 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -415,7 +415,7 @@ func (c *connection) waitRead(n int) (err error) { goto CLEANUP } // wait full n - for c.inputBuffer.Len() < n { + for c.inputBuffer.Len() < n && err == nil { switch c.status(closing) { case poller: err = Exception(ErrEOF, "wait read") @@ -424,9 +424,6 @@ func (c *connection) waitRead(n int) (err error) { default: err = <-c.readTrigger } - if err != nil { - goto CLEANUP - } } CLEANUP: atomic.StoreInt64(&c.waitReadSize, 0) @@ -506,7 +503,13 @@ func (c *connection) flush() error { if c.outputBuffer.IsEmpty() { return nil } - err = c.operator.Control(PollR2RW) + if c.operator.getMode() == ophup { + // triggered read throttled, so here shouldn't trigger read event again + err = c.operator.Control(PollHup2W) + } else { + err = c.operator.Control(PollR2RW) + } + c.operator.done() if err != nil { return Exception(err, "when flush") } diff --git a/poll.go b/poll.go index 649bf42f..ace07133 100644 --- a/poll.go +++ b/poll.go @@ -72,5 +72,7 @@ const ( // PollR2Hup is used to remove the readable monitor of FDOperator. PollR2Hup PollEvent = 0x9 // PollHup2R is used to add the readable monitor of FDOperator, generally used with PollR2Hup. - PollHup2R PollEvent = 0x10 + PollHup2R PollEvent = 0xA + // PollHup2W is used to add the writeable monitor of FDOperator. + PollHup2W PollEvent = 0xB ) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 43c5c579..8fda9c35 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -218,6 +218,9 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { case PollHup2R: operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE + case PollHup2W: + operator.setMode(opwrite) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_linux.go b/poll_default_linux.go index e4c8312b..f51e10c6 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -277,6 +277,9 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { case PollHup2R: operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollHup2W: + operator.setMode(opwrite) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR } return EpollCtl(p.fd, op, operator.FD, &evt) } From b04e0b7a1ecbf556dca7443d7342291c5c009b74 Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Wed, 10 Jan 2024 11:23:54 +0800 Subject: [PATCH 5/5] fix: race condition --- connection_impl.go | 17 +++----- connection_reactor.go | 58 ++++++++++++++-------------- connection_test.go | 16 ++------ docs/guide/guide_cn.md | 6 +-- docs/guide/guide_en.md | 6 +-- fd_operator.go | 27 +++---------- netpoll_test.go | 5 +-- poll.go | 22 ++++------- poll_default_bsd.go | 22 +---------- poll_default_bsd_test.go | 83 ++++++++++++++++++++++++++++++++++++++++ poll_default_linux.go | 23 +---------- 11 files changed, 143 insertions(+), 142 deletions(-) create mode 100644 poll_default_bsd_test.go diff --git a/connection_impl.go b/connection_impl.go index 43ea5cb9..57dab9ab 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -503,17 +503,10 @@ func (c *connection) flush() error { if c.outputBuffer.IsEmpty() { return nil } - if c.operator.getMode() == ophup { - // triggered read throttled, so here shouldn't trigger read event again - err = c.operator.Control(PollHup2W) - } else { - err = c.operator.Control(PollR2RW) - } - c.operator.done() - if err != nil { - return Exception(err, "when flush") - } + // no need to check if resume write successfully + // if resume failed, the connection will be triggered triggerWrite(err), and waitFlush will return err + c.resumeWrite() return c.waitFlush() } @@ -546,8 +539,8 @@ func (c *connection) waitFlush() (err error) { default: } // if timeout, remove write event from poller - // we cannot flush it again, since we don't if the poller is still process outputBuffer - c.operator.Control(PollRW2R) + // we cannot flush it again, since we don't know if the poller is still processing outputBuffer + c.pauseWrite() return Exception(ErrWriteTimeout, c.remoteAddr.String()) } } diff --git a/connection_reactor.go b/connection_reactor.go index 8fd582d4..3ac022d7 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -80,6 +80,12 @@ func (c *connection) closeBuffer() { // inputs implements FDOperator. func (c *connection) inputs(vs [][]byte) (rs [][]byte) { + // trigger throttle + if c.readBufferThreshold > 0 && int64(c.inputBuffer.Len()) >= c.readBufferThreshold { + c.pauseRead() + return + } + vs[0] = c.inputBuffer.book(c.bookSize, c.maxSize) return vs[:1] } @@ -123,6 +129,7 @@ func (c *connection) inputAck(n int) (err error) { func (c *connection) outputs(vs [][]byte) (rs [][]byte, supportZeroCopy bool) { if c.outputBuffer.IsEmpty() { c.pauseWrite() + c.triggerWrite(nil) return rs, c.supportZeroCopy } rs = c.outputBuffer.GetBytes(vs) @@ -137,50 +144,43 @@ func (c *connection) outputAck(n int) (err error) { } if c.outputBuffer.IsEmpty() { c.pauseWrite() + c.triggerWrite(nil) } return nil } +/* The race description of operator event monitoring +- Pause operation will remove old event monitor of operator +- Resume operation will add new event monitor of operator +- Only poller could use Pause to remove event monitor, and poller already hold the op.do() locker +- Only user could use Resume, and user's operation maybe compete with poller's operation +- If competition happen, because of all resume operation will monitor all events, it's safe to do that with a race condition. + * If resume first and pause latter, poller will monitor the accurate events it needs. + * If pause first and resume latter, poller will monitor the duplicate events which will be removed after next poller triggered. + And poller will ensure to remove the duplicate events. +- If there is no readBufferThreshold option, the code path will be more simple and efficient. +*/ + // pauseWrite removed the monitoring of write events. // pauseWrite used in poller func (c *connection) pauseWrite() { - switch c.operator.getMode() { - case opreadwrite: - c.operator.Control(PollRW2R) - case opwrite: - c.operator.Control(PollW2Hup) - } - c.triggerWrite(nil) + c.operator.Control(PollRW2R) +} + +// resumeWrite add the monitoring of write events. +// resumeWrite used by users +func (c *connection) resumeWrite() { + c.operator.Control(PollR2RW) } // pauseRead removed the monitoring of read events. // pauseRead used in poller func (c *connection) pauseRead() { - // Note that the poller ensure that every fd should read all left data in socket buffer before detach it. - // So the operator mode should never be ophup. - var changeTo PollEvent - switch c.operator.getMode() { - case opread: - changeTo = PollR2Hup - case opreadwrite: - changeTo = PollRW2W - } - if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 0, 1) { - c.operator.Control(changeTo) - } + c.operator.Control(PollRW2W) } // resumeRead add the monitoring of read events. // resumeRead used by users func (c *connection) resumeRead() { - var changeTo PollEvent - switch c.operator.getMode() { - case ophup: - changeTo = PollHup2R - case opwrite: - changeTo = PollW2RW - } - if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 1, 0) { - c.operator.Control(changeTo) - } + c.operator.Control(PollW2RW) } diff --git a/connection_test.go b/connection_test.go index 528f29d3..5b2b5c8b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -784,18 +784,10 @@ func TestConnectionReadThresholdWithClosed(t *testing.T) { MustNil(t, err) t.Logf("read non-throttled data") - // continue read throttled data - buf, err = connection.Reader().Next(5) - MustNil(t, err) - t.Logf("read throttled data: [%s]", buf) - Equal(t, len(buf), 5) - MustNil(t, err) - err = connection.Reader().Release() - MustNil(t, err) - Equal(t, connection.Reader().Len(), 0) - - _, err = connection.Reader().Next(1) - Assert(t, errors.Is(err, ErrEOF)) + // continue read throttled data with EOF + for !errors.Is(err, ErrEOF) { + buf, err = connection.Reader().Next(1) + } trigger <- struct{}{} return nil } diff --git a/docs/guide/guide_cn.md b/docs/guide/guide_cn.md index bffa9b29..d18d3ad7 100644 --- a/docs/guide/guide_cn.md +++ b/docs/guide/guide_cn.md @@ -521,12 +521,12 @@ func callback(connection netpoll.Connection) error { ## 8. 如何配置连接的读取阈值大小 ? -Netpoll 默认不会对端发送数据的读取速度有任何限制,每当连接有数据时,Netpoll 会尽可能快地将数据存放在自己的 buffer 中。但有时候可能用户不希望数据过快发送,或者是希望控制服务内存使用量,又或者业务 OnRequest 回调处理速度很慢需要限制发送方速度,此时可以使用 `WithReadThreshold` 来控制读取的最大阈值。 +Netpoll 默认不会对端发送数据的读取速度有任何限制,每当连接有数据时,Netpoll 会尽可能快地将数据存放在自己的 buffer 中。但有时候可能用户不希望数据过快发送,或者是希望控制服务内存使用量,又或者业务 OnRequest 回调处理速度很慢需要限制发送方速度,此时可以使用 `WithReadBufferThreshold` 来控制读取的最大阈值。 ### Client 侧使用 ``` -dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +dialer := netpoll.NewDialer(netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1)) // 1GB conn, _ = dialer.DialConnection(network, address, timeout) ``` @@ -535,7 +535,7 @@ conn, _ = dialer.DialConnection(network, address, timeout) ``` eventLoop, _ := netpoll.NewEventLoop( handle, - netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB + netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1), // 1GB ) ``` diff --git a/docs/guide/guide_en.md b/docs/guide/guide_en.md index 1cbbcd8d..260fe92d 100644 --- a/docs/guide/guide_en.md +++ b/docs/guide/guide_en.md @@ -564,12 +564,12 @@ By default, Netpoll does not place any limit on the reading speed of data sent b Whenever there have more data on the connection, Netpoll will read the data into its own buffer as quickly as possible. But sometimes users may not want data to be read too quickly, or they want to control the service memory usage, or the user's OnRequest callback processing data very slowly and need to control the peer's send speed. -In this case, you can use `WithReadThreshold` to control the maximum reading threshold. +In this case, you can use `WithReadBufferThreshold` to control the maximum reading threshold. ### Client side use ``` -dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +dialer := netpoll.NewDialer(netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1)) // 1GB conn, _ = dialer.DialConnection(network, address, timeout) ``` @@ -578,7 +578,7 @@ conn, _ = dialer.DialConnection(network, address, timeout) ``` eventLoop, _ := netpoll.NewEventLoop( handle, - netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB + netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1), // 1GB ) ``` diff --git a/fd_operator.go b/fd_operator.go index 89dae80f..1ac843a9 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -19,15 +19,6 @@ import ( "sync/atomic" ) -const ( - opdetach int32 = -1 - _ int32 = 0 // default op mode, means nothing - opread int32 = 1 - opwrite int32 = 2 - opreadwrite int32 = 3 - ophup int32 = 4 -) - // FDOperator is a collection of operations on file descriptors. type FDOperator struct { // FD is file descriptor, poll will bind when register. @@ -51,8 +42,8 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll - mode int32 - throttled int32 + // protect only detach once + detached int32 // private, used by operatorCache next *FDOperator @@ -61,6 +52,9 @@ type FDOperator struct { } func (op *FDOperator) Control(event PollEvent) error { + if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 { + return nil + } return op.poll.Control(op, event) } @@ -68,14 +62,6 @@ func (op *FDOperator) Free() { op.poll.Free(op) } -func (op *FDOperator) getMode() int32 { - return atomic.LoadInt32(&op.mode) -} - -func (op *FDOperator) setMode(mode int32) { - atomic.StoreInt32(&op.mode, mode) -} - func (op *FDOperator) do() (can bool) { return atomic.CompareAndSwapInt32(&op.state, 1, 2) } @@ -112,6 +98,5 @@ func (op *FDOperator) reset() { op.Inputs, op.InputAck = nil, nil op.Outputs, op.OutputAck = nil, nil op.poll = nil - op.mode = 0 - op.throttled = 0 + op.detached = 0 } diff --git a/netpoll_test.go b/netpoll_test.go index 85a0bef7..9ad6c3da 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -456,7 +456,7 @@ func TestReadThresholdOption(t *testing.T) { Equal(t, len(msg), 5) _, err = connection.Reader().Next(1) - Assert(t, errors.Is(err, ErrEOF)) + Assert(t, errors.Is(err, ErrEOF), err) t.Logf("server closed") return nil }, WithReadBufferThreshold(int64(readThreshold))) @@ -530,16 +530,13 @@ func TestReadThresholdClosed(t *testing.T) { t.Logf("server reading msg1") trigger <- struct{}{} // let client send msg2 <-trigger // ensure client send msg2 and closed - total := 0 for { msg, err := connection.Reader().Next(1) - total += len(msg) if errors.Is(err, ErrEOF) { break } _ = msg } - Equal(t, total, readThreshold+5) close(trigger) return nil }, WithReadBufferThreshold(int64(readThreshold))) diff --git a/poll.go b/poll.go index ace07133..915b1f9d 100644 --- a/poll.go +++ b/poll.go @@ -48,31 +48,23 @@ type PollEvent int const ( // PollReadable is used to monitor whether the FDOperator registered by // listener and connection is readable or closed. - PollReadable PollEvent = 0x1 + PollReadable PollEvent = iota + 1 // PollWritable is used to monitor whether the FDOperator created by the dialer is writable or closed. // ET mode must be used (still need to poll hup after being writable) - PollWritable PollEvent = 0x2 + PollWritable // PollDetach is used to remove the FDOperator from poll. - PollDetach PollEvent = 0x3 + PollDetach // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. - PollR2RW PollEvent = 0x4 + PollR2RW // PollRW2R is used to remove the writable monitor of FDOperator, generally used with PollR2RW. - PollRW2R PollEvent = 0x5 + PollRW2R // PollRW2W is used to remove the readable monitor of FDOperator. - PollRW2W PollEvent = 0x6 + PollRW2W // PollW2RW is used to add the readable monitor of FDOperator, generally used with PollRW2W. - PollW2RW PollEvent = 0x7 - PollW2Hup PollEvent = 0x8 - - // PollR2Hup is used to remove the readable monitor of FDOperator. - PollR2Hup PollEvent = 0x9 - // PollHup2R is used to add the readable monitor of FDOperator, generally used with PollR2Hup. - PollHup2R PollEvent = 0xA - // PollHup2W is used to add the writeable monitor of FDOperator. - PollHup2W PollEvent = 0xB + PollW2RW ) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 8fda9c35..33c0b52b 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -115,8 +115,7 @@ func (p *defaultPoll) Wait() error { } } if triggerHup { - // if peer closed with throttled state, we should ensure we read all left data to avoid data loss - if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil { + if triggerRead && operator.Inputs != nil { var leftRead int // read all left data if peer send and close if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) { @@ -183,14 +182,11 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: operator.inuse() - operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE case PollWritable: operator.inuse() - operator.setMode(opwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollDetach: - operator.setMode(ophup) if operator.OnWrite != nil { // means WaitWrite finished evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE } else { @@ -198,29 +194,13 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { } p.delOperator(operator) case PollR2RW: - operator.setMode(opreadwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: - operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE case PollRW2W: - operator.setMode(opwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE case PollW2RW: - operator.setMode(opreadwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE - case PollR2Hup: - operator.setMode(ophup) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE - case PollW2Hup: - operator.setMode(ophup) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE - case PollHup2R: - operator.setMode(opread) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE - case PollHup2W: - operator.setMode(opwrite) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_bsd_test.go b/poll_default_bsd_test.go new file mode 100644 index 00000000..92185f9e --- /dev/null +++ b/poll_default_bsd_test.go @@ -0,0 +1,83 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build darwin +// +build darwin + +package netpoll + +import ( + "syscall" + "testing" +) + +func TestKqueueEvent(t *testing.T) { + kqfd, err := syscall.Kqueue() + defer syscall.Close(kqfd) + _, err = syscall.Kevent(kqfd, []syscall.Kevent_t{{ + Ident: 0, + Filter: syscall.EVFILT_USER, + Flags: syscall.EV_ADD | syscall.EV_CLEAR, + }}, nil, nil) + MustNil(t, err) + + rfd, wfd := GetSysFdPairs() + defer syscall.Close(rfd) + defer syscall.Close(wfd) + + // add read event + changes := make([]syscall.Kevent_t, 1) + changes[0].Ident = uint64(rfd) + changes[0].Filter = syscall.EVFILT_READ + changes[0].Flags = syscall.EV_ADD + _, err = syscall.Kevent(kqfd, changes, nil, nil) + MustNil(t, err) + + // write + send := []byte("hello") + recv := make([]byte, 5) + _, err = syscall.Write(wfd, send) + MustNil(t, err) + + // check readable + events := make([]syscall.Kevent_t, 128) + n, err := syscall.Kevent(kqfd, nil, events, nil) + MustNil(t, err) + Equal(t, n, 1) + Assert(t, events[0].Filter == syscall.EVFILT_READ) + // read + _, err = syscall.Read(rfd, recv) + MustNil(t, err) + Equal(t, string(recv), string(send)) + + // delete read + changes[0].Ident = uint64(rfd) + changes[0].Filter = syscall.EVFILT_READ + changes[0].Flags = syscall.EV_DELETE + _, err = syscall.Kevent(kqfd, changes, nil, nil) + MustNil(t, err) + + // write + _, err = syscall.Write(wfd, send) + MustNil(t, err) + + // check readable + n, err = syscall.Kevent(kqfd, nil, events, &syscall.Timespec{Sec: 1}) + MustNil(t, err) + Equal(t, n, 0) + // read + _, err = syscall.Read(rfd, recv) + MustNil(t, err) + Equal(t, string(recv), string(send)) +} diff --git a/poll_default_linux.go b/poll_default_linux.go index f51e10c6..8637a577 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -168,8 +168,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { } } if triggerHup { - // if peer closed with throttled state, we should ensure we read all left data to avoid data loss - if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil { + if triggerRead && operator.Inputs != nil { // read all left data if peer send and close var leftRead int // read all left data if peer send and close @@ -245,41 +244,21 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: // server accept a new connection and wait read operator.inuse() - operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_ADD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollWritable: // client create a new connection and wait connect finished operator.inuse() - operator.setMode(opwrite) op, evt.events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister - if operator.getMode() == opdetach { - // protect only detach once - return nil - } - operator.setMode(opdetach) p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollR2RW: // connection wait read/write - operator.setMode(opreadwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2R: // connection wait read - operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2W: - operator.setMode(opwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollW2RW: - operator.setMode(opreadwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollR2Hup, PollW2Hup: - operator.setMode(ophup) - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollHup2R: - operator.setMode(opread) - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollHup2W: - operator.setMode(opwrite) - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR } return EpollCtl(p.fd, op, operator.FD, &evt) }