diff --git a/connection_errors.go b/connection_errors.go index 1edfa21d..2b17ada7 100644 --- a/connection_errors.go +++ b/connection_errors.go @@ -36,6 +36,8 @@ const ( ErrEOF = syscall.Errno(0x106) // Write I/O buffer timeout, calling by Connection.Writer ErrWriteTimeout = syscall.Errno(0x107) + // The wait read size large than read threshold + ErrReadExceedThreshold = syscall.Errno(0x108) ) const ErrnoMask = 0xFF @@ -110,11 +112,12 @@ func (e *exception) Temporary() bool { // Errors defined in netpoll var errnos = [...]string{ - ErrnoMask & ErrConnClosed: "connection has been closed", - ErrnoMask & ErrReadTimeout: "connection read timeout", - ErrnoMask & ErrDialTimeout: "dial wait timeout", - ErrnoMask & ErrDialNoDeadline: "dial no deadline", - ErrnoMask & ErrUnsupported: "netpoll dose not support", - ErrnoMask & ErrEOF: "EOF", - ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrConnClosed: "connection has been closed", + ErrnoMask & ErrReadTimeout: "connection read timeout", + ErrnoMask & ErrDialTimeout: "dial wait timeout", + ErrnoMask & ErrDialNoDeadline: "dial no deadline", + ErrnoMask & ErrUnsupported: "netpoll dose not support", + ErrnoMask & ErrEOF: "EOF", + ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrReadExceedThreshold: "connection read size exceeds the threshold", } diff --git a/connection_impl.go b/connection_impl.go index b683b4df..a7af3101 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -33,21 +33,22 @@ type connection struct { netFD onEvent locker - operator *FDOperator - readTimeout time.Duration - readTimer *time.Timer - readTrigger chan error - waitReadSize int64 - writeTimeout time.Duration - writeTimer *time.Timer - writeTrigger chan error - inputBuffer *LinkBuffer - outputBuffer *LinkBuffer - outputBarrier *barrier - supportZeroCopy bool - maxSize int // The maximum size of data between two Release(). - bookSize int // The size of data that can be read at once. - state int32 // 0: not connected, 1: connected, 2: disconnected. Connection state should be changed sequentially. + operator *FDOperator + readTimeout time.Duration + readTimer *time.Timer + readTrigger chan error + waitReadSize int64 + writeTimeout time.Duration + writeTimer *time.Timer + writeTrigger chan error + inputBuffer *LinkBuffer + outputBuffer *LinkBuffer + outputBarrier *barrier + supportZeroCopy bool + maxSize int // The maximum size of data between two Release(). + bookSize int // The size of data that can be read at once. + state int32 // 0: not connected, 1: connected, 2: disconnected. Connection state should be changed sequentially. + readBufferThreshold int64 // The readBufferThreshold limit the size of connection inputBuffer. In bytes. } var ( @@ -95,6 +96,12 @@ func (c *connection) SetWriteTimeout(timeout time.Duration) error { return nil } +// SetReadBufferThreshold implements Connection. +func (c *connection) SetReadBufferThreshold(threshold int64) error { + c.readBufferThreshold = threshold + return nil +} + // ------------------------------------------ implement zero-copy reader ------------------------------------------ // Next implements Connection. @@ -396,28 +403,41 @@ func (c *connection) triggerWrite(err error) { // waitRead will wait full n bytes. func (c *connection) waitRead(n int) (err error) { if n <= c.inputBuffer.Len() { - return nil + goto CLEANUP } + // cannot wait read with an out of threshold size + if c.readBufferThreshold > 0 && int64(n) > c.readBufferThreshold { + // just return error and dont do cleanup + return Exception(ErrReadExceedThreshold, "wait read") + } + atomic.StoreInt64(&c.waitReadSize, int64(n)) - defer atomic.StoreInt64(&c.waitReadSize, 0) if c.readTimeout > 0 { - return c.waitReadWithTimeout(n) + err = c.waitReadWithTimeout(n) + goto CLEANUP } // wait full n - for c.inputBuffer.Len() < n { + for c.inputBuffer.Len() < n && err == nil { switch c.status(closing) { case poller: - return Exception(ErrEOF, "wait read") + err = Exception(ErrEOF, "wait read") case user: - return Exception(ErrConnClosed, "wait read") + err = Exception(ErrConnClosed, "wait read") default: err = <-c.readTrigger - if err != nil { - return err - } } } - return nil +CLEANUP: + atomic.StoreInt64(&c.waitReadSize, 0) + if c.readBufferThreshold > 0 && err == nil { + // only resume read when current read size could make newBufferSize < readBufferThreshold + bufferSize := int64(c.inputBuffer.Len()) + newBufferSize := bufferSize - int64(n) + if bufferSize >= c.readBufferThreshold && newBufferSize < c.readBufferThreshold { + c.resumeRead() + } + } + return err } // waitReadWithTimeout will wait full n bytes or until timeout. @@ -485,11 +505,10 @@ func (c *connection) flush() error { if c.outputBuffer.IsEmpty() { return nil } - err = c.operator.Control(PollR2RW) - if err != nil { - return Exception(err, "when flush") - } + // no need to check if resume write successfully + // if resume failed, the connection will be triggered triggerWrite(err), and waitFlush will return err + c.resumeWrite() return c.waitFlush() } @@ -522,8 +541,8 @@ func (c *connection) waitFlush() (err error) { default: } // if timeout, remove write event from poller - // we cannot flush it again, since we don't if the poller is still process outputBuffer - c.operator.Control(PollRW2R) + // we cannot flush it again, since we don't know if the poller is still processing outputBuffer + c.pauseWrite() return Exception(ErrWriteTimeout, c.remoteAddr.String()) } } diff --git a/connection_onevent.go b/connection_onevent.go index 35b7c001..9db2a9a5 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -113,6 +113,7 @@ func (c *connection) onPrepare(opts *options) (err error) { c.SetReadTimeout(opts.readTimeout) c.SetWriteTimeout(opts.writeTimeout) c.SetIdleTimeout(opts.idleTimeout) + c.SetReadBufferThreshold(opts.readBufferThreshold) // calling prepare first and then register. if opts.onPrepare != nil { diff --git a/connection_reactor.go b/connection_reactor.go index 25b4dec5..61b01230 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -84,6 +84,12 @@ func (c *connection) closeBuffer() { // inputs implements FDOperator. func (c *connection) inputs(vs [][]byte) (rs [][]byte) { + // trigger throttle + if c.readBufferThreshold > 0 && int64(c.inputBuffer.Len()) >= c.readBufferThreshold { + c.pauseRead() + return + } + vs[0] = c.inputBuffer.book(c.bookSize, c.maxSize) return vs[:1] } @@ -108,6 +114,11 @@ func (c *connection) inputAck(n int) (err error) { c.maxSize = mallocMax } + // trigger throttle + if c.readBufferThreshold > 0 && int64(length) >= c.readBufferThreshold { + c.pauseRead() + } + var needTrigger = true if length == n { // first start onRequest needTrigger = c.onRequest() @@ -121,7 +132,8 @@ func (c *connection) inputAck(n int) (err error) { // outputs implements FDOperator. func (c *connection) outputs(vs [][]byte) (rs [][]byte, supportZeroCopy bool) { if c.outputBuffer.IsEmpty() { - c.rw2r() + c.pauseWrite() + c.triggerWrite(nil) return rs, c.supportZeroCopy } rs = c.outputBuffer.GetBytes(vs) @@ -135,13 +147,44 @@ func (c *connection) outputAck(n int) (err error) { c.outputBuffer.Release() } if c.outputBuffer.IsEmpty() { - c.rw2r() + c.pauseWrite() + c.triggerWrite(nil) } return nil } -// rw2r removed the monitoring of write events. -func (c *connection) rw2r() { +/* The race description of operator event monitoring +- Pause operation will remove old event monitor of operator +- Resume operation will add new event monitor of operator +- Only poller could use Pause to remove event monitor, and poller already hold the op.do() locker +- Only user could use Resume, and user's operation maybe compete with poller's operation +- If competition happen, because of all resume operation will monitor all events, it's safe to do that with a race condition. + * If resume first and pause latter, poller will monitor the accurate events it needs. + * If pause first and resume latter, poller will monitor the duplicate events which will be removed after next poller triggered. + And poller will ensure to remove the duplicate events. +- If there is no readBufferThreshold option, the code path will be more simple and efficient. +*/ + +// pauseWrite removed the monitoring of write events. +// pauseWrite used in poller +func (c *connection) pauseWrite() { c.operator.Control(PollRW2R) - c.triggerWrite(nil) +} + +// resumeWrite add the monitoring of write events. +// resumeWrite used by users +func (c *connection) resumeWrite() { + c.operator.Control(PollR2RW) +} + +// pauseRead removed the monitoring of read events. +// pauseRead used in poller +func (c *connection) pauseRead() { + c.operator.Control(PollRW2W) +} + +// resumeRead add the monitoring of read events. +// resumeRead used by users +func (c *connection) resumeRead() { + c.operator.Control(PollW2RW) } diff --git a/connection_test.go b/connection_test.go index 548d98a2..90f7ea5d 100644 --- a/connection_test.go +++ b/connection_test.go @@ -499,8 +499,6 @@ func TestConnDetach(t *testing.T) { func TestParallelShortConnection(t *testing.T) { ln, err := createTestListener("tcp", ":12345") MustNil(t, err) - defer ln.Close() - var received int64 el, err := NewEventLoop(func(ctx context.Context, connection Connection) error { data, err := connection.Reader().Next(connection.Reader().Len()) @@ -511,6 +509,7 @@ func TestParallelShortConnection(t *testing.T) { //t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive()) return nil }) + defer el.Shutdown(context.Background()) go func() { el.Serve(ln) }() @@ -646,8 +645,6 @@ func TestConnectionServerClose(t *testing.T) { func TestConnectionDailTimeoutAndClose(t *testing.T) { ln, err := createTestListener("tcp", ":12345") MustNil(t, err) - defer ln.Close() - el, err := NewEventLoop( func(ctx context.Context, connection Connection) error { _, err = connection.Reader().Next(connection.Reader().Len()) @@ -671,10 +668,156 @@ func TestConnectionDailTimeoutAndClose(t *testing.T) { go func() { defer wg.Done() conn, err := DialConnection("tcp", ":12345", time.Nanosecond) - Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout")) + Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout"), err) _ = conn }() } wg.Wait() } } + +func TestConnectionReadOutOfThreshold(t *testing.T) { + var readThreshold = 1024 * 100 + var readSize = readThreshold + 1 + var opts = &options{} + var wg sync.WaitGroup + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + defer wg.Done() + // read throttled data + _, err := connection.Reader().Next(readSize) + Assert(t, errors.Is(err, ErrReadExceedThreshold), err) + connection.Close() + return nil + } + + WithReadBufferThreshold(int64(readThreshold)).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + wg.Wait() +} + +func TestConnectionReadThreshold(t *testing.T) { + var readThreshold int64 = 1024 * 100 + var opts = &options{} + var wg sync.WaitGroup + var throttled int32 + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if int64(connection.Reader().Len()) < readThreshold { + return nil + } + defer wg.Done() + + atomic.StoreInt32(&throttled, 1) + // check if no more read data when throttled + inbuffered := connection.Reader().Len() + t.Logf("Inbuffered: %d", inbuffered) + time.Sleep(time.Millisecond * 100) + Equal(t, inbuffered, connection.Reader().Len()) + + // read non-throttled data + buf, err := connection.Reader().Next(int(readThreshold)) + Equal(t, int64(len(buf)), readThreshold) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + t.Logf("read non-throttled data") + + // continue read throttled data + buf, err = connection.Reader().Next(5) + MustNil(t, err) + t.Logf("read throttled data: [%s]", buf) + Equal(t, len(buf), 5) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + Equal(t, connection.Reader().Len(), 0) + return nil + } + + WithReadBufferThreshold(readThreshold).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + Assert(t, rconn.readBufferThreshold == readThreshold) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + _, err = wconn.Writer().WriteString("hello") + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + t.Logf("flush final msg") + + wg.Wait() +} + +func TestConnectionReadThresholdWithClosed(t *testing.T) { + var readThreshold int64 = 1024 * 100 + var opts = &options{} + var trigger = make(chan struct{}) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if int64(connection.Reader().Len()) < readThreshold { + return nil + } + Equal(t, connection.Reader().Len(), int(readThreshold)) + trigger <- struct{}{} // let client send final msg and close + <-trigger // wait for client send and close + + // read non-throttled data + buf, err := connection.Reader().Next(int(readThreshold)) + Equal(t, int64(len(buf)), readThreshold) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + t.Logf("read non-throttled data") + + // continue read throttled data with EOF + for !errors.Is(err, ErrEOF) { + buf, err = connection.Reader().Next(1) + } + trigger <- struct{}{} + return nil + } + + WithReadBufferThreshold(readThreshold).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + Assert(t, rconn.readBufferThreshold == readThreshold) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + + <-trigger + _, err = wconn.Writer().WriteString("hello") + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + t.Logf("flush final msg") + err = wconn.Close() + MustNil(t, err) + trigger <- struct{}{} + + <-trigger +} diff --git a/docs/guide/guide_cn.md b/docs/guide/guide_cn.md index f9b9a0db..d18d3ad7 100644 --- a/docs/guide/guide_cn.md +++ b/docs/guide/guide_cn.md @@ -519,6 +519,26 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. 如何配置连接的读取阈值大小 ? + +Netpoll 默认不会对端发送数据的读取速度有任何限制,每当连接有数据时,Netpoll 会尽可能快地将数据存放在自己的 buffer 中。但有时候可能用户不希望数据过快发送,或者是希望控制服务内存使用量,又或者业务 OnRequest 回调处理速度很慢需要限制发送方速度,此时可以使用 `WithReadBufferThreshold` 来控制读取的最大阈值。 + +### Client 侧使用 + +``` +dialer := netpoll.NewDialer(netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server 侧使用 + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # 注意事项 ## 1. 错误设置 NumLoops diff --git a/docs/guide/guide_en.md b/docs/guide/guide_en.md index 08c522f3..260fe92d 100644 --- a/docs/guide/guide_en.md +++ b/docs/guide/guide_en.md @@ -558,6 +558,30 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. How to configure the read threshold of the connection? + +By default, Netpoll does not place any limit on the reading speed of data sent by the end. +Whenever there have more data on the connection, Netpoll will read the data into its own buffer as quickly as possible. + +But sometimes users may not want data to be read too quickly, or they want to control the service memory usage, or the user's OnRequest callback processing data very slowly and need to control the peer's send speed. +In this case, you can use `WithReadBufferThreshold` to control the maximum reading threshold. + +### Client side use + +``` +dialer := netpoll.NewDialer(netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server side use + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadBufferThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # Attention ## 1. Wrong setting of NumLoops diff --git a/eventloop.go b/eventloop.go index 425cd95f..6911ba1b 100644 --- a/eventloop.go +++ b/eventloop.go @@ -68,13 +68,13 @@ type OnPrepare func(connection Connection) context.Context // // An example usage in TCP Proxy scenario: // -// func onConnect(ctx context.Context, upstream netpoll.Connection) context.Context { -// downstream, _ := netpoll.DialConnection("tcp", downstreamAddr, time.Second) -// return context.WithValue(ctx, downstreamKey, downstream) -// } -// func onRequest(ctx context.Context, upstream netpoll.Connection) error { -// downstream := ctx.Value(downstreamKey).(netpoll.Connection) -// } +// func onConnect(ctx context.Context, upstream netpoll.Connection) context.Context { +// downstream, _ := netpoll.DialConnection("tcp", downstreamAddr, time.Second) +// return context.WithValue(ctx, downstreamKey, downstream) +// } +// func onRequest(ctx context.Context, upstream netpoll.Connection) error { +// downstream := ctx.Value(downstreamKey).(netpoll.Connection) +// } type OnConnect func(ctx context.Context, connection Connection) context.Context // OnDisconnect is called once connection is going to be closed. @@ -86,14 +86,14 @@ type OnDisconnect func(ctx context.Context, connection Connection) // netpoll actively reads the data in LT mode and places it in the connection's input buffer. // Generally, OnRequest starts handling the data in the following way: // -// func OnRequest(ctx context, connection Connection) error { -// input := connection.Reader().Next(n) -// handling input data... -// send, _ := connection.Writer().Malloc(l) -// copy(send, output) -// connection.Flush() -// return nil -// } +// func OnRequest(ctx context, connection Connection) error { +// input := connection.Reader().Next(n) +// handling input data... +// send, _ := connection.Writer().Malloc(l) +// copy(send, output) +// connection.Flush() +// return nil +// } // // OnRequest will run in a separate goroutine and // it is guaranteed that there is one and only one OnRequest running at the same time. diff --git a/mux/shard_queue_test.go b/mux/shard_queue_test.go index 7a595d21..b0d3f5b4 100644 --- a/mux/shard_queue_test.go +++ b/mux/shard_queue_test.go @@ -19,6 +19,7 @@ package mux import ( "net" + "sync" "testing" "time" @@ -26,28 +27,25 @@ import ( ) func TestShardQueue(t *testing.T) { - var svrConn net.Conn accepted := make(chan struct{}) network, address := "tcp", ":18888" ln, err := net.Listen("tcp", ":18888") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) + count, pkgsize := 16, 11 + var wg sync.WaitGroup + wg.Add(1) go func() { - var err error - for { - select { - case <-stop: - err = ln.Close() - MustNil(t, err) - return - default: - } - svrConn, err = ln.Accept() - MustNil(t, err) - accepted <- struct{}{} - } + defer wg.Done() + svrConn, err := ln.Accept() + MustNil(t, err) + accepted <- struct{}{} + + total := count * pkgsize + recv := make([]byte, total) + rn, err := svrConn.Read(recv) + MustNil(t, err) + Equal(t, rn, total) }() conn, err := netpoll.DialConnection(network, address, time.Second) @@ -56,8 +54,7 @@ func TestShardQueue(t *testing.T) { // test queue := NewShardQueue(4, conn) - count, pkgsize := 16, 11 - for i := 0; i < int(count); i++ { + for i := 0; i < count; i++ { var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) { buf = netpoll.NewLinkBuffer(pkgsize) buf.Malloc(pkgsize) @@ -68,14 +65,8 @@ func TestShardQueue(t *testing.T) { err = queue.Close() MustNil(t, err) - total := count * pkgsize - recv := make([]byte, total) - rn, err := svrConn.Read(recv) - MustNil(t, err) - Equal(t, rn, total) -} -// TODO: need mock flush -func BenchmarkShardQueue(b *testing.B) { - b.Skip() + wg.Wait() + err = ln.Close() + MustNil(t, err) } diff --git a/net_dialer.go b/net_dialer.go index 4c4e8dd2..edcafca1 100644 --- a/net_dialer.go +++ b/net_dialer.go @@ -29,13 +29,22 @@ func DialConnection(network, address string, timeout time.Duration) (connection } // NewDialer only support TCP and unix socket now. -func NewDialer() Dialer { - return &dialer{} +func NewDialer(opts ...Option) Dialer { + d := new(dialer) + if len(opts) > 0 { + d.opts = new(options) + for _, opt := range opts { + opt.f(d.opts) + } + } + return d } var defaultDialer = NewDialer() -type dialer struct{} +type dialer struct { + opts *options +} // DialTimeout implements Dialer. func (d *dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { @@ -59,7 +68,7 @@ func (d *dialer) DialConnection(network, address string, timeout time.Duration) raddr := &UnixAddr{ UnixAddr: net.UnixAddr{Name: address, Net: network}, } - return DialUnix(network, nil, raddr) + return dialUnix(network, nil, raddr, d.opts) default: return nil, net.UnknownNetworkError(network) } @@ -95,9 +104,9 @@ func (d *dialer) dialTCP(ctx context.Context, network, address string) (connecti tcpAddr.Port = portnum tcpAddr.Zone = ipaddr.Zone if ipaddr.IP != nil && ipaddr.IP.To4() == nil { - connection, err = DialTCP(ctx, "tcp6", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp6", nil, tcpAddr, d.opts) } else { - connection, err = DialTCP(ctx, "tcp", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp", nil, tcpAddr, d.opts) } if err == nil { return connection, nil diff --git a/net_dialer_test.go b/net_dialer_test.go index 7383fd0d..deca3889 100644 --- a/net_dialer_test.go +++ b/net_dialer_test.go @@ -38,15 +38,14 @@ func TestDialerTCP(t *testing.T) { ln, err := CreateListener("tcp", ":1234") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) - + stop := make(chan int) go func() { for { select { case <-stop: err := ln.Close() MustNil(t, err) + close(stop) return default: } @@ -61,6 +60,9 @@ func TestDialerTCP(t *testing.T) { MustNil(t, err) MustTrue(t, strings.HasPrefix(conn.LocalAddr().String(), "127.0.0.1:")) Equal(t, conn.RemoteAddr().String(), "127.0.0.1:1234") + + stop <- 0 + <-stop } func TestDialerUnix(t *testing.T) { diff --git a/net_polldesc_test.go b/net_polldesc_test.go index 40804b62..6f379167 100644 --- a/net_polldesc_test.go +++ b/net_polldesc_test.go @@ -30,15 +30,14 @@ func TestRuntimePoll(t *testing.T) { ln, err := CreateListener("tcp", ":1234") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) - + stop := make(chan int) go func() { for { select { case <-stop: err := ln.Close() MustNil(t, err) + close(stop) return default: } @@ -54,4 +53,7 @@ func TestRuntimePoll(t *testing.T) { MustNil(t, err) conn.Close() } + + stop <- 0 + <-stop } diff --git a/net_sock.go b/net_sock.go index a3d318c7..c6ec98e8 100644 --- a/net_sock.go +++ b/net_sock.go @@ -55,29 +55,29 @@ func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, soty // address family, both AF_INET and AF_INET6, and a wildcard address // like the following: // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with a wildcard address: If the platform supports -// both IPv6 and IPv4-mapped IPv6 communication capabilities, -// or does not support IPv4, we use a dual stack, AF_INET6 and -// IPV6_V6ONLY=0, wildcard address listen. The dual stack -// wildcard address listen may fall back to an IPv6-only, -// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen. -// Otherwise we prefer an IPv4-only, AF_INET, wildcard address -// listen. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with a wildcard address: If the platform supports +// both IPv6 and IPv4-mapped IPv6 communication capabilities, +// or does not support IPv4, we use a dual stack, AF_INET6 and +// IPV6_V6ONLY=0, wildcard address listen. The dual stack +// wildcard address listen may fall back to an IPv6-only, +// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen. +// Otherwise we prefer an IPv4-only, AF_INET, wildcard address +// listen. // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with an IPv4 wildcard address: same as above. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv4 wildcard address: same as above. // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with an IPv6 wildcard address: same as above. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv6 wildcard address: same as above. // -// - A listen for an IPv4 communication domain, "tcp4" or "udp4", -// with an IPv4 wildcard address: We use an IPv4-only, AF_INET, -// wildcard address listen. +// - A listen for an IPv4 communication domain, "tcp4" or "udp4", +// with an IPv4 wildcard address: We use an IPv4-only, AF_INET, +// wildcard address listen. // -// - A listen for an IPv6 communication domain, "tcp6" or "udp6", -// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6 -// and IPV6_V6ONLY=1, wildcard address listen. +// - A listen for an IPv6 communication domain, "tcp6" or "udp6", +// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6 +// and IPV6_V6ONLY=1, wildcard address listen. // // Otherwise guess: If the addresses are IPv4 then returns AF_INET, // or else returns AF_INET6. It also returns a boolean value what diff --git a/net_tcpsock.go b/net_tcpsock.go index 2c90634b..87fb84eb 100644 --- a/net_tcpsock.go +++ b/net_tcpsock.go @@ -138,23 +138,16 @@ type TCPConnection struct { } // newTCPConnection wraps *TCPConnection. -func newTCPConnection(conn Conn) (connection *TCPConnection, err error) { +func newTCPConnection(conn Conn, opts *options) (connection *TCPConnection, err error) { connection = &TCPConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialTCP acts like Dial for TCP networks. -// -// The network must be a TCP network name; see func Dial for details. -// -// If laddr is nil, a local address is automatically chosen. -// If the IP field of raddr is nil or an unspecified IP address, the -// local system is assumed. -func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { +func dialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -167,14 +160,25 @@ func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPCo ctx = context.Background() } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialTCP(ctx, laddr, raddr) + c, err := sd.dialTCP(ctx, laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConnection, error) { +// DialTCP acts like Dial for TCP networks. +// +// The network must be a TCP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { + return dialTCP(ctx, network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { conn, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") // TCP has a rarely used mechanism called a 'simultaneous connection' in @@ -211,7 +215,7 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo if err != nil { return nil, err } - return newTCPConnection(conn) + return newTCPConnection(conn, opts) } func selfConnect(conn *netFD, err error) bool { diff --git a/net_unixsock.go b/net_unixsock.go index c5213a1c..9564dbc7 100644 --- a/net_unixsock.go +++ b/net_unixsock.go @@ -74,41 +74,45 @@ type UnixConnection struct { } // newUnixConnection wraps UnixConnection. -func newUnixConnection(conn Conn) (connection *UnixConnection, err error) { +func newUnixConnection(conn Conn, opts *options) (connection *UnixConnection, err error) { connection = &UnixConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialUnix acts like Dial for Unix networks. -// -// The network must be a Unix network name; see func Dial for details. -// -// If laddr is non-nil, it is used as the local address for the -// connection. -func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { +func dialUnix(network string, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { switch network { case "unix", "unixgram", "unixpacket": default: return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: net.UnknownNetworkError(network)} } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialUnix(context.Background(), laddr, raddr) + c, err := sd.dialUnix(context.Background(), laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConnection, error) { +// DialUnix acts like Dial for Unix networks. +// +// The network must be a Unix network name; see func Dial for details. +// +// If laddr is non-nil, it is used as the local address for the +// connection. +func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { + return dialUnix(network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { conn, err := unixSocket(ctx, sd.network, laddr, raddr, "dial") if err != nil { return nil, err } - return newUnixConnection(conn) + return newUnixConnection(conn, opts) } func unixSocket(ctx context.Context, network string, laddr, raddr sockaddr, mode string) (conn *netFD, err error) { diff --git a/netpoll_options.go b/netpoll_options.go index 2cdb1c13..04e4b780 100644 --- a/netpoll_options.go +++ b/netpoll_options.go @@ -111,17 +111,26 @@ func WithIdleTimeout(timeout time.Duration) Option { }} } +// WithReadBufferThreshold sets the max read buffer threshold. +// If connection already read the threshold bytes data, it will stop read more data. +func WithReadBufferThreshold(threshold int64) Option { + return Option{func(op *options) { + op.readBufferThreshold = threshold + }} +} + // Option . type Option struct { f func(*options) } type options struct { - onPrepare OnPrepare - onConnect OnConnect - onDisconnect OnDisconnect - onRequest OnRequest - readTimeout time.Duration - writeTimeout time.Duration - idleTimeout time.Duration + onPrepare OnPrepare + onConnect OnConnect + onDisconnect OnDisconnect + onRequest OnRequest + readTimeout time.Duration + writeTimeout time.Duration + idleTimeout time.Duration + readBufferThreshold int64 // bytes } diff --git a/netpoll_test.go b/netpoll_test.go index fb985604..503c2892 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -507,6 +507,178 @@ func TestClientWriteAndClose(t *testing.T) { MustNil(t, err) } +func TestReadThresholdOption(t *testing.T) { + /* + client => server: 102400 bytes + 5 bytes + server cached: 102400 bytes, and throttled + server read: 102400 bytes, and unthrottled + server cached: 5 bytes + server read: 5 bytes + server write: 102400 bytes + 5 bytes + client cached: 102400 bytes, and throttled + client read: 102400 bytes, and unthrottled + client cached: 5 bytes + client read: 5 bytes + */ + readThreshold := 1024 * 100 + trigger := make(chan struct{}) + msg1 := make([]byte, readThreshold) + msg2 := []byte("hello") + var wg sync.WaitGroup + + // server + ln, err := CreateListener("tcp", ":12345") + MustNil(t, err) + wg.Add(3) + svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + go func() { + defer wg.Done() + // server write + t.Logf("server writing msg1") + _, err := connection.Writer().WriteBinary(msg1) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("server writing msg2") + _, err = connection.Writer().WriteBinary(msg2) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + }() + + // server read + defer wg.Done() + t.Logf("server reading msg1") + trigger <- struct{}{} // let client send msg2 + time.Sleep(time.Millisecond * 100) // ensure client send msg2 + Equal(t, connection.Reader().Len(), readThreshold) + msg, err := connection.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("server reading msg2") + msg, err = connection.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + _, err = connection.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF), err) + t.Logf("server closed") + return nil + }, WithReadBufferThreshold(int64(readThreshold))) + defer svr.Shutdown(context.Background()) + go func() { + svr.Serve(ln) + }() + time.Sleep(time.Millisecond * 100) + + // client write + dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold))) + cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second) + MustNil(t, err) + go func() { + defer wg.Done() + t.Logf("client writing msg1") + _, err := cli.Writer().WriteBinary(msg1) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("client writing msg2") + _, err = cli.Writer().WriteBinary(msg2) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + }() + + // client read + trigger <- struct{}{} // let server send msg2 + time.Sleep(time.Millisecond * 100) // ensure server send msg2 + Equal(t, cli.Reader().Len(), readThreshold) + t.Logf("client reading msg1") + msg, err := cli.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("client reading msg2") + msg, err = cli.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + err = cli.Close() + MustNil(t, err) + t.Logf("client closed") + wg.Wait() +} + +func TestReadThresholdClosed(t *testing.T) { + /* + client => server: 102400 bytes + 5 bytes + client => server: close connection + server cached: 102400 bytes, and throttled + server read: 102400 bytes, and unthrottled + server cached: 5 bytes + server read: 5 bytes + */ + readThreshold := 1024 * 100 + trigger := make(chan struct{}) + msg1 := make([]byte, readThreshold) + msg2 := []byte("hello") + + // server + ln, err := CreateListener("tcp", ":12345") + MustNil(t, err) + svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + // server read + t.Logf("server reading msg1") + trigger <- struct{}{} // let client send msg2 + <-trigger // ensure client send msg2 and closed + for { + msg, err := connection.Reader().Next(1) + if errors.Is(err, ErrEOF) { + break + } + _ = msg + } + close(trigger) + return nil + }, WithReadBufferThreshold(int64(readThreshold))) + defer svr.Shutdown(context.Background()) + go func() { + svr.Serve(ln) + }() + time.Sleep(time.Millisecond * 100) + + // client write + dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold))) + cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second) + MustNil(t, err) + t.Logf("client writing msg1") + _, err = cli.Writer().WriteBinary(msg1) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("client writing msg2") + _, err = cli.Writer().WriteBinary(msg2) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + err = cli.Close() + MustNil(t, err) + t.Logf("client closed") + trigger <- struct{}{} + <-trigger +} + func createTestListener(network, address string) (Listener, error) { for { ln, err := CreateListener(network, address) diff --git a/nocopy.go b/nocopy.go index 80df5f9b..47ad2c6c 100644 --- a/nocopy.go +++ b/nocopy.go @@ -108,9 +108,9 @@ type Reader interface { // The usage of the design is a two-step operation, first apply for a section of memory, // fill it and then submit. E.g: // -// var buf, _ = Malloc(n) -// buf = append(buf[:0], ...) -// Flush() +// var buf, _ = Malloc(n) +// buf = append(buf[:0], ...) +// Flush() // // Note that it is not recommended to submit self-managed buffers to Writer. // Since the writer is processed asynchronously, if the self-managed buffer is used and recycled after submission, diff --git a/poll.go b/poll.go index c494ffd6..915b1f9d 100644 --- a/poll.go +++ b/poll.go @@ -48,19 +48,23 @@ type PollEvent int const ( // PollReadable is used to monitor whether the FDOperator registered by // listener and connection is readable or closed. - PollReadable PollEvent = 0x1 + PollReadable PollEvent = iota + 1 // PollWritable is used to monitor whether the FDOperator created by the dialer is writable or closed. // ET mode must be used (still need to poll hup after being writable) - PollWritable PollEvent = 0x2 + PollWritable // PollDetach is used to remove the FDOperator from poll. - PollDetach PollEvent = 0x3 + PollDetach // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. - PollR2RW PollEvent = 0x5 - + PollR2RW // PollRW2R is used to remove the writable monitor of FDOperator, generally used with PollR2RW. - PollRW2R PollEvent = 0x6 + PollRW2R + + // PollRW2W is used to remove the readable monitor of FDOperator. + PollRW2W + // PollW2RW is used to add the readable monitor of FDOperator, generally used with PollRW2W. + PollW2RW ) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 9c8aa8c9..33c0b52b 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -197,6 +197,10 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + case PollRW2W: + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + case PollW2RW: + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_bsd_test.go b/poll_default_bsd_test.go new file mode 100644 index 00000000..92185f9e --- /dev/null +++ b/poll_default_bsd_test.go @@ -0,0 +1,83 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build darwin +// +build darwin + +package netpoll + +import ( + "syscall" + "testing" +) + +func TestKqueueEvent(t *testing.T) { + kqfd, err := syscall.Kqueue() + defer syscall.Close(kqfd) + _, err = syscall.Kevent(kqfd, []syscall.Kevent_t{{ + Ident: 0, + Filter: syscall.EVFILT_USER, + Flags: syscall.EV_ADD | syscall.EV_CLEAR, + }}, nil, nil) + MustNil(t, err) + + rfd, wfd := GetSysFdPairs() + defer syscall.Close(rfd) + defer syscall.Close(wfd) + + // add read event + changes := make([]syscall.Kevent_t, 1) + changes[0].Ident = uint64(rfd) + changes[0].Filter = syscall.EVFILT_READ + changes[0].Flags = syscall.EV_ADD + _, err = syscall.Kevent(kqfd, changes, nil, nil) + MustNil(t, err) + + // write + send := []byte("hello") + recv := make([]byte, 5) + _, err = syscall.Write(wfd, send) + MustNil(t, err) + + // check readable + events := make([]syscall.Kevent_t, 128) + n, err := syscall.Kevent(kqfd, nil, events, nil) + MustNil(t, err) + Equal(t, n, 1) + Assert(t, events[0].Filter == syscall.EVFILT_READ) + // read + _, err = syscall.Read(rfd, recv) + MustNil(t, err) + Equal(t, string(recv), string(send)) + + // delete read + changes[0].Ident = uint64(rfd) + changes[0].Filter = syscall.EVFILT_READ + changes[0].Flags = syscall.EV_DELETE + _, err = syscall.Kevent(kqfd, changes, nil, nil) + MustNil(t, err) + + // write + _, err = syscall.Write(wfd, send) + MustNil(t, err) + + // check readable + n, err = syscall.Kevent(kqfd, nil, events, &syscall.Timespec{Sec: 1}) + MustNil(t, err) + Equal(t, n, 0) + // read + _, err = syscall.Read(rfd, recv) + MustNil(t, err) + Equal(t, string(recv), string(send)) +} diff --git a/poll_default_linux.go b/poll_default_linux.go index a0087ee0..8637a577 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -255,6 +255,10 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2R: // connection wait read op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollRW2W: + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollW2RW: + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR } return EpollCtl(p.fd, op, operator.FD, &evt) }