diff --git a/pkg/tcpip/transport/tcp/test/e2e/dual_stack_test.go b/pkg/tcpip/transport/tcp/test/e2e/dual_stack_test.go index ef50eda635..4c80181673 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/dual_stack_test.go @@ -18,6 +18,7 @@ import ( "os" "strings" "testing" + "testing/synctest" "time" "github.com/google/go-cmp/cmp" @@ -34,217 +35,253 @@ import ( ) func TestV4MappedConnectOnV6Only(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(true) + c.CreateV6Endpoint(true) - // Start connection attempt, it must fail. - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrHostUnreachable{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } + // Start connection attempt, it must fail. + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrHostUnreachable{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + } + }) } func TestV4MappedConnect(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Test the connection request. - e2e.TestV4Connect(t, c) + // Test the connection request. + e2e.TestV4Connect(t, c) + }) } func TestV4ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test the connection request. - e2e.TestV4Connect(t, c) + // Test the connection request. + e2e.TestV4Connect(t, c) + }) } func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test the connection request. - e2e.TestV4Connect(t, c) + // Test the connection request. + e2e.TestV4Connect(t, c) + }) } func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind to v4 mapped address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to v4 mapped address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test the connection request. - e2e.TestV4Connect(t, c) + // Test the connection request. + e2e.TestV4Connect(t, c) + }) } func TestV6Connect(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Test the connection request. - e2e.TestV6Connect(t, c) + // Test the connection request. + e2e.TestV6Connect(t, c) + }) } func TestV6ConnectV6Only(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(true) + c.CreateV6Endpoint(true) - // Test the connection request. - e2e.TestV6Connect(t, c) + // Test the connection request. + e2e.TestV6Connect(t, c) + }) } func TestV6ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test the connection request. - e2e.TestV6Connect(t, c) + // Test the connection request. + e2e.TestV6Connect(t, c) + }) } func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { - c := context.NewWithOpts(t, context.Options{ - EnableV6: true, - MTU: e2e.DefaultMTU, - }) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.NewWithOpts(t, context.Options{ + EnableV6: true, + MTU: e2e.DefaultMTU, + }) + defer c.Cleanup() - // Create a v6 endpoint but don't set the v6-only TCP option. - c.CreateV6Endpoint(false) + // Create a v6 endpoint but don't set the v6-only TCP option. + c.CreateV6Endpoint(false) - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test the connection request. - e2e.TestV6Connect(t, c) + // Test the connection request. + e2e.TestV6Connect(t, c) + }) } func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind to local address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to local address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test the connection request. - e2e.TestV6Connect(t, c) + // Test the connection request. + e2e.TestV6Connect(t, c) + }) } func TestV4RefuseOnV6Only(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(true) + c.CreateV6Endpoint(true) - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) - // Receive the RST reply. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) + // Receive the RST reply. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPAckNum(uint32(irs)+1), + ), + ) + }) } func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) - // Receive the RST reply. - p := c.GetV6Packet() - defer p.Release() - checker.IPv6(t, p, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) + // Receive the RST reply. + p := c.GetV6Packet() + defer p.Release() + checker.IPv6(t, p, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPAckNum(uint32(irs)+1), + ), + ) + }) } func testV4Accept(t *testing.T, c *context.Context) { @@ -259,11 +296,11 @@ func testV4Accept(t *testing.T, c *context.Context) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -282,12 +319,12 @@ func testV4Accept(t *testing.T, c *context.Context) { // Send ACK. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, }) // Try to accept the connection. @@ -333,142 +370,157 @@ func testV4Accept(t *testing.T, c *context.Context) { } func TestV4AcceptOnV6(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test acceptance. - testV4Accept(t, c) + // Test acceptance. + testV4Accept(t, c) + }) } func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test acceptance. - testV4Accept(t, c) + // Test acceptance. + testV4Accept(t, c) + }) } func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test acceptance. - testV4Accept(t, c) + // Test acceptance. + testV4Accept(t, c) + }) } func TestV6AcceptOnV6(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) - // Receive the SYN-ACK reply. - v := c.GetV6Packet() - defer v.Release() - tcp := header.TCP(header.IPv6(v.AsSlice()).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv6(t, v, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - ), - ) + // Receive the SYN-ACK reply. + v := c.GetV6Packet() + defer v.Release() + tcp := header.TCP(header.IPv6(v.AsSlice()).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + checker.IPv6(t, v, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + ), + ) - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - var addr tcpip.FullAddress - _, _, err := c.EP.Accept(&addr) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(&addr) - if err != nil { - t.Fatalf("Accept failed: %v", err) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) + var addr tcpip.FullAddress + _, _, err := c.EP.Accept(&addr) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept(&addr) + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } - } - if addr.Addr != context.TestV6Addr { - t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) - } + if addr.Addr != context.TestV6Addr { + t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) + } + }) } func TestV4AcceptOnV4(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test acceptance. - testV4Accept(t, c) + // Test acceptance. + testV4Accept(t, c) + }) } func testV4ListenClose(t *testing.T, c *context.Context) { @@ -488,11 +540,11 @@ func testV4ListenClose(t *testing.T, c *context.Context) { for i := uint16(0); i < n; i++ { // Send a SYN request. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + i, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, + SrcPort: context.TestPort + i, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, }) } @@ -505,12 +557,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) { iss := seqnum.Value(tcp.SequenceNumber()) // Send ACK. c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, + SrcPort: tcp.DestinationPort(), + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, }) } @@ -537,31 +589,31 @@ func testV4ListenClose(t *testing.T, c *context.Context) { } func TestV4ListenCloseOnV4(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } - // Test acceptance. - testV4ListenClose(t, c) + // Test acceptance. + testV4ListenClose(t, c) + }) } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) } diff --git a/pkg/tcpip/transport/tcp/test/e2e/forwarder_test.go b/pkg/tcpip/transport/tcp/test/e2e/forwarder_test.go index 5ef83c08a2..624a33ca4d 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/forwarder_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/forwarder_test.go @@ -17,6 +17,7 @@ package forwarder_test import ( "os" "testing" + "testing/synctest" "time" "gvisor.dev/gvisor/pkg/atomicbitops" @@ -31,43 +32,46 @@ import ( ) func TestForwarderSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - close(ch) - r.Complete(false) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + s := c.Stack() + ch := make(chan tcpip.Error, 1) + f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { + var err tcpip.Error + c.EP, err = r.CreateEndpoint(&c.WQ) + ch <- err + close(ch) + r.Complete(false) + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %s", err) + // Wait for connection to be available. + select { + case err := <-ch: + if err != nil { + t.Fatalf("Error creating endpoint: %s", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("Timed out waiting for connection") } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - // Check that data gets properly segmented. - e2e.CheckBrokenUpWrite(t, c, maxPayload) + // Check that data gets properly segmented. + e2e.CheckBrokenUpWrite(t, c, maxPayload) + }) } func TestForwarderDoesNotRejectECNFlags(t *testing.T) { testCases := []struct { - name string - flags header.TCPFlags + name string + flags header.TCPFlags }{ {name: "non-setup ECN SYN w/ ECE", flags: header.TCPFlagEce}, {name: "non-setup ECN SYN w/ CWR", flags: header.TCPFlagCwr}, @@ -76,146 +80,152 @@ func TestForwarderDoesNotRejectECNFlags(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - close(ch) - r.Complete(false) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, Flags: tc.flags}) - - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %s", err) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + s := c.Stack() + ch := make(chan tcpip.Error, 1) + f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { + var err tcpip.Error + c.EP, err = r.CreateEndpoint(&c.WQ) + ch <- err + close(ch) + r.Complete(false) + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, Flags: tc.flags}) + + // Wait for connection to be available. + select { + case err := <-ch: + if err != nil { + t.Fatalf("Error creating endpoint: %s", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("Timed out waiting for connection") } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } + }) }) } } func TestForwarderFailedConnect(t *testing.T) { - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - close(ch) - r.Complete(false) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Initiate a connection that will be forwarded by the Forwarder. - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + s := c.Stack() + ch := make(chan tcpip.Error, 1) + f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { + var err tcpip.Error + c.EP, err = r.CreateEndpoint(&c.WQ) + ch <- err + close(ch) + r.Complete(false) + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - // Receive the SYN-ACK reply. Make sure MSS and other expected options - // are present. - v := c.GetPacket() - defer v.Release() - tcp := header.TCP(header.IPv4(v.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(iss) + 1), - } - checker.IPv4(t, v, checker.TCP(tcpCheckers...)) - - // Now send an active RST to abort the handshake. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: iss + 1, - RcvWnd: 0, - }) + // Initiate a connection that will be forwarded by the Forwarder. + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - // Wait for connect to fail. - select { - case err := <-ch: - if err == nil { - t.Fatalf("endpoint creation should have failed") + // Receive the SYN-ACK reply. Make sure MSS and other expected options + // are present. + v := c.GetPacket() + defer v.Release() + tcp := header.TCP(header.IPv4(v.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(iss) + 1), } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection to fail") - } -} + checker.IPv4(t, v, checker.TCP(tcpCheckers...)) -func TestForwarderDroppedStats(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - const maxInFlight = 2 - iters := atomicbitops.FromInt64(maxInFlight) - s := c.Stack() - checkedStats := make(chan struct{}) - done := make(chan struct{}) - f := tcp.NewForwarder(s, 65536, maxInFlight, func(r *tcp.ForwarderRequest) { - <-checkedStats - // Complete all requests without doing anything - r.Complete(false) - if iter := iters.Add(-1); iter == 0 { - close(done) + // Now send an active RST to abort the handshake. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + RcvWnd: 0, + }) + + // Wait for connect to fail. + select { + case err := <-ch: + if err == nil { + t.Fatalf("endpoint creation should have failed") + } + case <-time.After(2 * time.Second): + t.Fatalf("Timed out waiting for connection to fail") } }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) +} - for i := 0; i < maxInFlight+1; i++ { - iss := seqnum.Value(context.TestInitialSequenceNumber + i) - c.SendPacket(nil, &context.Headers{ - SrcPort: uint16(context.TestPort + i), - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, +func TestForwarderDroppedStats(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + const maxInFlight = 2 + iters := atomicbitops.FromInt64(maxInFlight) + s := c.Stack() + checkedStats := make(chan struct{}) + done := make(chan struct{}) + f := tcp.NewForwarder(s, 65536, maxInFlight, func(r *tcp.ForwarderRequest) { + <-checkedStats + // Complete all requests without doing anything + r.Complete(false) + if iter := iters.Add(-1); iter == 0 { + close(done) + } }) - } + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) + + for i := 0; i < maxInFlight+1; i++ { + iss := seqnum.Value(context.TestInitialSequenceNumber + i) + c.SendPacket(nil, &context.Headers{ + SrcPort: uint16(context.TestPort + i), + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + } - // Verify that we got one ignored packet. - if curr := s.Stats().TCP.ForwardMaxInFlightDrop.Value(); curr != 1 { - t.Errorf("Expected one dropped connection, but got %d", curr) - } - close(checkedStats) - <-done + // Verify that we got one ignored packet. + if curr := s.Stats().TCP.ForwardMaxInFlightDrop.Value(); curr != 1 { + t.Errorf("Expected one dropped connection, but got %d", curr) + } + close(checkedStats) + <-done + }) } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) } diff --git a/pkg/tcpip/transport/tcp/test/e2e/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/test/e2e/sack_scoreboard_test.go index 5d931b49e8..df5bec7cd2 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/sack_scoreboard_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/sack_scoreboard_test.go @@ -17,7 +17,7 @@ package sack_scoreboard_test import ( "os" "testing" - "time" + "testing/synctest" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -36,227 +36,236 @@ func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreb } func TestSACKScoreboardIsSACKED(t *testing.T) { - type blockTest struct { - block header.SACKBlock - sacked bool - } - testCases := []struct { - comment string - scoreboardBlocks []header.SACKBlock - blockTests []blockTest - iss seqnum.Value - }{ - { - "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", - []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, - []blockTest{ - {header.SACKBlock{15, 21}, true}, - {header.SACKBlock{200, 201}, false}, - {header.SACKBlock{50, 51}, false}, - {header.SACKBlock{53, 120}, true}, + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + type blockTest struct { + block header.SACKBlock + sacked bool + } + testCases := []struct { + comment string + scoreboardBlocks []header.SACKBlock + blockTests []blockTest + iss seqnum.Value + }{ + { + "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", + []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, + []blockTest{ + {header.SACKBlock{15, 21}, true}, + {header.SACKBlock{200, 201}, false}, + {header.SACKBlock{50, 51}, false}, + {header.SACKBlock{53, 120}, true}, + }, + 0, }, - 0, - }, - { - "Test disjoint SACKBlocks", - []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, - []blockTest{ - {header.SACKBlock{2288624809, 2288810057}, true}, - {header.SACKBlock{2288811477, 2288838565}, true}, - {header.SACKBlock{2288810057, 2288811477}, false}, + { + "Test disjoint SACKBlocks", + []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, + []blockTest{ + {header.SACKBlock{2288624809, 2288810057}, true}, + {header.SACKBlock{2288811477, 2288838565}, true}, + {header.SACKBlock{2288810057, 2288811477}, false}, + }, + 2288624809, }, - 2288624809, - }, - { - "Test sequence number wrap around", - []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, - []blockTest{ - {header.SACKBlock{4294254144, 4294254145}, true}, - {header.SACKBlock{4294254143, 4294254144}, false}, - {header.SACKBlock{4294254144, 1}, true}, - {header.SACKBlock{225652, 5350509}, false}, - {header.SACKBlock{5340409, 5350509}, true}, - {header.SACKBlock{5350509, 5350609}, false}, + { + "Test sequence number wrap around", + []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, + []blockTest{ + {header.SACKBlock{4294254144, 4294254145}, true}, + {header.SACKBlock{4294254143, 4294254144}, false}, + {header.SACKBlock{4294254144, 1}, true}, + {header.SACKBlock{225652, 5350509}, false}, + {header.SACKBlock{5340409, 5350509}, true}, + {header.SACKBlock{5350509, 5350609}, false}, + }, + 4294254144, }, - 4294254144, - }, - { - "Test disjoint SACKBlocks out of order", - []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, - []blockTest{ - {header.SACKBlock{827426028, 827428867}, true}, - {header.SACKBlock{827450168, 827450275}, false}, + { + "Test disjoint SACKBlocks out of order", + []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, + []blockTest{ + {header.SACKBlock{827426028, 827428867}, true}, + {header.SACKBlock{827450168, 827450275}, false}, + }, + 827426000, }, - 827426000, - }, - } - for _, tc := range testCases { - sb := initScoreboard(tc.scoreboardBlocks, tc.iss) - for _, blkTest := range tc.blockTests { - if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { - t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) + } + for _, tc := range testCases { + sb := initScoreboard(tc.scoreboardBlocks, tc.iss) + for _, blkTest := range tc.blockTests { + if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { + t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) + } } } - } + }) } func TestSACKScoreboardIsRangeLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - block header.SACKBlock - lost bool - }{ - // Block not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number covered by this block. - {block: header.SACKBlock{0, 1}, lost: true}, + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + s := tcp.NewSACKScoreboard(10, 0) + s.Insert(header.SACKBlock{1, 25}) + s.Insert(header.SACKBlock{25, 50}) + s.Insert(header.SACKBlock{51, 100}) + s.Insert(header.SACKBlock{111, 120}) + s.Insert(header.SACKBlock{101, 110}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{145, 146}) + s.Insert(header.SACKBlock{147, 148}) + s.Insert(header.SACKBlock{149, 150}) + s.Insert(header.SACKBlock{153, 154}) + s.Insert(header.SACKBlock{155, 156}) + testCases := []struct { + block header.SACKBlock + lost bool + }{ + // Block not covered by SACK block and has more than + // nDupAckThreshold discontiguous SACK blocks after it as well + // as (nDupAckThreshold -1) * 10 (smss) bytes that have been + // SACKED above the sequence number covered by this block. + {block: header.SACKBlock{0, 1}, lost: true}, - // These blocks have all been SACKed and should not be - // considered lost. - {block: header.SACKBlock{1, 2}, lost: false}, - {block: header.SACKBlock{25, 26}, lost: false}, - {block: header.SACKBlock{1, 45}, lost: false}, + // These blocks have all been SACKed and should not be + // considered lost. + {block: header.SACKBlock{1, 2}, lost: false}, + {block: header.SACKBlock{25, 26}, lost: false}, + {block: header.SACKBlock{1, 45}, lost: false}, - // Same as the first case above. - {block: header.SACKBlock{50, 51}, lost: true}, + // Same as the first case above. + {block: header.SACKBlock{50, 51}, lost: true}, - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{119, 120}, lost: false}, + // This block has been SACKed and should not be considered lost. + {block: header.SACKBlock{119, 120}, lost: false}, - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {block: header.SACKBlock{120, 121}, lost: true}, + // This one should return true because there are > + // (nDupAckThreshold - 1) * 10 (smss) bytes that have been + // sacked above this sequence number. + {block: header.SACKBlock{120, 121}, lost: true}, - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{125, 126}, lost: false}, + // This block has been SACKed and should not be considered lost. + {block: header.SACKBlock{125, 126}, lost: false}, - // This block has not been SACKed and there are nDupAckThreshold - // number of SACKed blocks after it. - {block: header.SACKBlock{141, 145}, lost: true}, + // This block has not been SACKed and there are nDupAckThreshold + // number of SACKed blocks after it. + {block: header.SACKBlock{141, 145}, lost: true}, - // This block has not been SACKed and there are less than - // nDupAckThreshold SACKed sequences after it. - {block: header.SACKBlock{151, 152}, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { - t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) + // This block has not been SACKed and there are less than + // nDupAckThreshold SACKed sequences after it. + {block: header.SACKBlock{151, 152}, lost: false}, } - } + for _, tc := range testCases { + if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { + t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) + } + } + }) } func TestSACKScoreboardIsLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - seq seqnum.Value - lost bool - }{ - // Sequence number not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number. - {seq: 0, lost: true}, + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + s := tcp.NewSACKScoreboard(10, 0) + s.Insert(header.SACKBlock{1, 25}) + s.Insert(header.SACKBlock{25, 50}) + s.Insert(header.SACKBlock{51, 100}) + s.Insert(header.SACKBlock{111, 120}) + s.Insert(header.SACKBlock{101, 110}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{145, 146}) + s.Insert(header.SACKBlock{147, 148}) + s.Insert(header.SACKBlock{149, 150}) + s.Insert(header.SACKBlock{153, 154}) + s.Insert(header.SACKBlock{155, 156}) + testCases := []struct { + seq seqnum.Value + lost bool + }{ + // Sequence number not covered by SACK block and has more than + // nDupAckThreshold discontiguous SACK blocks after it as well + // as (nDupAckThreshold -1) * 10 (smss) bytes that have been + // SACKED above the sequence number. + {seq: 0, lost: true}, - // These sequence numbers have all been SACKed and should not be - // considered lost. - {seq: 1, lost: false}, - {seq: 25, lost: false}, - {seq: 45, lost: false}, + // These sequence numbers have all been SACKed and should not be + // considered lost. + {seq: 1, lost: false}, + {seq: 25, lost: false}, + {seq: 45, lost: false}, - // Same as first case above. - {seq: 50, lost: true}, + // Same as first case above. + {seq: 50, lost: true}, - // This block has been SACKed and should not be considered lost. - {seq: 119, lost: false}, + // This block has been SACKed and should not be considered lost. + {seq: 119, lost: false}, - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {seq: 120, lost: true}, + // This one should return true because there are > + // (nDupAckThreshold - 1) * 10 (smss) bytes that have been + // sacked above this sequence number. + {seq: 120, lost: true}, - // This sequence number has been SACKed and should not be - // considered lost. - {seq: 125, lost: false}, + // This sequence number has been SACKed and should not be + // considered lost. + {seq: 125, lost: false}, - // This sequence number has not been SACKed and there are - // nDupAckThreshold number of SACKed blocks after it. - {seq: 141, lost: true}, + // This sequence number has not been SACKed and there are + // nDupAckThreshold number of SACKed blocks after it. + {seq: 141, lost: true}, - // This sequence number has not been SACKed and there are less - // than nDupAckThreshold SACKed sequences after it. - {seq: 151, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.seq); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) + // This sequence number has not been SACKed and there are less + // than nDupAckThreshold SACKed sequences after it. + {seq: 151, lost: false}, } - } + for _, tc := range testCases { + if want, got := tc.lost, s.IsLost(tc.seq); got != want { + t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) + } + } + }) } func TestSACKScoreboardDelete(t *testing.T) { - blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} - s := initScoreboard(blocks, 4294254143) - s.Delete(5340408) - if s.Empty() { - t.Fatalf("s.Empty() = true, want false") - } - if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } - s.Delete(5340410) - if s.Empty() { - t.Fatal("s.Empty() = true, want false") - } - newSB := header.SACKBlock{5340410, 5350509} - if !s.IsSACKED(newSB) { - t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) - } - s.Delete(5350509) - lastOctet := header.SACKBlock{5350508, 5350509} - if s.IsSACKED(lastOctet) { - t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} + s := initScoreboard(blocks, 4294254143) + s.Delete(5340408) + if s.Empty() { + t.Fatalf("s.Empty() = true, want false") + } + if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { + t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) + } + s.Delete(5340410) + if s.Empty() { + t.Fatal("s.Empty() = true, want false") + } + newSB := header.SACKBlock{5340410, 5350509} + if !s.IsSACKED(newSB) { + t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) + } + s.Delete(5350509) + lastOctet := header.SACKBlock{5350508, 5350509} + if s.IsSACKED(lastOctet) { + t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) + } - s.Delete(5350510) - if !s.Empty() { - t.Fatal("s.Empty() = false, want true") - } - if got, want := s.Sacked(), seqnum.Size(0); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } + s.Delete(5350510) + if !s.Empty() { + t.Fatal("s.Empty() = false, want true") + } + if got, want := s.Sacked(), seqnum.Size(0); got != want { + t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) + } + }) } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) } diff --git a/pkg/tcpip/transport/tcp/test/e2e/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/test/e2e/tcp_noracedetector_test.go index 1643cdb627..4883f132ae 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/tcp_noracedetector_test.go @@ -24,6 +24,7 @@ import ( "math" "os" "testing" + "testing/synctest" "time" "gvisor.dev/gvisor/pkg/refs" @@ -36,305 +37,314 @@ import ( ) func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) + for i := range data { + data[i] = byte(i) } - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } - // Wait before checking metrics. - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + for i := 0; i < 3; i++ { + c.SendAck(790, rtxOffset) } - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) + // Receive the retransmitted packet. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Wait before checking metrics. + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + } + + if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } + // Now send 7 mode duplicate acks. Each of these should cause a window + // inflation by 1 and cause the sender to send an extra packet. + for i := 0; i < 7; i++ { + c.SendAck(790, rtxOffset) + } - recover := bytesRead + recover := bytesRead - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + // Receive the retransmit due to partial ack. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - // Wait before checking metrics. - metricPollFn = func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + // Wait before checking metrics. + metricPollFn = func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + } + return nil } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { + + // Receive the 10 extra packets that should have been released due to + // the congestion window inflation in recovery. + for i := 0; i < 10; i++ { c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) bytesRead += maxPayload } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) + // A partial ACK during recovery should reduce congestion window by the + // number acked. Since we had "expected" packets outstanding before sending + // partial ack and we acked expected/2 , the cwnd and outstanding should + // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered + // fast recovery). Which means the sender should not send any more packets + // till we ack this one. + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", + 50*time.Millisecond) - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) + // Acknowledge all pending data to recover point. + c.SendAck(790, recover) + + // At this point, the cwnd should reset to expected/2 and there are 10 + // packets outstanding. + // + // NOTE: Technically netstack is incorrect in that we adjust the cwnd on + // the same segment that takes us out of recovery. But because of that + // the actual cwnd at exit of recovery will be expected/2 + 1 as we + // acked a cwnd worth of packets which will increase the cwnd further by + // 1 in congestion avoidance. + // + // Now in the first iteration since there are 10 packets outstanding. + // We would expect to get expected/2 +1 - 10 packets. But subsequent + // iterations will send us expected/2 + 1 + 1 (per iteration). + expected = expected/2 + 1 - 10 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 10 + } + expected++ } - expected++ - } + }) } func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) + for i := range data { + data[i] = byte(i) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) - // Double the number of expected packets for the next iteration. - expected *= 2 - } + // Double the number of expected packets for the next iteration. + expected *= 2 + } + }) } func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) + for i := range data { + data[i] = byte(i) } - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd/2. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) // Acknowledge all the data received so far. c.SendAck(790, bytesRead) - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected/2. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 (which "consumes" expected/2-1 of the + // acknowledgements), then the congestion avoidance part will consume + // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack + // remains in the "ack count" (which will cause cwnd to be incremented + // once it reaches cwnd acks). + // + // So we're straight into congestion avoidance with cwnd set to + // expected/2 + 1. + // + // Check that packets trains of cwnd packets are sent, and that cwnd is + // incremented by 1 after we acknowledge each packet. + expected = expected/2 + 1 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + expected++ + } + }) } // cubicCwnd returns an estimate of a cubic window given the @@ -351,219 +361,222 @@ func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Durati } func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - e2e.EnableCUBIC(t, c) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + e2e.EnableCUBIC(t, c) + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) + for i := range data { + data[i] = byte(i) } - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - // Acknowledge all the data received so far. + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd * 0.7. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all pending data. c.SendAck(790, bytesRead) - } + + // Store away the time we sent the ACK and assuming a 200ms RTO + // we estimate that the sender will have an RTO 200ms from now + // and go back into slow start. + packetDropTime := time.Now().Add(200 * time.Millisecond) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 essentially putting the connection + // straight into congestion avoidance. + wMax := expected + // Lower expected as per cubic spec after a congestion event. + expected = int(float64(expected) * 0.7) + cwnd := expected + for i := 0; i < iterations; i++ { + // Cubic grows window independent of ACKs. Cubic Window growth + // is a function of time elapsed since last congestion event. + // As a result the congestion window does not grow + // deterministically in response to ACKs. + // + // We need to roughly estimate what the cwnd of the sender is + // based on when we sent the dupacks. + cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) + + packetsExpected := cwnd + for j := 0; j < packetsExpected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + t.Logf("expected packets received, next trying to receive any extra packets that may come") + + // If our estimate was correct there should be no more pending packets. + // We attempt to read a packet a few times with a short sleep in between + // to ensure that we don't see the sender send any unexpected packets. + unexpectedPackets := 0 + for { + gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) + if !gotPacket { + break + } + bytesRead += maxPayload + unexpectedPackets++ + time.Sleep(1 * time.Millisecond) + } + if unexpectedPackets != 0 { + t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + } + }) } func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data[:len(data)/2]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - r.Reset(data[len(data)/2:]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) + for i := range data { + data[i] = byte(i) } - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload + // Write all the data in two shots. Packets will only be written at the + // MTU size though. + var r bytes.Reader + r.Reset(data[:len(data)/2]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + r.Reset(data[len(data)/2:]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) - } + // Wait for a timeout and retransmit. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) - } + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) - } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) + } - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) - } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) + } - return nil - } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) + } - // Poll when checking metrics. - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) + } - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) + return nil + } - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } + // Poll when checking metrics. + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + // Receive the remaining data, making sure that acknowledged data is not + // retransmitted. + for offset := rtxOffset; offset < len(data); offset += maxPayload { + c.ReceiveAndCheckPacket(data, offset, maxPayload) + c.SendAck(790, offset+maxPayload) + } + + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + }) } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) } diff --git a/pkg/tcpip/transport/tcp/test/e2e/tcp_rack_test.go b/pkg/tcpip/transport/tcp/test/e2e/tcp_rack_test.go index 9204d2a7f9..8b407c620c 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/tcp_rack_test.go @@ -19,6 +19,7 @@ import ( "fmt" "os" "testing" + "testing/synctest" "time" "gvisor.dev/gvisor/pkg/buffer" @@ -33,131 +34,137 @@ import ( ) const ( - maxPayload = 10 - maxTCPOptionSize = 40 - mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload + maxPayload = 10 + maxTCPOptionSize = 40 + mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload ) // TestRACKUpdate tests the RACK related fields are updated when an ACK is // received on a SACK enabled connection. func TestRACKUpdate(t *testing.T) { - var xmitTime tcpip.MonotonicTime - probeDone := make(chan struct{}) - probe := func(state *tcp.TCPEndpointState) { - // Validate that the endpoint Sender.RACKState is what we expect. - if state.Sender.RACKState.XmitTime.Before(xmitTime) { - t.Fatalf("RACK transmit time failed to update when an ACK is received") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + var xmitTime tcpip.MonotonicTime + probeDone := make(chan struct{}) + probe := func(state *tcp.TCPEndpointState) { + // Validate that the endpoint Sender.RACKState is what we expect. + if state.Sender.RACKState.XmitTime.Before(xmitTime) { + t.Fatalf("RACK transmit time failed to update when an ACK is received") + } - gotSeq := state.Sender.RACKState.EndSequence - wantSeq := state.Sender.SndNxt - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } + gotSeq := state.Sender.RACKState.EndSequence + wantSeq := state.Sender.SndNxt + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } - if state.Sender.RACKState.RTT == 0 { - t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") + if state.Sender.RACKState.RTT == 0 { + t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") + } + close(probeDone) } - close(probeDone) - } - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() - e2e.SetStackSACKPermitted(t, c, true) - e2e.CreateConnectedWithSACKAndTS(c) + e2e.SetStackSACKPermitted(t, c, true) + e2e.CreateConnectedWithSACKAndTS(c) - data := make([]byte, maxPayload) - for i := range data { - data[i] = byte(i) - } + data := make([]byte, maxPayload) + for i := range data { + data[i] = byte(i) + } - // Write the data. - xmitTime = c.Stack().Clock().NowMonotonic() - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Write the data. + xmitTime = c.Stack().Clock().NowMonotonic() + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - bytesRead := 0 - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - bytesRead += maxPayload - c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) + bytesRead := 0 + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + bytesRead += maxPayload + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone + }) } // TestRACKDetectReorder tests that RACK detects packet reordering. func TestRACKDetectReorder(t *testing.T) { - t.Skipf("Skipping this test as reorder detection does not consider DSACK.") - - var n int - const ackNumToVerify = 2 - probeDone := make(chan struct{}) - probe := func(state *tcp.TCPEndpointState) { - gotSeq := state.Sender.RACKState.FACK - wantSeq := state.Sender.SndNxt - // FACK should be updated to the highest ending sequence number of the - // segment acknowledged most recently. - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + t.Skipf("Skipping this test as reorder detection does not consider DSACK.") + + var n int + const ackNumToVerify = 2 + probeDone := make(chan struct{}) + probe := func(state *tcp.TCPEndpointState) { + gotSeq := state.Sender.RACKState.FACK + wantSeq := state.Sender.SndNxt + // FACK should be updated to the highest ending sequence number of the + // segment acknowledged most recently. + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } - n++ - if n < ackNumToVerify { - if state.Sender.RACKState.Reord { - t.Fatalf("RACK reorder detected when there is no reordering") + n++ + if n < ackNumToVerify { + if state.Sender.RACKState.Reord { + t.Fatalf("RACK reorder detected when there is no reordering") + } + return } - return - } - if state.Sender.RACKState.Reord == false { - t.Fatalf("RACK reorder detection failed") + if state.Sender.RACKState.Reord == false { + t.Fatalf("RACK reorder detection failed") + } + close(probeDone) } - close(probeDone) - } - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() - e2e.SetStackSACKPermitted(t, c, true) - e2e.CreateConnectedWithSACKAndTS(c) - data := make([]byte, ackNumToVerify*maxPayload) - for i := range data { - data[i] = byte(i) - } + e2e.SetStackSACKPermitted(t, c, true) + e2e.CreateConnectedWithSACKAndTS(c) + data := make([]byte, ackNumToVerify*maxPayload) + for i := range data { + data[i] = byte(i) + } - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - bytesRead := 0 - for i := 0; i < ackNumToVerify; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - bytesRead += maxPayload - } + bytesRead := 0 + for i := 0; i < ackNumToVerify; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + bytesRead += maxPayload + } - start := c.IRS.Add(maxPayload + 1) - end := start.Add(maxPayload) - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - c.SendAck(seq, bytesRead) + start := c.IRS.Add(maxPayload + 1) + end := start.Add(maxPayload) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + c.SendAck(seq, bytesRead) - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone + }) } const ( - validDSACKDetected = 1 - failedToDetectDSACK = 2 - invalidDSACKDetected = 3 + validDSACKDetected = 1 + failedToDetectDSACK = 2 + invalidDSACKDetected = 3 ) func dsackSeenCheckerProbe(t *testing.T, numACK int, probeDone chan int) tcp.TCPProbeFunc { @@ -184,676 +191,712 @@ func dsackSeenCheckerProbe(t *testing.T, numACK int, probeDone chan int) tcp.TCP // case of a tail loss. This simulates a situation where the TLP is able to // insinuate the SACK holes and sender is able to retransmit the rest. func TestRACKTLPRecovery(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, e2e.TSOptionSize) - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - // Send the SACK after RTT because RACK RFC states that if the ACK for a - // retransmission arrives before the smoothed RTT then the sender should not - // update RACK state as it could be a spurious inference. - time.Sleep(info.RTT) - - // Okay, let the sender know we got #8 using a SACK block. - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // The sender should be entering RACK based loss-recovery and sending #6 and - // #7 one after another. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - bytesRead += maxPayload - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - bytesRead += 2 * maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // One fast retransmit after the SACK. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - // Recovery should be SACK recovery. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // Packets 6, 7 and 8 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 3}, - // TLP recovery should have been detected. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 1}, - // No RTOs should have occurred. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + // Send 8 packets. + numPackets := 8 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Packets [6-8] are lost. Send cumulative ACK for [1-5]. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // PTO should fire and send #8 packet as a TLP. + c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, e2e.TSOptionSize) + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + + // Send the SACK after RTT because RACK RFC states that if the ACK for a + // retransmission arrives before the smoothed RTT then the sender should not + // update RACK state as it could be a spurious inference. + time.Sleep(info.RTT) + + // Okay, let the sender know we got #8 using a SACK block. + eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) + eighthPEnd := eighthPStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) + + // The sender should be entering RACK based loss-recovery and sending #6 and + // #7 one after another. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + bytesRead += maxPayload + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + bytesRead += 2 * maxPayload + c.SendAck(seq, bytesRead) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // One fast retransmit after the SACK. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + // Recovery should be SACK recovery. + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + // Packets 6, 7 and 8 were retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 3}, + // TLP recovery should have been detected. + {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 1}, + // No RTOs should have occurred. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestRACKTLPFallbackRTO tests that RACK sends a tail loss probe (TLP) in the // case of a tail loss. This simulates a situation where either the TLP or its // ACK is lost. The sender should retransmit when RTO fires. func TestRACKTLPFallbackRTO(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, e2e.TSOptionSize) - - // Either the TLP or the ACK the receiver sent with SACK blocks was lost. - - // Confirm that RTO fires and retransmits packet #6. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // No fast retransmits happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - // No SACK recovery happened. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - // TLP was unsuccessful. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + // Send 8 packets. + numPackets := 8 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Packets [6-8] are lost. Send cumulative ACK for [1-5]. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // PTO should fire and send #8 packet as a TLP. + c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, e2e.TSOptionSize) + + // Either the TLP or the ACK the receiver sent with SACK blocks was lost. + + // Confirm that RTO fires and retransmits packet #6. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // No fast retransmits happened. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, + // No SACK recovery happened. + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, + // TLP was unsuccessful. + {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, + // RTO should have fired. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestNoTLPRecoveryOnDSACK tests the scenario where the sender speculates a // tail loss and sends a TLP. Everything is received and acked. The probe // segment is DSACKed. No fast recovery should be triggered in this case. func TestNoTLPRecoveryOnDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Packets [1-5] are received first. [6-8] took a detour and will take a - // while to arrive. Ack [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // The tail loss probe (#8 packet) is received. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, e2e.TSOptionSize) - - // Now that all 8 packets are received + duplicate 8th packet, send ack. - bytesRead += 3 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Wait for RTO and make sure that nothing else is received. - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - var p *buffer.View - if p = c.GetPacketWithTimeout(info.RTO); p != nil { - t.Errorf("received an unexpected packet: %v", p) - p.Release() - } - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Make sure no recovery was entered. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #8 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + // Send 8 packets. + numPackets := 8 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Packets [1-5] are received first. [6-8] took a detour and will take a + // while to arrive. Ack [1-5]. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // The tail loss probe (#8 packet) is received. + c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, e2e.TSOptionSize) + + // Now that all 8 packets are received + duplicate 8th packet, send ack. + bytesRead += 3 * maxPayload + eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) + eighthPEnd := eighthPStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) + + // Wait for RTO and make sure that nothing else is received. + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + var p *buffer.View + if p = c.GetPacketWithTimeout(info.RTO); p != nil { + t.Errorf("received an unexpected packet: %v", p) + p.Release() } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // Make sure no recovery was entered. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, + {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, + // RTO should not have fired. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, + // Only #8 was retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestNoTLPOnSACK tests the scenario where there is not exactly a tail loss // due to the presence of multiple SACK holes. In such a scenario, TLP should // not be sent. func TestNoTLPOnSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Packets [1-5] and #7 were received. #6 and #8 were dropped. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhEnd := seventhStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhStart, seventhEnd}}) - - // The sender should retransmit #6. If the sender sends a TLP, then #8 will - // received and fail this test. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, e2e.TSOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #6 was retransmitted due to SACK recovery. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #6 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + // Send 8 packets. + numPackets := 8 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Packets [1-5] and #7 were received. #6 and #8 were dropped. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + seventhStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + seventhEnd := seventhStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhStart, seventhEnd}}) + + // The sender should retransmit #6. If the sender sends a TLP, then #8 will + // received and fail this test. + c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, e2e.TSOptionSize) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // #6 was retransmitted due to SACK recovery. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, + // RTO should not have fired. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, + // Only #6 was retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestRACKOnePacketTailLoss tests the trivial case of a tail loss of only one // packet. The probe should itself repairs the loss instead of having to go // into any recovery. func TestRACKOnePacketTailLoss(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 3 packets. - numPackets := 3 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Packets [1-2] are received. #3 is lost. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 2 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #3 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, e2e.TSOptionSize) - bytesRead += maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #3 was retransmitted as TLP. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #3 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + // Send 3 packets. + numPackets := 3 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Packets [1-2] are received. #3 is lost. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 2 * maxPayload + c.SendAck(seq, bytesRead) + + // PTO should fire and send #3 packet as a TLP. + c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, e2e.TSOptionSize) + bytesRead += maxPayload + c.SendAck(seq, bytesRead) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // #3 was retransmitted as TLP. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, + // RTO should not have fired. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, + // Only #3 was retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments. // See: https://tools.ietf.org/html/rfc2883#section-4.1.1. func TestRACKDetectDSACK(t *testing.T) { - probeDone := make(chan int) - const ackNumToVerify = 2 - probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) - - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - numPackets := 8 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK #8 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Expect retransmission of #6 packet after RTO expires. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-8] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 3 * maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan int) + const ackNumToVerify = 2 + probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) + + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + numPackets := 8 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Cumulative ACK for [1-5] packets and SACK #8 packet (to prevent TLP). + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) + eighthPEnd := eighthPStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) + + // Expect retransmission of #6 packet after RTO expires. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Check DSACK was received for one segment. - {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1}, + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-8] are received. + start := c.IRS.Add(1 + seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 3 * maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // Check DSACK was received for one segment. + {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1}, } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of // order segments. // See: https://tools.ietf.org/html/rfc2883#section-4.1.2. func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { - probeDone := make(chan int) - const ackNumToVerify = 2 - probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) - - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - numPackets := 10 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - // Send DSACK block for #6 along with SACK for out of - // order #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan int) + const ackNumToVerify = 2 + probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) + + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + numPackets := 10 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + seventhPEnd := seventhPStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(1 + seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + // Send DSACK block for #6 along with SACK for out of + // order #9 packet. + start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } + }) } // TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a // duplicate of out of order packet. // See: https://tools.ietf.org/html/rfc2883#section-4.1.3 func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { - probeDone := make(chan int) - const ackNumToVerify = 4 - probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) - - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - numPackets := 10 - e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // ACK [1-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // Send SACK indicating #6 packet is missing and received #7 packet. - offset := seqnum.Size(bytesRead + maxPayload) - start := c.IRS.Add(1 + offset) - end := start.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Send SACK with #6 packet is missing and received [7-8] packets. - end = start.Add(2 * maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Consider #8 packet is duplicated on the network and send DSACK. - dsackStart := c.IRS.Add(1 + offset + maxPayload) - dsackEnd := dsackStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan int) + const ackNumToVerify = 4 + probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) + + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + numPackets := 10 + e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // ACK [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Send SACK indicating #6 packet is missing and received #7 packet. + offset := seqnum.Size(bytesRead + maxPayload) + start := c.IRS.Add(1 + offset) + end := start.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Send SACK with #6 packet is missing and received [7-8] packets. + end = start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Consider #8 packet is duplicated on the network and send DSACK. + dsackStart := c.IRS.Add(1 + offset + maxPayload) + dsackEnd := dsackStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } + }) } // TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment. // See: https://tools.ietf.org/html/rfc2883#section-4.2.1. func TestRACKDetectDSACKSingleDup(t *testing.T) { - probeDone := make(chan int) - const ackNumToVerify = 4 - probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) - - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - numPackets := 4 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-3] packets and received #4 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - - // ACK for retransmitted #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Simulate receiving delayed subsegment of #2 packet and delayed #3 packet by - // sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan int) + const ackNumToVerify = 4 + probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) + + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + numPackets := 4 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-3] packets and received #4 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Check DSACK was received for a subsegment. - {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1}, + // ACK for retransmitted #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Simulate receiving delayed subsegment of #2 packet and delayed #3 packet by + // sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // Check DSACK was received for a subsegment. + {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1}, } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous // duplicate subsegments covered by the cumulative acknowledgement. // See: https://tools.ietf.org/html/rfc2883#section-4.2.2. func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { - probeDone := make(chan int) - const ackNumToVerify = 5 - probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) - - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - numPackets := 6 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-5] packets and received #6 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - - // Received delayed #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #4 packet. - start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 - // packet by sending DSACK block for #2 packet. - dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan int) + const ackNumToVerify = 5 + probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) + + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + numPackets := 6 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-5] packets and received #6 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + + // Received delayed #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #4 packet. + start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 + // packet by sending DSACK block for #2 packet. + dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } + }) } // TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not // covered by the cumulative acknowledgement. // See: https://tools.ietf.org/html/rfc2883#section-4.2.3. func TestRACKDetectDSACKDup(t *testing.T) { - probeDone := make(chan int) - const ackNumToVerify = 5 - probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) - - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - numPackets := 7 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #3 packet. - start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - - // Consider #2 packet has been dropped and SACK #4 packet. - start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end2 := start2.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 - // packet by sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - end1 = end1.Add(seqnum.Size(2 * maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan int) + const ackNumToVerify = 5 + probe := dsackSeenCheckerProbe(t, ackNumToVerify, probeDone) + + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + numPackets := 7 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #3 packet. + start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + + // Consider #2 packet has been dropped and SACK #4 packet. + start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end2 := start2.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 + // packet by sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + end1 = end1.Add(seqnum.Size(2 * maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } + }) } // TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK // is not the first SACK block. func TestRACKWithInvalidDSACKBlock(t *testing.T) { - probeDone := make(chan struct{}) - const ackNumToVerify = 2 - var n int - probe := func(state *tcp.TCPEndpointState) { - // Validate that RACK does not detect DSACK when DSACK block is - // not the first SACK block. - n++ - t.Helper() - if state.Sender.RACKState.DSACKSeen { - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan struct{}) + const ackNumToVerify = 2 + var n int + probe := func(state *tcp.TCPEndpointState) { + // Validate that RACK does not detect DSACK when DSACK block is + // not the first SACK block. + n++ + t.Helper() + if state.Sender.RACKState.DSACKSeen { + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } - if n == ackNumToVerify { - close(probeDone) + if n == ackNumToVerify { + close(probeDone) + } } - } - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - numPackets := 10 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - - // Send DSACK block as second block. The first block is a SACK for #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - <-probeDone + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + numPackets := 10 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + seventhPEnd := seventhPStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(1 + seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + + // Send DSACK block as second block. The first block is a SACK for #9 packet. + start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone + }) } func reorderWindowCheckerProbe(numACK int, probeDone chan error) tcp.TCPProbeFunc { @@ -884,198 +927,207 @@ func reorderWindowCheckerProbe(numACK int, probeDone chan error) tcp.TCPProbeFun } func TestRACKCheckReorderWindow(t *testing.T) { - probeDone := make(chan error) - const ackNumToVerify = 3 - probe := reorderWindowCheckerProbe(ackNumToVerify, probeDone) - - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() - - const numPackets = 7 - e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed packets [2-6] which indicates there is reordering - // in the connection. - bytesRead += 6 * maxPayload - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - if err := <-probeDone; err != nil { - t.Fatalf("unexpected values for RACK variables: %v", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan error) + const ackNumToVerify = 3 + probe := reorderWindowCheckerProbe(ackNumToVerify, probeDone) + + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + const numPackets = 7 + e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed packets [2-6] which indicates there is reordering + // in the connection. + bytesRead += 6 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + if err := <-probeDone; err != nil { + t.Fatalf("unexpected values for RACK variables: %v", err) + } + }) } func TestRACKWithDuplicateACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - const numPackets = 4 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Send three duplicate ACKs to trigger fast recovery. The first - // segment is considered as lost and will be retransmitted after - // receiving the duplicate ACKs. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - end = end.Add(seqnum.Size(maxPayload)) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, e2e.TSOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + const numPackets = 4 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Send three duplicate ACKs to trigger fast recovery. The first + // segment is considered as lost and will be retransmitted after + // receiving the duplicate ACKs. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + for i := 0; i < 3; i++ { + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + end = end.Add(seqnum.Size(maxPayload)) } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, e2e.TSOptionSize) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK // is received. func TestRACKUpdateSackedOut(t *testing.T) { - probeDone := make(chan struct{}) - ackNum := 0 - probe := func(state *tcp.TCPEndpointState) { - // Validate that the endpoint Sender.SackedOut is what we expect. - if state.Sender.SackedOut != 2 && ackNum == 0 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + probeDone := make(chan struct{}) + ackNum := 0 + probe := func(state *tcp.TCPEndpointState) { + // Validate that the endpoint Sender.SackedOut is what we expect. + if state.Sender.SackedOut != 2 && ackNum == 0 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) + } - if !state.Sender.FastRecovery.Active && state.Sender.SackedOut != 0 && ackNum == 1 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) - } + if !state.Sender.FastRecovery.Active && state.Sender.SackedOut != 0 && ackNum == 1 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) + } - if ackNum > 0 { - close(probeDone) + if ackNum > 0 { + close(probeDone) + } + ackNum++ } - ackNum++ - } - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() - e2e.SendAndReceiveWithSACK(t, c, maxPayload, 8 /* numPackets */, true /* enableRACK */) + e2e.SendAndReceiveWithSACK(t, c, maxPayload, 8 /* numPackets */, true /* enableRACK */) - // ACK for [3-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) - bytesRead := 2 * maxPayload - end := start.Add(seqnum.Size(bytesRead)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + // ACK for [3-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) + bytesRead := 2 * maxPayload + end := start.Add(seqnum.Size(bytesRead)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - bytesRead += 3 * maxPayload - c.SendAck(seq, bytesRead) + bytesRead += 3 * maxPayload + c.SendAck(seq, bytesRead) - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone + }) } // TestRACKWithWindowFull tests that RACK honors the receive window size. func TestRACKWithWindowFull(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - e2e.SetStackSACKPermitted(t, c, true) - e2e.CreateConnectedWithSACKAndTS(c) - - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - const numPkts = 10 - data := make([]byte, numPkts*maxPayload) - for i := range data { - data[i] = byte(i) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + e2e.SetStackSACKPermitted(t, c, true) + e2e.CreateConnectedWithSACKAndTS(c) + + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + const numPkts = 10 + data := make([]byte, numPkts*maxPayload) + for i := range data { + data[i] = byte(i) + } - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - bytesRead := 0 - for i := 0; i < numPkts; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) - bytesRead += maxPayload - } + bytesRead := 0 + for i := 0; i < numPkts; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, e2e.TSOptionSize) + bytesRead += maxPayload + } - // Expect retransmission of last packet due to TLP. - c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, e2e.TSOptionSize) + // Expect retransmission of last packet due to TLP. + c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, e2e.TSOptionSize) - // SACK for first and last packet. - start := c.IRS.Add(seqnum.Size(maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + // SACK for first and last packet. + start := c.IRS.Add(seqnum.Size(maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - // Wait for RTT to trigger recovery. - time.Sleep(info.RTT) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, e2e.TSOptionSize) - - // Send ACK for #2 packet. - c.SendAck(seq, 3*maxPayload) - - // Expect retransmission of #3 packet. - c.ReceiveAndCheckPacketWithOptions(data, 3*maxPayload, maxPayload, e2e.TSOptionSize) - - // Send ACK with zero window size. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + 4*maxPayload), - RcvWnd: 0, + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + // Wait for RTT to trigger recovery. + time.Sleep(info.RTT) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, e2e.TSOptionSize) + + // Send ACK for #2 packet. + c.SendAck(seq, 3*maxPayload) + + // Expect retransmission of #3 packet. + c.ReceiveAndCheckPacketWithOptions(data, 3*maxPayload, maxPayload, e2e.TSOptionSize) + + // Send ACK with zero window size. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + 4*maxPayload), + RcvWnd: 0, + }) + + // No packet should be received as the receive window size is zero. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") }) - - // No packet should be received as the receive window size is zero. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) } diff --git a/pkg/tcpip/transport/tcp/test/e2e/tcp_sack_test.go b/pkg/tcpip/transport/tcp/test/e2e/tcp_sack_test.go index 6202a2aacc..6667103f50 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/tcp_sack_test.go @@ -21,6 +21,7 @@ import ( "os" "slices" "testing" + "testing/synctest" "time" "gvisor.dev/gvisor/pkg/buffer" @@ -35,6 +36,15 @@ import ( "gvisor.dev/gvisor/pkg/test/testutil" ) +// withSynctest runs fn inside a synctest bubble and waits for goroutines to finish. +func withSynctest(t *testing.T, fn func(t *testing.T)) { + t.Helper() + synctest.Test(t, func(t *testing.T) { + fn(t) + synctest.Wait() + }) +} + const ( maxPayload = 10 tsOptionSize = 12 @@ -45,43 +55,46 @@ const ( // enabled. func TestSackPermittedConnect(t *testing.T) { for _, sackEnabled := range []bool{false, true} { + sackEnabled := sackEnabled t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - e2e.SetStackSACKPermitted(t, c, sackEnabled) - e2e.SetStackTCPRecovery(t, c, 0) - rep := e2e.CreateConnectedWithSACKPermittedOption(c) - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // Restore the saved sequence number so that the - // VerifyXXX calls use the right sequence number for - // checking ACK numbers. - rep.NextSeqNum = savedSeqNum - if sackEnabled { - rep.VerifyACKHasSACK(sackBlocks) - } else { + withSynctest(t, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + e2e.SetStackSACKPermitted(t, c, sackEnabled) + e2e.SetStackTCPRecovery(t, c, 0) + rep := e2e.CreateConnectedWithSACKPermittedOption(c) + data := []byte{1, 2, 3} + + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum rep.VerifyACKNoSACK() - } - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // Restore the saved sequence number so that the + // VerifyXXX calls use the right sequence number for + // checking ACK numbers. + rep.NextSeqNum = savedSeqNum + if sackEnabled { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() + } + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) }) } } @@ -90,37 +103,40 @@ func TestSackPermittedConnect(t *testing.T) { // disabled and verifies that no SACKs are sent for out of order segments. func TestSackDisabledConnect(t *testing.T) { for _, sackEnabled := range []bool{false, true} { + sackEnabled := sackEnabled t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + withSynctest(t, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - e2e.SetStackSACKPermitted(t, c, sackEnabled) - e2e.SetStackTCPRecovery(t, c, 0) + e2e.SetStackSACKPermitted(t, c, sackEnabled) + e2e.SetStackTCPRecovery(t, c, 0) - rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) + rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) - data := []byte{1, 2, 3} + data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum + rep.VerifyACKNoSACK() - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) - // The ACK should contain the older sequence number and - // no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() + // The ACK should contain the older sequence number and + // no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) }) } } @@ -144,56 +160,60 @@ func TestSackPermittedAccept(t *testing.T) { } for _, tc := range testCases { + tc := tc t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { for _, sackEnabled := range []bool{false, true} { + sackEnabled := sackEnabled t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + withSynctest(t, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + if tc.cookieEnabled { + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + } + e2e.SetStackSACKPermitted(t, c, sackEnabled) + e2e.SetStackTCPRecovery(t, c, 0) + + rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS, SACKPermitted: tc.sackPermitted}) + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + + savedSeqNum := rep.NextSeqNum - if tc.cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // The ACK should contain the older + // sequence number. + rep.NextSeqNum = savedSeqNum + if sackEnabled && tc.sackPermitted { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() } - } - e2e.SetStackSACKPermitted(t, c, sackEnabled) - e2e.SetStackTCPRecovery(t, c, 0) - - rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS, SACKPermitted: tc.sackPermitted}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number. - rep.NextSeqNum = savedSeqNum - if sackEnabled && tc.sackPermitted { - rep.VerifyACKHasSACK(sackBlocks) - } else { + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() + }) }) } }) @@ -217,50 +237,54 @@ func TestSackDisabledAccept(t *testing.T) { } for _, tc := range testCases { + tc := tc t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { for _, sackEnabled := range []bool{false, true} { + sackEnabled := sackEnabled t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - if tc.cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + withSynctest(t, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + if tc.cookieEnabled { + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } - } - e2e.SetStackSACKPermitted(t, c, sackEnabled) - e2e.SetStackTCPRecovery(t, c, 0) + e2e.SetStackSACKPermitted(t, c, sackEnabled) + e2e.SetStackTCPRecovery(t, c, 0) - rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}) + rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + savedSeqNum := rep.NextSeqNum - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) - // The ACK should contain the older - // sequence number and no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() + // The ACK should contain the older + // sequence number and no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. + rep.VerifyACKNoSACK() + }) }) } }) @@ -355,229 +379,231 @@ func TestTrimSackBlockList(t *testing.T) { } func TestSACKRecovery(t *testing.T) { - probe := func(s *tcp.TCPEndpointState) { - // We use log.Printf instead of t.Logf here because this probe - // can fire even when the test function has finished. This is - // because closing the endpoint in cleanup() does not mean the - // actual worker loop terminates immediately as it still has to - // do a full TCP shutdown. But this test can finish running - // before the shutdown is done. Using t.Logf in such a case - // causes the test to panic due to logging after test finished. - log.Printf("state: %+v\n", s) - } - const maxPayload = 10 - // See: tcp.makeOptions for why tsOptionSize is set to 12 here. - const tsOptionSize = 12 - // Enabling SACK means the payload size is reduced to account - // for the extra space required for the TCP options. - // - // We increase the MTU by e2e.MaxTCPOptionSize bytes to account for SACK - // and Timestamp options. - c := context.NewWithProbe(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+e2e.MaxTCPOptionSize+maxPayload), probe) - defer c.Cleanup() - - e2e.SetStackSACKPermitted(t, c, true) - e2e.SetStackTCPRecovery(t, c, 0) - e2e.CreateConnectedWithSACKAndTS(c) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(seq, bytesRead) + withSynctest(t, func(t *testing.T) { + probe := func(s *tcp.TCPEndpointState) { + // We use log.Printf instead of t.Logf here because this probe + // can fire even when the test function has finished. This is + // because closing the endpoint in cleanup() does not mean the + // actual worker loop terminates immediately as it still has to + // do a full TCP shutdown. But this test can finish running + // before the shutdown is done. Using t.Logf in such a case + // causes the test to panic due to logging after test finished. + log.Printf("state: %+v\n", s) + } + const maxPayload = 10 + // See: tcp.makeOptions for why tsOptionSize is set to 12 here. + const tsOptionSize = 12 + // Enabling SACK means the payload size is reduced to account + // for the extra space required for the TCP options. + // + // We increase the MTU by e2e.MaxTCPOptionSize bytes to account for SACK + // and Timestamp options. + c := context.NewWithProbe(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+e2e.MaxTCPOptionSize+maxPayload), probe) + defer c.Cleanup() + + e2e.SetStackSACKPermitted(t, c, true) + e2e.SetStackTCPRecovery(t, c, 0) + e2e.CreateConnectedWithSACKAndTS(c) + + const iterations = 3 + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) + for i := range data { + data[i] = byte(i) } - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } + // Do slow start for a few iterations. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(seq, bytesRead) + } - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end := start.Add(10) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) + end := start.Add(10) + for i := 0; i < 3; i++ { + c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) + end = end.Add(10) } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } - // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause - // window inflation and sending of packets is completely handled by the - // SACK Recovery algorithm. We should see no packets being released, as - // the cwnd at this point after entering recovery should be half of the - // outstanding number of packets in flight. - for i := 0; i < 7; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } + // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause + // window inflation and sending of packets is completely handled by the + // SACK Recovery algorithm. We should see no packets being released, as + // the cwnd at this point after entering recovery should be half of the + // outstanding number of packets in flight. + for i := 0; i < 7; i++ { + c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) + end = end.Add(10) + } - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. This along with the 10 sacked - // segments above should reduce the outstanding below the current - // congestion window allowing the sender to transmit data. - rtxOffset = bytesRead - expected*maxPayload/2 - - // Now send a partial ACK w/ a SACK block that indicates that the next 3 - // segments are lost and we have received 6 segments after the lost - // segments. This should cause the sender to immediately transmit all 3 - // segments in response to this ACK unlike in FastRecovery where only 1 - // segment is retransmitted per ACK. - start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end = start.Add(60) - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - - // At this point, we acked expected/2 packets and we SACKED 6 packets and - // 3 segments were considered lost due to the SACK block we sent. - // - // So total packets outstanding can be calculated as follows after 7 - // iterations of slow start -> 10/20/40/80/160/320/640. So expected - // should be 640 at start, then we went to recover at which point the - // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the - // network). - // Outstanding at this point after acking half the window - // (320 packets) will be: - // outstanding = 640-320-6(due to SACK block)-3 = 311 - // - // The last 3 is due to the fact that the first 3 packets after - // rtxOffset will be considered lost due to the SACK blocks sent. - // Receive the retransmit due to partial ack. - - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - // Receive the 2 extra packets that should have been retransmitted as - // those should be considered lost and immediately retransmitted based - // on the SACK information in the previous ACK sent above. - for i := 0; i < 2; i++ { - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) - } + recover := bytesRead - // Now we should get 9 more new unsent packets as the cwnd is 323 and - // outstanding is 311. - for i := 0; i < 9; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) - metricPollFn = func() error { - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } + // Acknowledge half of the pending data. This along with the 10 sacked + // segments above should reduce the outstanding below the current + // congestion window allowing the sender to transmit data. + rtxOffset = bytesRead - expected*maxPayload/2 - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) - } + // Now send a partial ACK w/ a SACK block that indicates that the next 3 + // segments are lost and we have received 6 segments after the lost + // segments. This should cause the sender to immediately transmit all 3 + // segments in response to this ACK unlike in FastRecovery where only 1 + // segment is retransmitted per ACK. + start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) + end = start.Add(60) + c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) + // At this point, we acked expected/2 packets and we SACKED 6 packets and + // 3 segments were considered lost due to the SACK block we sent. + // + // So total packets outstanding can be calculated as follows after 7 + // iterations of slow start -> 10/20/40/80/160/320/640. So expected + // should be 640 at start, then we went to recover at which point the + // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the + // network). + // Outstanding at this point after acking half the window + // (320 packets) will be: + // outstanding = 640-320-6(due to SACK block)-3 = 311 + // + // The last 3 is due to the fact that the first 3 packets after + // rtxOffset will be considered lost due to the SACK blocks sent. + // Receive the retransmit due to partial ack. + + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) + // Receive the 2 extra packets that should have been retransmitted as + // those should be considered lost and immediately retransmitted based + // on the SACK information in the previous ACK sent above. + for i := 0; i < 2; i++ { + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(seq, recover) - - // At this point, the cwnd should reset to expected/2 and there are 9 - // packets outstanding. - // - // Now in the first iteration since there are 9 packets outstanding. - // We would expect to get expected/2 - 9 packets. But subsequent - // iterations will send us expected/2 + 1 (per iteration). - expected = expected/2 - 9 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { + // Now we should get 9 more new unsent packets as the cwnd is 323 and + // outstanding is 311. + for i := 0; i < 9; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) - // Acknowledge all the data received so far. - c.SendAck(seq, bytesRead) + metricPollFn = func() error { + // In SACK recovery only the first segment is fast retransmitted when + // entering recovery. + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 9 + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) + } + + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) + } + return nil } - expected++ - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) + + // Acknowledge all pending data to recover point. + c.SendAck(seq, recover) + + // At this point, the cwnd should reset to expected/2 and there are 9 + // packets outstanding. + // + // Now in the first iteration since there are 9 packets outstanding. + // We would expect to get expected/2 - 9 packets. But subsequent + // iterations will send us expected/2 + 1 (per iteration). + expected = expected/2 - 9 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(seq, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 9 + } + expected++ + } + }) } // TestRecoveryEntry tests the following two properties of entering recovery: @@ -586,108 +612,111 @@ func TestSACKRecovery(t *testing.T) { // - Only enter recovery when at least one more byte of data beyond the highest // byte that was outstanding when fast retransmit was last entered is acked. func TestRecoveryEntry(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - numPackets := 5 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, false /* enableRACK */) - - // Ack #1 packet. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, maxPayload) - - // Now SACK #3, #4 and #5 packets. This will simulate a situation where - // SND.UNA should be considered lost and the sender should enter fast recovery - // (even though dupack count is still below threshold). - p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - p3End := p3Start.Add(maxPayload) - p4Start := p3End - p4End := p4Start.Add(maxPayload) - p5Start := p4End - p5End := p5Start.Add(maxPayload) - c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}}) - - // Expect #2 to be retransmitted. - c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - // No RTOs should have fired yet. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + withSynctest(t, func(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + numPackets := 5 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, false /* enableRACK */) + + // Ack #1 packet. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, maxPayload) + + // Now SACK #3, #4 and #5 packets. This will simulate a situation where + // SND.UNA should be considered lost and the sender should enter fast recovery + // (even though dupack count is still below threshold). + p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + p3End := p3Start.Add(maxPayload) + p4Start := p3End + p4End := p4Start.Add(maxPayload) + p5Start := p4End + p5End := p5Start.Add(maxPayload) + c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}}) + + // Expect #2 to be retransmitted. + c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // SACK recovery must have happened. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + // #2 was retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + // No RTOs should have fired yet. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Send 4 more packets. - var r bytes.Reader - data = append(data, data...) - r.Reset(data[5*maxPayload : 9*maxPayload]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - var sackBlocks []header.SACKBlock - bytesRead := numPackets * maxPayload - for i := 0; i < 4; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - if i > 0 { - pStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)}) - c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } - bytesRead += maxPayload - } - // #6 should be retransmitted after RTO. The sender should NOT enter fast - // recovery because the highest byte that was outstanding when fast recovery - // was last entered is #5 packet's end. And the sender requires at least one - // more byte beyond that (#6 packet start) to be acked to enter recovery. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) - c.SendAck(seq, 9*maxPayload) + // Send 4 more packets. + var r bytes.Reader + data = append(data, data...) + r.Reset(data[5*maxPayload : 9*maxPayload]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - metricPollFn = func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Only 1 SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 and #6 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 2}, - // RTO should have fired once. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, + var sackBlocks []header.SACKBlock + bytesRead := numPackets * maxPayload + for i := 0; i < 4; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + if i > 0 { + pStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) + sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)}) + c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks) + } + bytesRead += maxPayload } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + + // #6 should be retransmitted after RTO. The sender should NOT enter fast + // recovery because the highest byte that was outstanding when fast recovery + // was last entered is #5 packet's end. And the sender requires at least one + // more byte beyond that (#6 packet start) to be acked to enter recovery. + c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) + c.SendAck(seq, 9*maxPayload) + + metricPollFn = func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + // Only 1 SACK recovery must have happened. + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + // #2 and #6 were retransmitted. + {tcpStats.Retransmits, "stats.TCP.Retransmits", 2}, + // RTO should have fired once. + {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + }) } func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery, numSpuriousRTO uint64) { @@ -740,219 +769,222 @@ func buildTSOptionFromHeader(tcpHdr header.TCP) []byte { } func TestDetectSpuriousRecoveryWithRTO(t *testing.T) { - probeDone := make(chan struct{}) - probe := func(s *tcp.TCPEndpointState) { - if s.Sender.RetransmitTS == 0 { - t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") - } - if !s.Sender.SpuriousRecovery { - t.Fatalf("Spurious recovery was not detected") + withSynctest(t, func(t *testing.T) { + probeDone := make(chan struct{}) + probe := func(s *tcp.TCPEndpointState) { + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) } - close(probeDone) - } - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() - e2e.SetStackSACKPermitted(t, c, true) - e2e.CreateConnectedWithSACKAndTS(c) - numPackets := 5 - data := make([]byte, numPackets*maxPayload) - for i := range data { - data[i] = byte(i) - } - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + e2e.SetStackSACKPermitted(t, c, true) + e2e.CreateConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) + } + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - var options []byte - var bytesRead uint32 - for i := 0; i < numPackets; i++ { - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) - - // Get options only for the first packet. This will be sent with - // the ACK to indicate the acknowledgement is for the original - // packet. - if i == 0 && c.TimeStampEnabled { - options = buildTSOptionFromHeader(tcpHdr) - } - bytesRead += uint32(len(tcpHdr.Payload())) - } + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Expect #5 segment with TLP. - c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Expect #5 segment with TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) - // Expect #1 segment because of RTO. - c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + // Expect #1 segment because of RTO. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) - info := tcpip.TCPInfoOption{} - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) - } + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } - if info.CcState != tcpip.RTORecovery { - t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery) - } + if info.CcState != tcpip.RTORecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery) + } - // Acknowledge the data. - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), - RcvWnd: rcvWnd, - TCPOpts: options, - }) + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) - // Wait for the probe function to finish processing the - // ACK before the test completes. - <-probeDone + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone - verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */, 1 /* numSpuriousRTO */) + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */, 1 /* numSpuriousRTO */) + }) } func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) { - numAck := 0 - probeDone := make(chan struct{}) - probe := func(s *tcp.TCPEndpointState) { - if numAck < 3 { - numAck++ - return + withSynctest(t, func(t *testing.T) { + numAck := 0 + probeDone := make(chan struct{}) + probe := func(s *tcp.TCPEndpointState) { + if numAck < 3 { + numAck++ + return + } + + if s.Sender.RetransmitTS == 0 { + t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + } + if !s.Sender.SpuriousRecovery { + t.Fatalf("Spurious recovery was not detected") + } + close(probeDone) } - if s.Sender.RetransmitTS == 0 { - t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0") + c := context.NewWithProbe(t, uint32(mtu), probe) + defer c.Cleanup() + + e2e.SetStackSACKPermitted(t, c, true) + e2e.CreateConnectedWithSACKAndTS(c) + numPackets := 5 + data := make([]byte, numPackets*maxPayload) + for i := range data { + data[i] = byte(i) } - if !s.Sender.SpuriousRecovery { - t.Fatalf("Spurious recovery was not detected") + // Write the data. + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - close(probeDone) - } - c := context.NewWithProbe(t, uint32(mtu), probe) - defer c.Cleanup() + var options []byte + var bytesRead uint32 + for i := 0; i < numPackets; i++ { + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) + + // Get options only for the first packet. This will be sent with + // the ACK to indicate the acknowledgement is for the original + // packet. + if i == 0 && c.TimeStampEnabled { + options = buildTSOptionFromHeader(tcpHdr) + } + bytesRead += uint32(len(tcpHdr.Payload())) + } - e2e.SetStackSACKPermitted(t, c, true) - e2e.CreateConnectedWithSACKAndTS(c) - numPackets := 5 - data := make([]byte, numPackets*maxPayload) - for i := range data { - data[i] = byte(i) - } - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) - var options []byte - var bytesRead uint32 - for i := 0; i < numPackets; i++ { - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data) - - // Get options only for the first packet. This will be sent with - // the ACK to indicate the acknowledgement is for the original - // packet. - if i == 0 && c.TimeStampEnabled { - options = buildTSOptionFromHeader(tcpHdr) - } - bytesRead += uint32(len(tcpHdr.Payload())) - } + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - // Receive the retransmitted packet after TLP. - c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Send ACK for #3 and #4 segments to avoid entering TLP. - start := c.IRS.Add(3*maxPayload + 1) - end := start.Add(2 * maxPayload) - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) - c.SendAck(seq, 0 /* bytesReceived */) - c.SendAck(seq, 0 /* bytesReceived */) + info := tcpip.TCPInfoOption{} + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } - // Receive the retransmitted packet after three duplicate ACKs. - c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + if info.CcState != tcpip.SACKRecovery { + t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery) + } - info := tcpip.TCPInfoOption{} - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) - } + // Acknowledge the data. + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), + RcvWnd: rcvWnd, + TCPOpts: options, + }) - if info.CcState != tcpip.SACKRecovery { - t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery) - } + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone - // Acknowledge the data. - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)), - RcvWnd: rcvWnd, - TCPOpts: options, + verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */, 0 /* numSpuriousRTO */) }) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - <-probeDone - - verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */, 0 /* numSpuriousRTO */) } func TestNoSpuriousRecoveryWithDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - e2e.SetStackSACKPermitted(t, c, true) - e2e.CreateConnectedWithSACKAndTS(c) - numPackets := 5 - data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) - - // Receive the retransmitted packet after TLP. - c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) - - // Send ACK for #3 and #4 segments to avoid entering TLP. - start := c.IRS.Add(3*maxPayload + 1) - end := start.Add(2 * maxPayload) - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - - c.SendAck(seq, 0 /* bytesReceived */) - c.SendAck(seq, 0 /* bytesReceived */) - - // Receive the retransmitted packet after three duplicate ACKs. - c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) - - // Acknowledge the data with DSACK for #1 segment. - start = c.IRS.Add(maxPayload + 1) - end = start.Add(2 * maxPayload) - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}}) - - verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */, 0 /* numSpuriousRTO */) + withSynctest(t, func(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + e2e.SetStackSACKPermitted(t, c, true) + e2e.CreateConnectedWithSACKAndTS(c) + numPackets := 5 + data := e2e.SendAndReceiveWithSACK(t, c, maxPayload, numPackets, true /* enableRACK */) + + // Receive the retransmitted packet after TLP. + c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize) + + // Send ACK for #3 and #4 segments to avoid entering TLP. + start := c.IRS.Add(3*maxPayload + 1) + end := start.Add(2 * maxPayload) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + + c.SendAck(seq, 0 /* bytesReceived */) + c.SendAck(seq, 0 /* bytesReceived */) + + // Receive the retransmitted packet after three duplicate ACKs. + c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) + + // Acknowledge the data with DSACK for #1 segment. + start = c.IRS.Add(maxPayload + 1) + end = start.Add(2 * maxPayload) + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}}) + + verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */, 0 /* numSpuriousRTO */) + }) } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) } diff --git a/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go b/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go index 33a4c07dd3..16b866ceb0 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/tcp_test.go @@ -22,6 +22,7 @@ import ( "os" "strings" "testing" + "testing/synctest" "time" "github.com/google/go-cmp/cmp" @@ -74,8 +75,8 @@ func (e *endpointTester) CheckRead(t *testing.T) []byte { t.Fatalf("ep.Read = _, %s; want _, nil", err) } if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), + Count: buf.Len(), + Total: buf.Len(), }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } @@ -87,8 +88,8 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha t.Helper() var buf bytes.Buffer w := tcpip.LimitedWriter{ - W: &buf, - N: int64(count), + W: &buf, + N: int64(count), } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) @@ -108,1150 +109,1220 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } func TestGiveUpConnect(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) - wq.EventRegister(&waitEntry) - defer wq.EventUnregister(&waitEntry) + // Register for notification, then start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) + wq.EventRegister(&waitEntry) + defer wq.EventUnregister(&waitEntry) - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - // Close the connection, wait for completion. - ep.Close() + // Close the connection, wait for completion. + ep.Close() - // Wait for ep to become writable. - <-notifyCh + // Wait for ep to become writable. + <-notifyCh - // Call Connect again to retrieve the handshake failure status - // and stats updates. - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + // Call Connect again to retrieve the handshake failure status + // and stats updates. + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) - } + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) + } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + }) } // Test for ICMP error handling without completing handshake. func TestConnectICMPError(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) - wq.EventRegister(&waitEntry) - defer wq.EventUnregister(&waitEntry) + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) + wq.EventRegister(&waitEntry) + defer wq.EventUnregister(&waitEntry) - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - syn := c.GetPacket() - defer syn.Release() - checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) + syn := c.GetPacket() + defer syn.Release() + checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) - wep := ep.(interface { - LastErrorLocked() tcpip.Error - }) + wep := ep.(interface { + LastErrorLocked() tcpip.Error + }) - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, e2e.DefaultMTU) + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, e2e.DefaultMTU) - for { - if err := wep.LastErrorLocked(); err != nil { - if d := cmp.Diff(&tcpip.ErrHostUnreachable{}, err); d != "" { - t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) + for { + if err := wep.LastErrorLocked(); err != nil { + if d := cmp.Diff(&tcpip.ErrHostUnreachable{}, err); d != "" { + t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) + } + break } - break + time.Sleep(time.Millisecond) } - time.Sleep(time.Millisecond) - } - <-notifyCh + <-notifyCh - // The stack would have unregistered the endpoint because of the ICMP error. - // Expect a RST for any subsequent packets sent to the endpoint. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, - AckNum: c.IRS + 1, - }) + // The stack would have unregistered the endpoint because of the ICMP error. + // Expect a RST for any subsequent packets sent to the endpoint. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, + AckNum: c.IRS + 1, + }) - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + }) } func TestConnectIncrementActiveConnection(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - stats := c.Stack().Stats() - want := stats.TCP.ActiveConnectionOpenings.Value() + 1 + stats := c.Stack().Stats() + want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) - } + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) + } + }) } func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - stats := c.Stack().Stats() - want := stats.TCP.FailedConnectionAttempts.Value() + stats := c.Stack().Stats() + want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) - } + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) + } + }) } func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + want := stats.TCP.FailedConnectionAttempts.Value() + 1 - { - err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrNetworkUnreachable{}, err); d != "" { - t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + { + err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrNetworkUnreachable{}, err); d != "" { + t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) - } + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) + } + }) } func TestCloseWithoutConnect(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - c.EP.Close() - c.EP = nil + c.EP.Close() + c.EP = nil - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } + }) } func TestHandshakeTimeoutConnectedCount(t *testing.T) { - clock := faketime.NewManualClock() - c := context.NewWithOpts(t, context.Options{ - EnableV4: true, - EnableV6: true, - MTU: e2e.DefaultMTU, - Clock: clock, - }) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + clock := faketime.NewManualClock() + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + EnableV6: true, + MTU: e2e.DefaultMTU, + Clock: clock, + }) + defer c.Cleanup() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep - we, ch := waiter.NewChannelEntry(waiter.WritableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + we, ch := waiter.NewChannelEntry(waiter.WritableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - switch err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}).(type) { - case *tcpip.ErrConnectStarted: - default: - t.Fatalf("Connect did not start: %v", err) - } + switch err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}).(type) { + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect did not start: %v", err) + } - clock.Advance(tcp.DefaultKeepaliveInterval) - clock.Advance(tcp.DefaultKeepaliveInterval) - <-ch - switch err := c.EP.LastError().(type) { - case *tcpip.ErrTimeout: - default: - t.Fatalf("Connect didn't timeout: %v", err) - } - if got, want := c.Stack().Stats().TCP.CurrentConnected.Value(), uint64(0); got != want { - t.Fatalf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, want) - } + clock.Advance(tcp.DefaultKeepaliveInterval) + clock.Advance(tcp.DefaultKeepaliveInterval) + <-ch + switch err := c.EP.LastError().(type) { + case *tcpip.ErrTimeout: + default: + t.Fatalf("Connect didn't timeout: %v", err) + } + if got, want := c.Stack().Stats().TCP.CurrentConnected.Value(), uint64(0); got != want { + t.Fatalf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, want) + } + }) } func TestTCPSegmentsSentIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - // SYN and ACK - want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + // SYN and ACK + want := stats.TCP.SegmentsSent.Value() + 2 + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + + if got := stats.TCP.SegmentsSent.Value(); got != want { + t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { + t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) + } + }) } func TestTCPResetsSentIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - want := stats.TCP.SegmentsSent.Value() + 1 + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + want := stats.TCP.SegmentsSent.Value() + 1 - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) - // Receive the SYN-ACK reply. - v := c.GetPacket() - defer v.Release() - tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - // If the AckNum is not the increment of the last sequence number, a RST - // segment is sent back in response. - AckNum: c.IRS + 2, - } + // Receive the SYN-ACK reply. + v := c.GetPacket() + defer v.Release() + tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + // If the AckNum is not the increment of the last sequence number, a RST + // segment is sent back in response. + AckNum: c.IRS + 2, + } - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Send ACK. + c.SendPacket(nil, ackHeaders) - v = c.GetPacket() - defer v.Release() + v = c.GetPacket() + defer v.Release() - metricPollFn := func() error { - if got := stats.TCP.ResetsSent.Value(); got != want { - return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) + metricPollFn := func() error { + if got := stats.TCP.ResetsSent.Value(); got != want { + return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } // TestTCPResetsSentNoICMP confirms that we don't get an ICMP DstUnreachable // packet when we try send a packet which is not part of an active session. func TestTCPResetsSentNoICMP(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - - // Send a SYN request for a closed port. This should elicit an RST - // but NOT an ICMPv4 DstUnreachable packet. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + + // Send a SYN request for a closed port. This should elicit an RST + // but NOT an ICMPv4 DstUnreachable packet. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) - // Receive whatever comes back. - v := c.GetPacket() - defer v.Release() - ipHdr := header.IPv4(v.AsSlice()) - if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { - t.Errorf("unexpected protocol, got = %d, want = %d", got, want) - } + // Receive whatever comes back. + v := c.GetPacket() + defer v.Release() + ipHdr := header.IPv4(v.AsSlice()) + if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { + t.Errorf("unexpected protocol, got = %d, want = %d", got, want) + } - // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. - sent := stats.ICMP.V4.PacketsSent - if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { - t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) - } + // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. + sent := stats.ICMP.V4.PacketsSent + if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { + t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) + } + }) } // TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates // a RST if an ACK is received on the listening socket for which there is no // active handshake in progress and we are not using SYN cookies. func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) + // Receive the SYN-ACK reply. + v := c.GetPacket() + defer v.Release() + tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } - // Receive the SYN-ACK reply. - v := c.GetPacket() - defer v.Release() - tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } + // Send ACK. + c.SendPacket(nil, ackHeaders) - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } - } - // Lower stackwide TIME_WAIT timeout so that the reservations - // are released instantly on Close. - tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) - } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) + } - c.EP.Close() - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } + c.EP.Close() + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } - c.SendPacket(nil, finHeaders) + c.SendPacket(nil, finHeaders) - // Get the ACK to the FIN we just sent. - b = c.GetPacket() - defer b.Release() + // Get the ACK to the FIN we just sent. + b = c.GetPacket() + defer b.Release() - // Since an active close was done we need to wait for a little more than - // tcpLingerTimeout for the port reservations to be released and the - // socket to move to a CLOSED state. - time.Sleep(20 * time.Millisecond) + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) - // Now resend the same ACK, this ACK should generate a RST as there - // should be no endpoint in SYN-RCVD state and we are not using - // syn-cookies yet. The reason we send the same ACK is we need a valid - // cookie(IRS) generated by the netstack without which the ACK will be - // rejected. - c.SendPacket(nil, ackHeaders) + // Now resend the same ACK, this ACK should generate a RST as there + // should be no endpoint in SYN-RCVD state and we are not using + // syn-cookies yet. The reason we send the same ACK is we need a valid + // cookie(IRS) generated by the netstack without which the ACK will be + // rejected. + c.SendPacket(nil, ackHeaders) - b = c.GetPacket() - defer b.Release() - checker.IPv4(t, b, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) + b = c.GetPacket() + defer b.Release() + checker.IPv4(t, b, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + }) } func TestTCPResetsReceivedIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) + } + }) } func TestTCPResetsDoNotGenerateResets(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } - c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) + } + c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) + }) } func TestActiveHandshake(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + }) } func TestNonBlockingClose(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + // Close the endpoint and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %s", diff) + } + }) } func TestConnectResetAfterClose(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + } - // Set TCPLinger to 3 seconds so that sockets are marked closed - // after 3 second in FIN_WAIT2 state. - tcpLingerTimeout := 3 * time.Second - opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil + // Close the endpoint, make sure we get a FIN segment, then acknowledge + // to complete closure of sender, but don't send our own FIN. + ep.Close() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // Close the endpoint, make sure we get a FIN segment, then acknowledge - // to complete closure of sender, but don't send our own FIN. - ep.Close() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) - // Wait for the ep to give up waiting for a FIN. - time.Sleep(tcpLingerTimeout + 1*time.Second) + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // Now send an ACK and it should trigger a RST as the endpoint should - // not exist anymore. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + for { + v := c.GetPacket() + defer v.Release() + tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) + if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { + // This is a retransmit of the FIN, ignore it. + continue + } - for { - v := c.GetPacket() - defer v.Release() - tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) - if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { - // This is a retransmit of the FIN, ignore it. - continue + checker.IPv4(t, v, + checker.TCP( + checker.DstPort(context.TestPort), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + break } - - checker.IPv4(t, v, - checker.TCP( - checker.DstPort(context.TestPort), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence number - // of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - break - } + }) } // TestCurrentConnectedIncrement tests increment of the current // established and connected counters. func TestCurrentConnectedIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + } - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) - } - gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() - if gotConnected != 1 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) - } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) + } + gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() + if gotConnected != 1 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) + } - ep.Close() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + ep.Close() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) - } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) + } - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // Check that the stack acks the FIN. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Check that the stack acks the FIN. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - // Wait for a little more than the TIME-WAIT duration for the socket to - // transition to CLOSED state. - time.Sleep(1200 * time.Millisecond) + // Wait for a little more than the TIME-WAIT duration for the socket to + // transition to CLOSED state. + time.Sleep(1200 * time.Millisecond) - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } + }) } // TestClosingWithEnqueuedSegments tests handling of still enqueued segments // when the endpoint transitions to StateClose. The in-flight segments would be // re-enqueued to a any listening endpoint. func TestClosingWithEnqueuedSegments(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil - if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) + } - // Send a FIN for ESTABLISHED --> CLOSED-WAIT - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Get the ACK for the FIN we sent. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Get the ACK for the FIN we sent. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - // Give the stack a few ms to transition the endpoint out of ESTABLISHED - // state. - time.Sleep(10 * time.Millisecond) + // Give the stack a few ms to transition the endpoint out of ESTABLISHED + // state. + time.Sleep(10 * time.Millisecond) - if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) + } - // Close the application endpoint for CLOSE_WAIT --> LAST_ACK - ep.Close() + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() - // Get the FIN - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) + // Get the FIN + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) - if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - // Pause the endpoint. - ep.(interface{ StopWork() }).StopWork() + // Pause the endpoint. + ep.(interface{ StopWork() }).StopWork() - // Enqueue last ACK followed by an ACK matching the endpoint - // - // Send Last ACK for LAST_ACK --> CLOSED - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // Send a packet with ACK set, this would generate RST when - // not using SYN cookies as in this test. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss.Add(2), - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss.Add(2), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // Unpause endpoint. - ep.(interface{ ResumeWork() }).ResumeWork() + // Unpause endpoint. + ep.(interface{ ResumeWork() }).ResumeWork() - // Wait for the endpoint to resume and update state. - time.Sleep(10 * time.Millisecond) + // Wait for the endpoint to resume and update state. + time.Sleep(10 * time.Millisecond) - // Expect the endpoint to be closed. - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) - } + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) + } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } - // Check if the endpoint was moved to CLOSED and netstack sent a reset in - // response to the ACK packet that we sent after last-ACK. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) + // Check if the endpoint was moved to CLOSED and netstack sent a reset in + // response to the ACK packet that we sent after last-ACK. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + }) } func TestSimpleReceive(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + ept := endpointTester{c.EP} - ept := endpointTester{c.EP} + data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } + // Receive data. + v := ept.CheckRead(t) + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } - // Check that ACK is received. - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Check that ACK is received. + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + }) } // TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when // creating a new active TCP socket. It should be present in the sent TCP // SYN segment. func TestUserSuppliedMSSOnConnect(t *testing.T) { - const mtu = 5000 - - ips := []struct { - name string - createEP func(*context.Context) - connectAddr tcpip.Address - checker func(*testing.T, *context.Context, uint16, int) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - connectAddr: context.TestAddr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(true) - }, - connectAddr: context.TestV6Addr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const mtu = 5000 + + ips := []struct { + name string + createEP func(*context.Context) + connectAddr tcpip.Address + checker func(*testing.T, *context.Context, uint16, int) + maxMSS uint16 + }{ + { + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, + connectAddr: context.TestAddr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, + }, + { + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(true) }, - } + connectAddr: context.TestV6Addr, + checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) + }, + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - ip.createEP(c) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } + ip.createEP(c) - // Get expected window size. - rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + // Set the MSS socket option. + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - { - err := c.EP.Connect(connectAddr) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) + // Get expected window size. + rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} + { + err := c.EP.Connect(connectAddr) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) + } } - } - // Receive SYN packet with our user supplied MSS. - ip.checker(t, c, test.expMSS, ws) - }) - } - }) - } + // Receive SYN packet with our user supplied MSS. + ip.checker(t, c, test.expMSS, ws) + }) + } + }) + } + }) } // TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used // when completing the handshake for a new TCP connection from a TCP // listening socket. It should be present in the sent TCP SYN-ACK segment. func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const mtu = 5000 - - ips := []struct { - name string - createEP func(*context.Context) - sendPkt func(*context.Context, *context.Headers) - checker func(*testing.T, *context.Context, uint16, uint16) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendPacket(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(false) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendV6Packet(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const mtu = 5000 + + ips := []struct { + name string + createEP func(*context.Context) + sendPkt func(*context.Context, *context.Headers) + checker func(*testing.T, *context.Context, uint16, uint16) + maxMSS uint16 + }{ + { + name: "IPv4", + createEP: func(c *context.Context) { + c.Create(-1) + }, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendPacket(nil, h) + }, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) + }, + maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, + { + name: "IPv6", + createEP: func(c *context.Context) { + c.CreateV6Endpoint(false) }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, + sendPkt: func(c *context.Context, h *context.Headers) { + c.SendV6Packet(nil, h) }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, + checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, checker.TCP( + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) }, - } + maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, ip := range ips { + t.Run(ip.name, func(t *testing.T) { + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + name: "EqualToMaxMSS", + setMSS: ip.maxMSS, + expMSS: ip.maxMSS, + }, + { + name: "LessThanMaxMSS", + setMSS: ip.maxMSS - 1, + expMSS: ip.maxMSS - 1, + }, + { + name: "GreaterThanMaxMSS", + setMSS: ip.maxMSS + 1, + expMSS: ip.maxMSS, + }, + } - ip.createEP(c) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } + ip.createEP(c) - bindAddr := tcpip.FullAddress{Port: context.StackPort} - if err := c.EP.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s:", bindAddr, err) - } + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) + } - backlog := 5 - // Keep the number of client requests twice to the backlog - // such that half of the connections do not use syncookies - // and the other half does. - clientConnects := backlog * 2 + bindAddr := tcpip.FullAddress{Port: context.StackPort} + if err := c.EP.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s:", bindAddr, err) + } - if err := c.EP.Listen(backlog); err != nil { - t.Fatalf("Listen(%d): %s:", backlog, err) - } + backlog := 5 + // Keep the number of client requests twice to the backlog + // such that half of the connections do not use syncookies + // and the other half does. + clientConnects := backlog * 2 - for i := 0; i < clientConnects; i++ { - // Send a SYN requests. - iss := seqnum.Value(i) - srcPort := context.TestPort + uint16(i) - ip.sendPkt(c, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) + if err := c.EP.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s:", backlog, err) + } - // Receive the SYN-ACK reply. - ip.checker(t, c, srcPort, test.expMSS) - } - }) - } - }) - } + for i := 0; i < clientConnects; i++ { + // Send a SYN requests. + iss := seqnum.Value(i) + srcPort := context.TestPort + uint16(i) + ip.sendPkt(c, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + ip.checker(t, c, srcPort, test.expMSS) + } + }) + } + }) + } + }) } + func TestSendRstOnListenerRxSynAckV4(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1) + c.Create(-1) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.TCPSeqNum(200))) + }) } func TestSendRstOnListenerRxSynAckV6(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(true) + c.CreateV6Endpoint(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.TCPSeqNum(200))) + }) } // TestNoSynCookieWithoutOverflow tests that SYN-COOKIEs are not issued when the @@ -1259,53 +1330,56 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { // accepted we do not see a SYN-COOKIE even > 2x listen backlog number of connections // are accepted. func TestNoSynCookieWithoutOverflow(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1) + c.Create(-1) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - const backlog = 10 - if err := c.EP.Listen(backlog); err != nil { - t.Fatal("Listen failed:", err) - } + const backlog = 10 + if err := c.EP.Listen(backlog); err != nil { + t.Fatal("Listen failed:", err) + } - doOne := func(portIndex int) { - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + doOne := func(portIndex int) { + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - _, _ = executeHandshake(t, c, context.TestPort+uint16(portIndex), false /* synCookiesInUse */) + _, _ = executeHandshake(t, c, context.TestPort+uint16(portIndex), false /* synCookiesInUse */) - _, _, err := c.EP.Accept(nil) - if err == nil { - return - } - switch { - case cmp.Equal(&tcpip.ErrWouldBlock{}, err): - { - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + _, _, err := c.EP.Accept(nil) + if err == nil { + return + } + switch { + case cmp.Equal(&tcpip.ErrWouldBlock{}, err): + { + select { + case <-ch: + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } + default: + t.Fatalf("Accept failed: %s", err) } - default: - t.Fatalf("Accept failed: %s", err) } - } - for i := 0; i < backlog*5; i++ { - doOne(i) - } + for i := 0; i < backlog*5; i++ { + doOne(i) + } + }) } // TestNoSynCookieOnFailedHandshakes tests that failed handshakes clear @@ -1317,66 +1391,69 @@ func TestNoSynCookieWithoutOverflow(t *testing.T) { // list for the accepting endpoint then it will eventually result in a // SYN-COOKIE which we can identify with a SYN-ACK w/ a WS of -1. func TestNoSynCookieOnFailedHandshakes(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1) + c.Create(-1) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - const backlog = 10 - if err := c.EP.Listen(backlog); err != nil { - t.Fatal("Listen failed:", err) - } + const backlog = 10 + if err := c.EP.Listen(backlog); err != nil { + t.Fatal("Listen failed:", err) + } - doOne := func() { - // Send a SYN request. - options := []byte{header.TCPOptionWS, 3, 0, header.TCPOptionNOP} - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - TCPOpts: options, - }) + doOne := func() { + // Send a SYN request. + options := []byte{header.TCPOptionWS, 3, 0, header.TCPOptionNOP} + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + TCPOpts: options, + }) - // Receive the SYN-ACK reply. - v := c.GetPacket() - defer v.Release() - tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) - iss := seqnum.Value(tcpHdr.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - checker.TCPSynOptions(header.TCPSynOptions{ - WS: tcp.FindWndScale(tcp.DefaultReceiveBufferSize), - MSS: c.MSSWithoutOptions(), - }), - } + // Receive the SYN-ACK reply. + v := c.GetPacket() + defer v.Release() + tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) + iss := seqnum.Value(tcpHdr.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs) + 1), + checker.TCPSynOptions(header.TCPSynOptions{ + WS: tcp.FindWndScale(tcp.DefaultReceiveBufferSize), + MSS: c.MSSWithoutOptions(), + }), + } - checker.IPv4(t, v, checker.TCP(tcpCheckers...)) + checker.IPv4(t, v, checker.TCP(tcpCheckers...)) - // Send a RST to abort the handshake. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 0, - }) + // Send a RST to abort the handshake. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 0, + }) - } + } - for i := 0; i < backlog*5; i++ { - doOne() - } + for i := 0; i < backlog*5; i++ { + doOne() + } + }) } // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, @@ -1385,46 +1462,49 @@ func TestNoSynCookieOnFailedHandshakes(t *testing.T) { // // This test uses IPv4. func TestTCPAckBeforeAcceptV4(t *testing.T) { - for _, cookieEnabled := range []tcpip.TCPAlwaysUseSynCookies{false, true} { - t.Run(fmt.Sprintf("syn-cookies enabled: %t", cookieEnabled), func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, cookieEnabled := range []tcpip.TCPAlwaysUseSynCookies{false, true} { + t.Run(fmt.Sprintf("syn-cookies enabled: %t", cookieEnabled), func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &cookieEnabled); err != nil { - panic(fmt.Sprintf("SetTransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, cookieEnabled, err)) - } + if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &cookieEnabled); err != nil { + panic(fmt.Sprintf("SetTransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, cookieEnabled, err)) + } + + c.Create(-1) - c.Create(-1) + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } + irs, iss := executeHandshake(t, c, context.TestPort, bool(cookieEnabled)) - irs, iss := executeHandshake(t, c, context.TestPort, bool(cookieEnabled)) + // Send data before accepting the connection. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) - // Send data before accepting the connection. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, + // Receive ACK for the data we sent. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) }) - - // Receive ACK for the data we sent. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - }) - } + } + }) } // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, @@ -1433,684 +1513,747 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) { // // This test uses IPv6. func TestTCPAckBeforeAcceptV6(t *testing.T) { - for _, cookieEnabled := range []tcpip.TCPAlwaysUseSynCookies{false, true} { - t.Run(fmt.Sprintf("syn-cookies enabled: %t", cookieEnabled), func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, cookieEnabled := range []tcpip.TCPAlwaysUseSynCookies{false, true} { + t.Run(fmt.Sprintf("syn-cookies enabled: %t", cookieEnabled), func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &cookieEnabled); err != nil { - panic(fmt.Sprintf("SetTransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, cookieEnabled, err)) - } - c.CreateV6Endpoint(true) + if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &cookieEnabled); err != nil { + panic(fmt.Sprintf("SetTransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, cookieEnabled, err)) + } + c.CreateV6Endpoint(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } - irs, iss := executeV6Handshake(t, c, context.TestPort, bool(cookieEnabled)) + irs, iss := executeV6Handshake(t, c, context.TestPort, bool(cookieEnabled)) - // Send data before accepting the connection. - c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) + // Send data before accepting the connection. + c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) - // Receive ACK for the data we sent. - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - }) - } + // Receive ACK for the data we sent. + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) + }) + } + }) } func TestSendRstOnListenerRxAckV4(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1 /* epRcvBuf */) + c.Create(-1 /* epRcvBuf */) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.TCPSeqNum(200))) + }) } func TestSendRstOnListenerRxAckV6(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(true /* v6Only */) + c.CreateV6Endpoint(true /* v6Only */) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.TCPSeqNum(200))) + }) } // TestListenShutdown tests for the listening endpoint replying with RST // on read shutdown. func TestListenShutdown(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1 /* epRcvBuf */) + c.Create(-1 /* epRcvBuf */) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatal("Shutdown failed:", err) - } + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: 100, - AckNum: 200, - }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) - // Expect the listening endpoint to reset the connection. + // Expect the listening endpoint to reset the connection. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) + }) } func TestListenerReadinessOnEvent(t *testing.T) { - s := stack.New(stack.Options{ - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - defer s.Destroy() - { - ep := loopback.New() - if testing.Verbose() { - ep = sniffer.New(ep) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + s := stack.New(stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + defer s.Destroy() + { + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + const id = 1 + if err := s.CreateNIC(id, ep); err != nil { + t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(), + } + if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: header.IPv4EmptySubnet, NIC: id}, + }) } - const id = 1 - if err := s.CreateNIC(id, ep); err != nil { - t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) + + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(), + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { + t.Fatalf("Bind(%s): %s", context.StackAddr, err) } - if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err) + const backlog = 1 + if err := ep.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s", backlog, err) } - s.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: id}, - }) - } - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { - t.Fatalf("Bind(%s): %s", context.StackAddr, err) - } - const backlog = 1 - if err := ep.Listen(backlog); err != nil { - t.Fatalf("Listen(%d): %s", backlog, err) - } + address, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress(): %s", err) + } - address, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress(): %s", err) - } + conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer conn.Close() - conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) - } - defer conn.Close() + events := make(chan waiter.EventMask) + // Scope `entry` to allow a binding of the same name below. + { + entry := waiter.NewFunctionEntry(waiter.EventIn, func(mask waiter.EventMask) { + events <- ep.Readiness(mask) + }) + wq.EventRegister(&entry) + defer wq.EventUnregister(&entry) + } - events := make(chan waiter.EventMask) - // Scope `entry` to allow a binding of the same name below. - { - entry := waiter.NewFunctionEntry(waiter.EventIn, func(mask waiter.EventMask) { - events <- ep.Readiness(mask) - }) + entry, ch := waiter.NewChannelEntry(waiter.EventOut) wq.EventRegister(&entry) defer wq.EventUnregister(&entry) - } - entry, ch := waiter.NewChannelEntry(waiter.EventOut) - wq.EventRegister(&entry) - defer wq.EventUnregister(&entry) - - switch err := conn.Connect(address).(type) { - case *tcpip.ErrConnectStarted: - default: - t.Fatalf("Connect(%#v): %v", address, err) - } + switch err := conn.Connect(address).(type) { + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect(%#v): %v", address, err) + } - // Read at least one event. - got := <-events - for { - select { - case event := <-events: - got |= event - continue - case <-ch: - if want := waiter.ReadableEvents; got != want { - t.Errorf("observed events = %b, want %b", got, want) + // Read at least one event. + got := <-events + for { + select { + case event := <-events: + got |= event + continue + case <-ch: + if want := waiter.ReadableEvents; got != want { + t.Errorf("observed events = %b, want %b", got, want) + } } + break } - break - } + }) } // TestListenCloseWhileConnect tests for the listening endpoint to // drain the accept-queue when closed. This should reset all of the // pending connections that are waiting to be accepted. func TestListenCloseWhileConnect(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1 /* epRcvBuf */) + c.Create(-1 /* epRcvBuf */) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&waitEntry) - defer c.WQ.EventUnregister(&waitEntry) + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&waitEntry) + defer c.WQ.EventUnregister(&waitEntry) - executeHandshake(t, c, context.TestPort, true /* synCookiesInUse */) - // Wait for the new endpoint created because of handshake to be delivered - // to the listening endpoint's accept queue. - <-notifyCh + executeHandshake(t, c, context.TestPort, true /* synCookiesInUse */) + // Wait for the new endpoint created because of handshake to be delivered + // to the listening endpoint's accept queue. + <-notifyCh - // Close the listening endpoint. - c.EP.Close() + // Close the listening endpoint. + c.EP.Close() - // Expect the listening endpoint to reset the connection. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) + // Expect the listening endpoint to reset the connection. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) + }) } func TestTOSV4(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { - t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) - } + const tos = 0xC0 + if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) + } - v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) - } + v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) + } - if v != tos { - t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) - } + if v != tos { + t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) + } - e2e.TestV4Connect(t, c, checker.TOS(tos, 0)) + e2e.TestV4Connect(t, c, checker.TOS(tos, 0)) - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Check that data is received. - p := c.GetPacket() - defer p.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, p, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - checker.TOS(tos, 0), - ) + // Check that data is received. + p := c.GetPacket() + defer p.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, p, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + checker.TOS(tos, 0), + ) - if b := p.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, b) { - t.Errorf("got data = %x, want = %x", p.AsSlice(), data) - } + if b := p.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, b) { + t.Errorf("got data = %x, want = %x", p.AsSlice(), data) + } + }) } func TestTrafficClassV6(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(false) + c.CreateV6Endpoint(false) - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { - t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) - } + const tos = 0xC0 + if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { + t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) + } - v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) - } + v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) + } - if v != tos { - t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) - } + if v != tos { + t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) + } - // Test the connection request. - e2e.TestV6Connect(t, c, checker.TOS(tos, 0)) + // Test the connection request. + e2e.TestV6Connect(t, c, checker.TOS(tos, 0)) - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Check that data is received. - b := c.GetV6Packet() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv6(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - checker.TOS(tos, 0), - ) + // Check that data is received. + b := c.GetV6Packet() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv6(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + checker.TOS(tos, 0), + ) - if p := b.AsSlice()[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } + if p := b.AsSlice()[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + }) } func TestConnectBindToDevice(t *testing.T) { - for _, test := range []struct { - name string - device tcpip.NICID - want tcp.EndpointState - }{ - {"RightDevice", 1, tcp.StateEstablished}, - {"WrongDevice", 2, tcp.StateSynSent}, - {"AnyDevice", 0, tcp.StateEstablished}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, test := range []struct { + name string + device tcpip.NICID + want tcp.EndpointState + }{ + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1) - if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) - } - // Start connection attempt. - waitEntry, _ := waiter.NewChannelEntry(waiter.WritableEvents) - c.WQ.EventRegister(&waitEntry) - defer c.WQ.EventUnregister(&waitEntry) + c.Create(-1) + if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) + } + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(waiter.WritableEvents) + c.WQ.EventRegister(&waitEntry) + defer c.WQ.EventUnregister(&waitEntry) - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + } - // Receive SYN packet. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + // Receive SYN packet. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) + } + tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) + iss := seqnum.Value(context.TestInitialSequenceNumber) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) - v = c.GetPacket() - defer v.Release() - if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - }) - } + v = c.GetPacket() + defer v.Release() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) + } + }) + } + }) } func TestShutdownConnectingSocket(t *testing.T) { - for _, test := range []struct { - name string - shutdownMode tcpip.ShutdownFlags - }{ - {"ShutdownRead", tcpip.ShutdownRead}, - {"ShutdownWrite", tcpip.ShutdownWrite}, - {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create an endpoint, don't handshake because we want to interfere with - // the handshake process. - c.Create(-1) - - waitEntry, ch := waiter.NewChannelEntry(waiter.EventHUp) - c.WQ.EventRegister(&waitEntry) - defer c.WQ.EventUnregister(&waitEntry) - - // Start connection attempt. - addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" { - t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, test := range []struct { + name string + shutdownMode tcpip.ShutdownFlags + }{ + {"ShutdownRead", tcpip.ShutdownRead}, + {"ShutdownWrite", tcpip.ShutdownWrite}, + {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Check the SYN packet. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + // Create an endpoint, don't handshake because we want to interfere with + // the handshake process. + c.Create(-1) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } + waitEntry, ch := waiter.NewChannelEntry(waiter.EventHUp) + c.WQ.EventRegister(&waitEntry) + defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Shutdown(test.shutdownMode); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Start connection attempt. + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) + } - // The endpoint internal state is updated immediately. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } + // Check the SYN packet. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) - select { - case <-ch: - default: - t.Fatal("endpoint was not notified") - } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrConnectionReset{}) + if err := c.EP.Shutdown(test.shutdownMode); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - // If the endpoint is not properly shutdown, it'll re-attempt to connect - // by sending another ACK packet. - c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond)) - }) - } + // The endpoint internal state is updated immediately. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + select { + case <-ch: + default: + t.Fatal("endpoint was not notified") + } + + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrConnectionReset{}) + + // If the endpoint is not properly shutdown, it'll re-attempt to connect + // by sending another ACK packet. + c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond)) + }) + } + }) } func TestSynSent(t *testing.T) { - for _, test := range []struct { - name string - reset bool - }{ - {"RstOnSynSent", true}, - {"CloseOnSynSent", false}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create an endpoint, don't handshake because we want to interfere with the - // handshake process. - c.Create(-1) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, test := range []struct { + name string + reset bool + }{ + {"RstOnSynSent", true}, + {"CloseOnSynSent", false}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Start connection attempt. - waitEntry, ch := waiter.NewChannelEntry(waiter.EventHUp) - c.WQ.EventRegister(&waitEntry) - defer c.WQ.EventUnregister(&waitEntry) + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) - addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - err := c.EP.Connect(addr) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) - } + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(waiter.EventHUp) + c.WQ.EventRegister(&waitEntry) + defer c.WQ.EventUnregister(&waitEntry) - // Receive SYN packet. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + err := c.EP.Connect(addr) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) + } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + // Receive SYN packet. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) - if test.reset { - // Send a packet with a proper ACK and a RST flag to cause the socket - // to error and close out. - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagRst | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - } else { - c.EP.Close() - } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(v.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + if test.reset { + // Send a packet with a proper ACK and a RST flag to cause the socket + // to error and close out. + iss := seqnum.Value(context.TestInitialSequenceNumber) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + } else { + c.EP.Close() + } - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatal("timed out waiting for packet to arrive") - } + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } - ept := endpointTester{c.EP} - if test.reset { - ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) - } else { - ept.CheckReadError(t, &tcpip.ErrAborted{}) - } + ept := endpointTester{c.EP} + if test.reset { + ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) + } else { + ept.CheckReadError(t, &tcpip.ErrAborted{}) + } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } - // Due to the RST the endpoint should be in an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - }) - } + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + }) + } + }) } func TestOutOfOrderReceive(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - // Send second half of data first, with seqnum 3 ahead of expected. - data := []byte{1, 2, 3, 4, 5, 6} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(3), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - // Check that we get an ACK specifying which seqnum is expected. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Send second half of data first, with seqnum 3 ahead of expected. + data := []byte{1, 2, 3, 4, 5, 6} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(3), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Wait 200ms and check that no data has been received. - time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + // Check that we get an ACK specifying which seqnum is expected. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - // Send the first 3 bytes now. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) + // Wait 200ms and check that no data has been received. + time.Sleep(200 * time.Millisecond) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + + // Send the first 3 bytes now. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Receive data. - read := ept.CheckReadFull(t, 6, ch, 5*time.Second) + // Receive data. + read := ept.CheckReadFull(t, 6, ch, 5*time.Second) - // Check that we received the data in proper order. - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } + // Check that we received the data in proper order. + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) + } - // Check that the whole data is acknowledged. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Check that the whole data is acknowledged. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + }) } func TestOutOfOrderFlood(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - rcvBufSz := math.MaxUint16 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) + rcvBufSz := math.MaxUint16 + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - // Send 100 packets with seqnum iss + 6 before the actual one that is - // expected. - data := []byte{1, 2, 3, 4, 5, 6} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < 100; i++ { + // Send 100 packets with seqnum iss + 6 before the actual one that is + // expected. + data := []byte{1, 2, 3, 4, 5, 6} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for i := 0; i < 100; i++ { + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(6), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Send packet with seqnum as initial + 3. It won't be discarded + // because the receive window limits the sender to rcvBufSize/2 bytes, + // but we allow (3/4)*rcvBufSize to be used for out-of-order bytes. So + // the sender hasn't filled the buffer and we still have space to + // receive it. c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(6), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(3), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, }) v := c.GetPacket() @@ -2122,1309 +2265,1394 @@ func TestOutOfOrderFlood(t *testing.T) { checker.TCPFlags(header.TCPFlagAck), ), ) - } - // Send packet with seqnum as initial + 3. It won't be discarded - // because the receive window limits the sender to rcvBufSize/2 bytes, - // but we allow (3/4)*rcvBufSize to be used for out-of-order bytes. So - // the sender hasn't filled the buffer and we still have space to - // receive it. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(3), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Now send the expected packet with initial sequence number. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Now send the expected packet with initial sequence number. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, + // Check that all packets are acknowledged. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+9), + checker.TCPFlags(header.TCPFlagAck), + ), + ) }) - - // Check that all packets are acknowledged. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+9), - checker.TCPFlags(header.TCPFlagAck), - ), - ) } func TestRstOnCloseWithUnreadData(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } + data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Check that ACK is received, this happens regardless of the read. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that ACK is received, this happens regardless of the read. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - // Now that we know we have unread data, let's just close the connection - // and verify that netstack sends an RST rather than a FIN. - c.EP.Close() + // Now that we know we have unread data, let's just close the connection + // and verify that netstack sends an RST rather than a FIN. + c.EP.Close() - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // We shouldn't consume a sequence number on RST. + checker.TCPSeqNum(uint32(c.IRS)+1), + )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - // This final ACK should be ignored because an ACK on a reset doesn't mean - // anything. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(len(data))), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, + // This final ACK should be ignored because an ACK on a reset doesn't mean + // anything. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(seqnum.Size(len(data))), + AckNum: c.IRS.Add(seqnum.Size(2)), + RcvWnd: 30000, + }) }) } func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - // Cause a FIN to be generated. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - // Make sure we get the FIN but DON't ACK IT. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Cause a RST to be generated by closing the read end now since we have - // unread data. - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } - // Make sure we get the RST - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence - // number of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + // Check that ACK is received, this happens regardless of the read. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - // The ACK to the FIN should now be rejected since the connection has been - // closed by a RST. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(len(data))), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} + // Cause a FIN to be generated. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } -func TestShutdownRead(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + // Make sure we get the FIN but DON't ACK IT. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + checker.TCPSeqNum(uint32(c.IRS)+1), + )) - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + // Cause a RST to be generated by closing the read end now since we have + // unread data. + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Make sure we get the RST + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. + checker.TCPSeqNum(uint32(c.IRS)+2), + )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) - } + // The ACK to the FIN should now be rejected since the connection has been + // closed by a RST. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(seqnum.Size(len(data))), + AckNum: c.IRS.Add(seqnum.Size(2)), + RcvWnd: 30000, + }) + }) } -func TestFullWindowReceive(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() +func TestShutdownRead(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - const rcvBufSz = 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies - // the provided buffer value by tcp.SegOverheadFactor to calculate the actual - // receive buffer size. - data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) - for i := range data { - data[i] = byte(i % 255) - } - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) + } + }) +} - // Check that data is acknowledged, and window goes to zero. - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) +func TestFullWindowReceive(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Receive data and check it. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } + const rcvBufSz = 10 + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) - } + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - // Check that we get an ACK for the newly non-zero window. - b = c.GetPacket() - defer b.Release() - checker.IPv4(t, b, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(10), - ), - ) -} + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) -func TestSmallReceiveBufferReadiness(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - defer s.Destroy() + // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies + // the provided buffer value by tcp.SegOverheadFactor to calculate the actual + // receive buffer size. + data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) + for i := range data { + data[i] = byte(i % 255) + } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - ep := loopback.New() - if testing.Verbose() { - ep = sniffer.New(ep) - } + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } - const nicID = 1 - nicOpts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) - } + // Check that data is acknowledged, and window goes to zero. + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPWindow(0), + ), + ) - protocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.AddrFromSlice([]byte("\x7f\x00\x00\x01")), - PrefixLen: 32, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err) - } + // Receive data and check it. + v := ept.CheckRead(t) + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } - { - subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice([]byte("\x7f\x00\x00\x00")), tcpip.MaskFrom("\xff\x00\x00\x00")) - if err != nil { - t.Fatalf("tcpip.NewSubnet failed: %s", err) + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) } - s.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: nicID, - }, - }) - } - listenerEntry, listenerCh := waiter.NewChannelEntry(waiter.ReadableEvents) - var listenerWQ waiter.Queue - listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer listener.Close() - listenerWQ.EventRegister(&listenerEntry) - defer listenerWQ.EventUnregister(&listenerEntry) + // Check that we get an ACK for the newly non-zero window. + b = c.GetPacket() + defer b.Release() + checker.IPv4(t, b, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPWindow(10), + ), + ) + }) +} - if err := listener.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := listener.Listen(1); err != nil { - t.Fatalf("Bind failed: %s", err) - } +func TestSmallReceiveBufferReadiness(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + defer s.Destroy() - localAddress, err := listener.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %s", err) - } + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } - for i := 8; i > 0; i /= 2 { - size := int64(i << 12) - t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) { - var clientWQ waiter.Queue - client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer client.Close() - switch err := client.Connect(localAddress).(type) { - case nil: - t.Fatal("Connect returned nil error") - case *tcpip.ErrConnectStarted: - default: - t.Fatalf("Connect failed: %s", err) - } + const nicID = 1 + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice([]byte("\x7f\x00\x00\x01")), + PrefixLen: 32, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err) + } - <-listenerCh - server, serverWQ, err := listener.Accept(nil) + { + subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice([]byte("\x7f\x00\x00\x00")), tcpip.MaskFrom("\xff\x00\x00\x00")) if err != nil { - t.Fatalf("Accept failed: %s", err) + t.Fatalf("tcpip.NewSubnet failed: %s", err) } - defer server.Close() + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + } - client.SocketOptions().SetReceiveBufferSize(size, true) - // Send buffer size doesn't seem to affect this test. - // server.SocketOptions().SetSendBufferSize(size, true) + listenerEntry, listenerCh := waiter.NewChannelEntry(waiter.ReadableEvents) + var listenerWQ waiter.Queue + listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer listener.Close() + listenerWQ.EventRegister(&listenerEntry) + defer listenerWQ.EventUnregister(&listenerEntry) - clientEntry, clientCh := waiter.NewChannelEntry(waiter.ReadableEvents) - clientWQ.EventRegister(&clientEntry) - defer clientWQ.EventUnregister(&clientEntry) + if err := listener.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := listener.Listen(1); err != nil { + t.Fatalf("Bind failed: %s", err) + } - serverEntry, serverCh := waiter.NewChannelEntry(waiter.WritableEvents) - serverWQ.EventRegister(&serverEntry) - defer serverWQ.EventUnregister(&serverEntry) + localAddress, err := listener.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress failed: %s", err) + } - var total int64 - for { - var b [64 << 10]byte - var r bytes.Reader - r.Reset(b[:]) - switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + for i := 8; i > 0; i /= 2 { + size := int64(i << 12) + t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) { + var clientWQ waiter.Queue + client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer client.Close() + switch err := client.Connect(localAddress).(type) { case nil: - t.Logf("wrote %d bytes", n) - total += n - continue - case *tcpip.ErrWouldBlock: - select { - case <-serverCh: + t.Fatal("Connect returned nil error") + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect failed: %s", err) + } + + <-listenerCh + server, serverWQ, err := listener.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + defer server.Close() + + client.SocketOptions().SetReceiveBufferSize(size, true) + // Send buffer size doesn't seem to affect this test. + // server.SocketOptions().SetSendBufferSize(size, true) + + clientEntry, clientCh := waiter.NewChannelEntry(waiter.ReadableEvents) + clientWQ.EventRegister(&clientEntry) + defer clientWQ.EventUnregister(&clientEntry) + + serverEntry, serverCh := waiter.NewChannelEntry(waiter.WritableEvents) + serverWQ.EventRegister(&serverEntry) + defer serverWQ.EventUnregister(&serverEntry) + + var total int64 + for { + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n continue - case <-time.After(100 * time.Millisecond): - // Well and truly full. - t.Logf("send and receive queues are full") + case *tcpip.ErrWouldBlock: + select { + case <-serverCh: + continue + case <-time.After(100 * time.Millisecond): + // Well and truly full. + t.Logf("send and receive queues are full") + } + default: + t.Fatalf("Write failed: %s", err) } - default: - t.Fatalf("Write failed: %s", err) + break } - break - } - t.Logf("wrote %d bytes in total", total) + t.Logf("wrote %d bytes in total", total) - var wg sync.WaitGroup - defer wg.Wait() + var wg sync.WaitGroup + defer wg.Wait() - wg.Add(2) - go func() { - defer wg.Done() + wg.Add(2) + go func() { + defer wg.Done() - var b [64 << 10]byte - var r bytes.Reader - r.Reset(b[:]) - if err := func() error { - var total int64 - defer t.Logf("wrote %d bytes in total", total) - for r.Len() != 0 { - switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { - case nil: - t.Logf("wrote %d bytes", n) - total += n - case *tcpip.ErrWouldBlock: - for { - t.Logf("waiting on server") - select { - case <-serverCh: - case <-time.After(time.Second): - if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 { - t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness) + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + if err := func() error { + var total int64 + defer t.Logf("wrote %d bytes in total", total) + for r.Len() != 0 { + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on server") + select { + case <-serverCh: + case <-time.After(time.Second): + if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 { + t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness) + } + continue } - continue + break } - break + default: + return fmt.Errorf("server.Write failed: %s", err) } - default: - return fmt.Errorf("server.Write failed: %s", err) } + if err := server.Shutdown(tcpip.ShutdownWrite); err != nil { + return fmt.Errorf("server.Shutdown failed: %s", err) + } + t.Logf("server end shutdown done") + return nil + }(); err != nil { + t.Error(err) } - if err := server.Shutdown(tcpip.ShutdownWrite); err != nil { - return fmt.Errorf("server.Shutdown failed: %s", err) - } - t.Logf("server end shutdown done") - return nil - }(); err != nil { - t.Error(err) - } - }() + }() - go func() { - defer wg.Done() + go func() { + defer wg.Done() - if err := func() error { - total := 0 - defer t.Logf("read %d bytes in total", total) - for { - switch res, err := client.Read(io.Discard, tcpip.ReadOptions{}); err.(type) { - case nil: - t.Logf("read %d bytes", res.Count) - total += res.Count - t.Logf("read total %d bytes till now", total) - case *tcpip.ErrClosedForReceive: - return nil - case *tcpip.ErrWouldBlock: - for { - t.Logf("waiting on client") - select { - case <-clientCh: - case <-time.After(time.Second): - if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 { - return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness) + if err := func() error { + total := 0 + defer t.Logf("read %d bytes in total", total) + for { + switch res, err := client.Read(io.Discard, tcpip.ReadOptions{}); err.(type) { + case nil: + t.Logf("read %d bytes", res.Count) + total += res.Count + t.Logf("read total %d bytes till now", total) + case *tcpip.ErrClosedForReceive: + return nil + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on client") + select { + case <-clientCh: + case <-time.After(time.Second): + if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 { + return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness) + } + continue } - continue + break } - break + default: + return fmt.Errorf("client.Write failed: %s", err) } - default: - return fmt.Errorf("client.Write failed: %s", err) } + }(); err != nil { + t.Error(err) } - }(); err != nil { - t.Error(err) - } - }() - }) - } + }() + }) + } + }) } // Test the stack receive window advertisement on receiving segments smaller than // segment overhead. It tests for the right edge of the window to not grow when // the endpoint is not being read from. func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize, - Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), - } - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize, + Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), + } + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } - c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}) + c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}) - // Bump up the receive buffer size such that, when the receive window grows, - // the scaled window exceeds maxUint16. - c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */) + // Bump up the receive buffer size such that, when the receive window grows, + // the scaled window exceeds maxUint16. + c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */) - // Keep the payload size < segment overhead and such that it is a multiple - // of the window scaled value. This enables the test to perform equality - // checks on the incoming receive window. - payloadSize := 1 << c.RcvdWindowScale - if payloadSize >= tcp.SegOverheadSize { - t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegOverheadSize) - } - payload := generateRandomPayload(t, payloadSize) - payloadLen := seqnum.Size(len(payload)) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Keep the payload size < segment overhead and such that it is a multiple + // of the window scaled value. This enables the test to perform equality + // checks on the incoming receive window. + payloadSize := 1 << c.RcvdWindowScale + if payloadSize >= tcp.SegOverheadSize { + t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegOverheadSize) + } + payload := generateRandomPayload(t, payloadSize) + payloadLen := seqnum.Size(len(payload)) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Send payload to the endpoint and return the advertised receive window - // from the endpoint. - getIncomingRcvWnd := func() uint32 { - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss, - AckNum: c.IRS.Add(1), - Flags: header.TCPFlagAck, - RcvWnd: 30000, - }) - iss = iss.Add(payloadLen) + // Send payload to the endpoint and return the advertised receive window + // from the endpoint. + getIncomingRcvWnd := func() uint32 { + c.SendPacket(payload, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss, + AckNum: c.IRS.Add(1), + Flags: header.TCPFlagAck, + RcvWnd: 30000, + }) + iss = iss.Add(payloadLen) - pkt := c.GetPacket() - defer pkt.Release() - return uint32(header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize()) << c.RcvdWindowScale - } + pkt := c.GetPacket() + defer pkt.Release() + return uint32(header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize()) << c.RcvdWindowScale + } - // Read the advertised receive window with the ACK for payload. - rcvWnd := getIncomingRcvWnd() + // Read the advertised receive window with the ACK for payload. + rcvWnd := getIncomingRcvWnd() - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } - // Read the data so that the subsequent ACK from the endpoint - // grows the right edge of the window. - var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { - t.Fatalf("c.EP.Read: %s", err) - } + // Read the data so that the subsequent ACK from the endpoint + // grows the right edge of the window. + var buf bytes.Buffer + if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { + t.Fatalf("c.EP.Read: %s", err) + } - // Check if we have received max uint16 as our advertised - // scaled window now after a read above. - maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) - if got, want := getIncomingRcvWnd(), maxRcv; got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } + // Check if we have received max uint16 as our advertised + // scaled window now after a read above. + maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) + if got, want := getIncomingRcvWnd(), maxRcv; got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + }) } func TestNoWindowShrinking(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Start off with a certain receive buffer then cut it in half and verify that + // the right edge of the window does not shrink. + // NOTE: Netstack doubles the value specified here. + rcvBufSize := 65536 + // Enable window scaling with a scale of zero from our end. + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) - // Start off with a certain receive buffer then cut it in half and verify that - // the right edge of the window does not shrink. - // NOTE: Netstack doubles the value specified here. - rcvBufSize := 65536 - // Enable window scaling with a scale of zero from our end. - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) + + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + // Send a 1 byte payload so that we can record the current receive window. + // Send a payload of half the size of rcvBufSize. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + payload := []byte{1} + c.SendPacket(payload, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } - // Send a 1 byte payload so that we can record the current receive window. - // Send a payload of half the size of rcvBufSize. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - payload := []byte{1} - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } + // Read the 1 byte payload we just sent. + if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { + t.Fatalf("got data: %v, want: %v", got, want) + } - // Read the 1 byte payload we just sent. - if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { - t.Fatalf("got data: %v, want: %v", got, want) - } + // Verify that the ACK does not shrink the window. + pkt := c.GetPacket() + defer pkt.Release() + iss = iss.Add(1) + checker.IPv4(t, pkt, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + // Stash the initial window. + initialWnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize() << c.RcvdWindowScale + initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) + // Now shrink the receive buffer to half its original size. + c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */) + + data := generateRandomPayload(t, rcvBufSize) + // Send a payload of half the size of rcvBufSize. + c.SendPacket(data[:rcvBufSize/2], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + iss = iss.Add(seqnum.Size(rcvBufSize / 2)) - // Verify that the ACK does not shrink the window. - pkt := c.GetPacket() - defer pkt.Release() - iss = iss.Add(1) - checker.IPv4(t, pkt, - checker.TCP( + // Verify that the ACK does not shrink the window. + pkt = c.GetPacket() + defer pkt.Release() + checker.IPv4(t, pkt, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + newWnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize() << c.RcvdWindowScale + newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd)) + if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { + t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) + } + + // Send another payload of half the size of rcvBufSize. This should fill up the + // socket receive buffer and we should see a zero window. + c.SendPacket(data[rcvBufSize/2:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + iss = iss.Add(seqnum.Size(rcvBufSize / 2)) + + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), + checker.TCPWindow(0), ), - ) - // Stash the initial window. - initialWnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize() << c.RcvdWindowScale - initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) - // Now shrink the receive buffer to half its original size. - c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */) - - data := generateRandomPayload(t, rcvBufSize) - // Send a payload of half the size of rcvBufSize. - c.SendPacket(data[:rcvBufSize/2], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - iss = iss.Add(seqnum.Size(rcvBufSize / 2)) - - // Verify that the ACK does not shrink the window. - pkt = c.GetPacket() - defer pkt.Release() - checker.IPv4(t, pkt, - checker.TCP( + ) + + // Receive data and check it. + read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) + } + + // Check that we get an ACK for the newly non-zero window, which is the new + // receive buffer size we set after the connection was established. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), + checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), - ) - newWnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize() << c.RcvdWindowScale - newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd)) - if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { - t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) - } - - // Send another payload of half the size of rcvBufSize. This should fill up the - // socket receive buffer and we should see a zero window. - c.SendPacket(data[rcvBufSize/2:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, + ) }) - iss = iss.Add(seqnum.Size(rcvBufSize / 2)) - - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) - - // Receive data and check it. - read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that we get an ACK for the newly non-zero window, which is the new - // receive buffer size we set after the connection was established. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), - ), - ) } func TestSimpleSend(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Check that data is received. - b := c.GetPacket() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + // Check that data is received. + b := c.GetPacket() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) - if p := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } + if p := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) + } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) }) } func TestZeroWindowSend(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */) - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Check if we got a zero-window probe. - b := c.GetPacket() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize+1), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + // Check if we got a zero-window probe. + b := c.GetPacket() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize+1), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) - // Open up the window. Data should be received now. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) + // Open up the window. Data should be received now. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Check that data is received. - b = c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + // Check that data is received. + b = c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) - if p := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } + if p := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) + } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) }) } func TestScaledWindowConnect(t *testing.T) { + synctest. // This test ensures that window scaling is used when the peer // does advertise it and connection is established with Connect(). - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Set the window size greater than the maximum non-scaled window. + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received, and that advertised window is 0x5fff, + // that is, that it is scaled. + b := c.GetPacket() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPWindow(0x5fff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + }) } func TestNonScaledWindowConnect(t *testing.T) { + synctest. // This test ensures that window scaling is not used when the peer // doesn't advertise it and connection is established with Connect(). - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3) + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Set the window size greater than the maximum non-scaled window. + c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3) - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPWindow(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + }) } func TestScaledWindowAccept(t *testing.T) { + synctest. // This test ensures that window scaling is used when the peer // does advertise it and connection is established with Accept(). - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + // Set the window size greater than the maximum non-scaled window. + ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Do 3-way handshake. - // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 - c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}, 0 /* delay */) + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + // Do 3-way handshake. + // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}, 0 /* delay */) - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + // Check that data is received, and that advertised window is 0x5fff, + // that is, that it is scaled. + b := c.GetPacket() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPWindow(0x5fff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + }) } func TestNonScaledWindowAccept(t *testing.T) { + synctest. // This test ensures that window scaling is not used when the peer // doesn't advertise it and connection is established with Accept(). - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + // Set the window size greater than the maximum non-scaled window. + ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN - // should not carry the window scaling option. - c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}) + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN + // should not carry the window scaling option. + c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}) - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPWindow(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + }) } func TestZeroScaledWindowReceive(t *testing.T) { + synctest. // This test ensures that the endpoint sends a non-zero window size // advertisement when the scaled window transitions from 0 to non-zero, // but the actual window (not scaled) hasn't gotten to zero. - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - // Set the buffer size such that a window scale of 5 will be used. - const bufSz = 65535 * 10 - const ws = uint32(5) - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Write chunks of 50000 bytes. - remain := 0 - sent := 0 - data := make([]byte, 50000) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Keep writing till the window drops below len(data). - for { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, + // Set the buffer size such that a window scale of 5 will be used. + const bufSz = 65535 * 10 + const ws = uint32(5) + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) - sent += len(data) - pkt := c.GetPacket() - defer pkt.Release() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Don't reduce window to zero here. - if wnd := int(header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize()); wnd<= 16 { + data = data[:remain-15] + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(seqnum.Size(sent)), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + pkt := c.GetPacket() + defer pkt.Release() + checker.IPv4(t, pkt, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(sent)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + // Since the receive buffer is split between window advertisement and + // application data buffer the window does not always reflect the space + // available and actual space available can be a bit more than what is + // advertised in the window. + wnd := int(header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize()) + if wnd == 0 { + break + } remain = wnd << ws - break } - } - // Make the window non-zero, but the scaled window zero. - for remain >= 16 { - data = data[:remain-15] - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - pkt := c.GetPacket() - defer pkt.Release() - checker.IPv4(t, pkt, + // Read at least 2MSS of data. An ack should be sent in response to that. + // Since buffer space is now split in half between window and application + // data we need to read more than 1 MSS(65536) of data for a non-zero window + // update to be sent. For 1MSS worth of window to be available we need to + // read at least 128KB. Since our segments above were 50KB each it means + // we need to read at 3 packets. + w := tcpip.LimitedWriter{ + W: io.Discard, + N: e2e.DefaultMTU * 2, + } + for w.N != 0 { + res, err := c.EP.Read(&w, tcpip.ReadOptions{}) + t.Logf("err=%v res=%#v", err, res) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + } + + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)+uint32(sent)), + checker.TCPWindowGreaterThanEq(uint16(e2e.DefaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Since the receive buffer is split between window advertisement and - // application data buffer the window does not always reflect the space - // available and actual space available can be a bit more than what is - // advertised in the window. - wnd := int(header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize()) - if wnd == 0 { - break - } - remain = wnd << ws - } - - // Read at least 2MSS of data. An ack should be sent in response to that. - // Since buffer space is now split in half between window and application - // data we need to read more than 1 MSS(65536) of data for a non-zero window - // update to be sent. For 1MSS worth of window to be available we need to - // read at least 128KB. Since our segments above were 50KB each it means - // we need to read at 3 packets. - w := tcpip.LimitedWriter{ - W: io.Discard, - N: e2e.DefaultMTU * 2, - } - for w.N != 0 { - res, err := c.EP.Read(&w, tcpip.ReadOptions{}) - t.Logf("err=%v res=%#v", err, res) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindowGreaterThanEq(uint16(e2e.DefaultMTU>>ws)), - checker.TCPFlags(header.TCPFlagAck), - )) + )) + }) } func TestSegmentMerging(t *testing.T) { - tests := []struct { - name string - stop func(tcpip.Endpoint) - resume func(tcpip.Endpoint) - }{ - { - "stop work", - func(ep tcpip.Endpoint) { - ep.(interface{ StopWork() }).StopWork() - }, - func(ep tcpip.Endpoint) { - ep.(interface{ ResumeWork() }).ResumeWork() - }, - }, - { - "cork", - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(true) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + tests := []struct { + name string + stop func(tcpip.Endpoint) + resume func(tcpip.Endpoint) + }{ + { + "stop work", + func(ep tcpip.Endpoint) { + ep.(interface{ StopWork() }).StopWork() + }, + func(ep tcpip.Endpoint) { + ep.(interface{ ResumeWork() }).ResumeWork() + }, }, - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(false) + { + "cork", + func(ep tcpip.Endpoint) { + ep.SocketOptions().SetCorkOption(true) + }, + func(ep tcpip.Endpoint) { + ep.SocketOptions().SetCorkOption(false) + }, }, - }, - } + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + // Send tcp.InitialCwnd number of segments to fill up + // InitialWindow but don't ACK. That should prevent + // anymore packets from going out. + var r bytes.Reader + for i := 0; i < tcp.InitialCwnd; i++ { + r.Reset([]byte{0}) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } - // Send tcp.InitialCwnd number of segments to fill up - // InitialWindow but don't ACK. That should prevent - // anymore packets from going out. - var r bytes.Reader - for i := 0; i < tcp.InitialCwnd; i++ { - r.Reset([]byte{0}) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) + // Now send the segments that should get merged as the congestion + // window is full and we won't be able to send any more packets. + var allData []byte + for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + allData = append(allData, data...) + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } } - } - // Now send the segments that should get merged as the congestion - // window is full and we won't be able to send any more packets. - var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) + // Check that we get tcp.InitialCwnd packets. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for i := 0; i < tcp.InitialCwnd; i++ { + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize+1), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) } - } - // Check that we get tcp.InitialCwnd packets. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < tcp.InitialCwnd; i++ { + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. + RcvWnd: 30000, + }) + + // Check that data is received. b := c.GetPacket() defer b.Release() checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize+1), + checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), + checker.TCPSeqNum(uint32(c.IRS)+11), checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) - } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. - RcvWnd: 30000, + if got := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { + t.Fatalf("got data = %v, want = %v", got, allData) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), + RcvWnd: 30000, + }) }) + } + }) +} + +func TestDelay(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + + c.EP.SocketOptions().SetDelayOption(true) + + var allData []byte + for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + allData = append(allData, data...) + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } + seq := c.IRS.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for _, want := range [][]byte{allData[:1], allData[1:]} { // Check that data is received. b := c.GetPacket() defer b.Release() checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), + checker.PayloadLen(len(want)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+11), + checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) - if got := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) + if got := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { + t.Fatalf("got data = %v, want = %v", got, want) } + seq = seq.Add(seqnum.Size(len(want))) // Acknowledge the data. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), - RcvWnd: 30000, + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seq, + RcvWnd: 30000, }) - }) - } + } + }) } -func TestDelay(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() +func TestUndelay(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - c.EP.SocketOptions().SetDelayOption(true) + c.EP.SocketOptions().SetDelayOption(true) - var allData []byte - for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) + allData := [][]byte{{0}, {1, 2, 3}} + for i, data := range allData { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } } - } - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for _, want := range [][]byte{allData[:1], allData[1:]} { + seq := c.IRS.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) // Check that data is received. - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.PayloadLen(len(want)+header.TCPMinimumSize), + first := c.GetPacket() + defer first.Release() + checker.IPv4(t, first, + checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), @@ -3433,840 +3661,826 @@ func TestDelay(t *testing.T) { ), ) - if got := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { - t.Fatalf("got data = %v, want = %v", got, want) + if got, want := first.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { + t.Fatalf("got first packet's data = %v, want = %v", got, want) } - seq = seq.Add(seqnum.Size(len(want))) - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) - } -} + seq = seq.Add(seqnum.Size(len(allData[0]))) -func TestUndelay(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + // Check that we don't get the second packet yet. + c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.EP.SocketOptions().SetDelayOption(false) - c.EP.SocketOptions().SetDelayOption(true) + // Check that data is received. + second := c.GetPacket() + defer second.Release() + checker.IPv4(t, second, + checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) - allData := [][]byte{{0}, {1, 2, 3}} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) + if got, want := second.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { + t.Fatalf("got second packet's data = %v, want = %v", got, want) } - } - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Check that data is received. - first := c.GetPacket() - defer first.Release() - checker.IPv4(t, first, - checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + seq = seq.Add(seqnum.Size(len(allData[1]))) - if got, want := first.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { - t.Fatalf("got first packet's data = %v, want = %v", got, want) - } + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seq, + RcvWnd: 30000, + }) + }) +} - seq = seq.Add(seqnum.Size(len(allData[0]))) +func TestMSSNotDelayed(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + tests := []struct { + name string + fn func(tcpip.Endpoint) + }{ + {"no-op", func(tcpip.Endpoint) {}}, + {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, + } - // Check that we don't get the second packet yet. - c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const maxPayload = 100 + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.EP.SocketOptions().SetDelayOption(false) + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) - // Check that data is received. - second := c.GetPacket() - defer second.Release() - checker.IPv4(t, second, - checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + test.fn(c.EP) - if got, want := second.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { - t.Fatalf("got second packet's data = %v, want = %v", got, want) - } + allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} + for i, data := range allData { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } - seq = seq.Add(seqnum.Size(len(allData[1]))) + seq := c.IRS.Add(1) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for i, data := range allData { + // Check that data is received. + packet := c.GetPacket() + defer packet.Release() + checker.IPv4(t, packet, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + + if got, want := packet.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { + t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) + } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, + seq = seq.Add(seqnum.Size(len(data))) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seq, + RcvWnd: 30000, + }) + }) + } }) } -func TestMSSNotDelayed(t *testing.T) { - tests := []struct { - name string - fn func(tcpip.Endpoint) - }{ - {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, - } +func TestSendGreaterThanMTU(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + e2e.CheckBrokenUpWrite(t, c, maxPayload) + }) +} - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const maxPayload = 100 - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() +func TestDefaultTTL(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, test := range []struct { + name string + protoNum tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + {"ipv4", ipv4.ProtocolNumber, context.TestAddr}, + {"ipv6", ipv6.ProtocolNumber, context.TestV6Addr}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, test.protoNum, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - test.fn(c.EP) + proto := c.Stack().NetworkProtocolInstance(test.protoNum) + if proto == nil { + t.Fatalf("c.s.NetworkProtocolInstance(flow.netProto()) did not return a protocol") + } - allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) + var initialDefaultTTL tcpip.DefaultTTLOption + if err := proto.Option(&initialDefaultTTL); err != nil { + t.Fatalf("proto.Option(&initialDefaultTTL) (%T) failed: %s", initialDefaultTTL, err) } - } - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i, data := range allData { - // Check that data is received. - packet := c.GetPacket() - defer packet.Release() - checker.IPv4(t, packet, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got, want := packet.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { - t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) - } - - seq = seq.Add(seqnum.Size(len(data))) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) - }) - } -} - -func TestSendGreaterThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - e2e.CheckBrokenUpWrite(t, c, maxPayload) -} - -func TestDefaultTTL(t *testing.T) { - for _, test := range []struct { - name string - protoNum tcpip.NetworkProtocolNumber - addr tcpip.Address - }{ - {"ipv4", ipv4.ProtocolNumber, context.TestAddr}, - {"ipv6", ipv6.ProtocolNumber, context.TestV6Addr}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, 65535) - defer c.Cleanup() - - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, test.protoNum, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - proto := c.Stack().NetworkProtocolInstance(test.protoNum) - if proto == nil { - t.Fatalf("c.s.NetworkProtocolInstance(flow.netProto()) did not return a protocol") - } - - var initialDefaultTTL tcpip.DefaultTTLOption - if err := proto.Option(&initialDefaultTTL); err != nil { - t.Fatalf("proto.Option(&initialDefaultTTL) (%T) failed: %s", initialDefaultTTL, err) - } - - { - err := c.EP.Connect(tcpip.FullAddress{Addr: test.addr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: test.addr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - checkTTL := func(ttl uint8) { - if test.protoNum == ipv4.ProtocolNumber { - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TTL(ttl)) - } else { - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, checker.TTL(ttl)) + checkTTL := func(ttl uint8) { + if test.protoNum == ipv4.ProtocolNumber { + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TTL(ttl)) + } else { + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, checker.TTL(ttl)) + } } - } - // Receive SYN packet. - checkTTL(uint8(initialDefaultTTL)) + // Receive SYN packet. + checkTTL(uint8(initialDefaultTTL)) - newDefaultTTL := tcpip.DefaultTTLOption(initialDefaultTTL + 1) - if err := proto.SetOption(&newDefaultTTL); err != nil { - t.Fatalf("proto.SetOption(&%T(%d))) failed: %s", newDefaultTTL, newDefaultTTL, err) - } + newDefaultTTL := tcpip.DefaultTTLOption(initialDefaultTTL + 1) + if err := proto.SetOption(&newDefaultTTL); err != nil { + t.Fatalf("proto.SetOption(&%T(%d))) failed: %s", newDefaultTTL, newDefaultTTL, err) + } - // Receive retransmitted SYN packet. - checkTTL(uint8(newDefaultTTL)) - }) - } + // Receive retransmitted SYN packet. + checkTTL(uint8(newDefaultTTL)) + }) + } + }) } func TestSetTTL(t *testing.T) { - for _, test := range []struct { - name string - protoNum tcpip.NetworkProtocolNumber - addr tcpip.Address - relevantOpt tcpip.SockOptInt - irrelevantOpt tcpip.SockOptInt - }{ - {"ipv4", ipv4.ProtocolNumber, context.TestAddr, tcpip.IPv4TTLOption, tcpip.IPv6HopLimitOption}, - {"ipv6", ipv6.ProtocolNumber, context.TestV6Addr, tcpip.IPv6HopLimitOption, tcpip.IPv4TTLOption}, - } { - t.Run(test.name, func(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := context.New(t, 65535) - defer c.Cleanup() - - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, test.protoNum, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, test := range []struct { + name string + protoNum tcpip.NetworkProtocolNumber + addr tcpip.Address + relevantOpt tcpip.SockOptInt + irrelevantOpt tcpip.SockOptInt + }{ + {"ipv4", ipv4.ProtocolNumber, context.TestAddr, tcpip.IPv4TTLOption, tcpip.IPv6HopLimitOption}, + {"ipv6", ipv6.ProtocolNumber, context.TestV6Addr, tcpip.IPv6HopLimitOption, tcpip.IPv4TTLOption}, + } { + t.Run(test.name, func(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() + + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, test.protoNum, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - if err := c.EP.SetSockOptInt(test.relevantOpt, int(wantTTL)); err != nil { - t.Fatalf("SetSockOptInt(%d, %d) failed: %s", test.relevantOpt, wantTTL, err) - } - // Set a different ttl/hoplimit for the unused protocol, showing that - // it does not affect the other protocol. - if err := c.EP.SetSockOptInt(test.irrelevantOpt, int(wantTTL+1)); err != nil { - t.Fatalf("SetSockOptInt(%d, %d) failed: %s", test.irrelevantOpt, wantTTL, err) - } + if err := c.EP.SetSockOptInt(test.relevantOpt, int(wantTTL)); err != nil { + t.Fatalf("SetSockOptInt(%d, %d) failed: %s", test.relevantOpt, wantTTL, err) + } + // Set a different ttl/hoplimit for the unused protocol, showing that + // it does not affect the other protocol. + if err := c.EP.SetSockOptInt(test.irrelevantOpt, int(wantTTL+1)); err != nil { + t.Fatalf("SetSockOptInt(%d, %d) failed: %s", test.irrelevantOpt, wantTTL, err) + } - { - err := c.EP.Connect(tcpip.FullAddress{Addr: test.addr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: test.addr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - // Receive SYN packet. - if test.protoNum == ipv4.ProtocolNumber { - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TTL(wantTTL)) - } else { - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, checker.TTL(wantTTL)) - } - }) - } - }) - } + // Receive SYN packet. + if test.protoNum == ipv4.ProtocolNumber { + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TTL(wantTTL)) + } else { + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, checker.TTL(wantTTL)) + } + }) + } + }) + } + }) } func TestSendMSSLessThanOptionsSize(t *testing.T) { - const mss = 10 - const writeSize = 300 - c := context.New(t, 65535) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const mss = 10 + const writeSize = 300 + c := context.New(t, 65535) + defer c.Cleanup() + + // The sizes of these options add up to 12. + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(mss / 256), byte(mss % 256), + header.TCPOptionTS, header.TCPOptionTSLength, 1, 2, 3, 4, 5, 6, 7, 8, + header.TCPOptionSACKPermitted, header.TCPOptionSackPermittedLength, + }) + e2e.CheckBrokenUpWrite(t, c, writeSize) - // The sizes of these options add up to 12. - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(mss / 256), byte(mss % 256), - header.TCPOptionTS, header.TCPOptionTSLength, 1, 2, 3, 4, 5, 6, 7, 8, - header.TCPOptionSACKPermitted, header.TCPOptionSackPermittedLength, + var r bytes.Reader + r.Reset(make([]byte, writeSize)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } }) - e2e.CheckBrokenUpWrite(t, c, writeSize) - - var r bytes.Reader - r.Reset(make([]byte, writeSize)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } } func TestActiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, 65535) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + c := context.New(t, 65535) + defer c.Cleanup() + + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) + e2e.CheckBrokenUpWrite(t, c, maxPayload) }) - e2e.CheckBrokenUpWrite(t, c, maxPayload) } func TestPassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } - } - // Check that data gets properly segmented. - e2e.CheckBrokenUpWrite(t, c, maxPayload) + // Check that data gets properly segmented. + e2e.CheckBrokenUpWrite(t, c, maxPayload) + }) } func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 536 - const mtu = 2000 - c := context.New(t, mtu) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 536 + const mtu = 2000 + c := context.New(t, mtu) + defer c.Cleanup() + + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } - } - // Check that data gets properly segmented. - e2e.CheckBrokenUpWrite(t, c, maxPayload) + // Check that data gets properly segmented. + e2e.CheckBrokenUpWrite(t, c, maxPayload) + }) } func TestSynOptionsOnActiveConnect(t *testing.T) { - const mtu = 1400 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const mtu = 1400 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 3 - c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + const wndScale = 3 + c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) - // Start connection attempt. - we, ch := waiter.NewChannelEntry(waiter.WritableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + // Start connection attempt. + we, ch := waiter.NewChannelEntry(waiter.WritableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - // Receive SYN packet. - b := c.GetPacket() - defer b.Release() - mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(t, b, - checker.TCP( + // Receive SYN packet. + b := c.GetPacket() + defer b.Release() + mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), + ), + ) + + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + // Wait for retransmit. + time.Sleep(1 * time.Second) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), + checker.SrcPort(tcpHdr.SourcePort()), + checker.TCPSeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), - ) - - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Wait for retransmit. - time.Sleep(1 * time.Second) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcpHdr.SourcePort()), - checker.TCPSeqNum(tcpHdr.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) + ) - // Send SYN-ACK. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) + // Send SYN-ACK. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Receive ACK packet. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) + // Receive ACK packet. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), + ), + ) - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) + // Wait for connection to be established. + select { + case <-ch: + if err := c.EP.LastError(); err != nil { + t.Fatalf("Connect failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } + }) } func TestCloseListener(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Close the listener and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } + // Close the listener and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %s", diff) + } + }) } func TestReceiveOnResetConnection(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Try to read. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - -loop: - for { - switch _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}); err.(type) { - case *tcpip.ErrWouldBlock: - <-ch - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) - } - break loop - case *tcpip.ErrConnectionReset: - break loop - default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) - } - } + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if tcp.EndpointState(c.EP.State()) != tcp.StateError { - t.Fatalf("got EP state is not StateError") - } + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + // Send RST segment. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Try to read. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - checkValid := func() []error { - var errors []error - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) + loop: + for { + switch _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}); err.(type) { + case *tcpip.ErrWouldBlock: + <-ch + // Expect the state to be StateError and subsequent Reads to fail with HardError. + _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) + } + break loop + case *tcpip.ErrConnectionReset: + break loop + default: + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + } } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) + + if tcp.EndpointState(c.EP.State()) != tcp.StateError { + t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) + + checkValid := func() []error { + var errors []error + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) + } + return errors } - return errors - } - start := time.Now() - for time.Since(start) < time.Minute && len(checkValid()) > 0 { - time.Sleep(50 * time.Millisecond) - } - for _, err := range checkValid() { - t.Error(err) - } + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) + } + for _, err := range checkValid() { + t.Error(err) + } + }) } func TestSendOnResetConnection(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - // Send RST segment. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - RcvWnd: 30000, - }) + // Send RST segment. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: iss, + RcvWnd: 30000, + }) - // Wait for the RST to be received. - time.Sleep(1 * time.Second) + // Wait for the RST to be received. + time.Sleep(1 * time.Second) - // Try to write. - var r bytes.Reader - r.Reset(make([]byte, 10)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) - } + // Try to write. + var r bytes.Reader + r.Reset(make([]byte, 10)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) + } + }) } // TestMaxRetransmitsTimeout tests if the connection is timed out after // a segment has been retransmitted MaxRetries times. func TestMaxRetransmitsTimeout(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + const numRetries = 2 + opt := tcpip.TCPMaxRetriesOption(numRetries) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + } - const numRetries = 2 - opt := tcpip.TCPMaxRetriesOption(numRetries) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } + // Wait for the connection to timeout after MaxRetries retransmits. + initRTO := time.Second + minRTOOpt := tcpip.TCPMinRTOOption(initRTO) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - // Wait for the connection to timeout after MaxRetries retransmits. - initRTO := time.Second - minRTOOpt := tcpip.TCPMinRTOOption(initRTO) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) - } - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) + c.WQ.EventRegister(&waitEntry) + defer c.WQ.EventUnregister(&waitEntry) + + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) - c.WQ.EventRegister(&waitEntry) - defer c.WQ.EventUnregister(&waitEntry) + // Expect first transmit and MaxRetries retransmits. + for i := 0; i < numRetries+1; i++ { + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + )) + } + select { + case <-notifyCh: + case <-time.After((2 << numRetries) * initRTO): + t.Fatalf("connection still alive after maximum retransmits.\n") + } - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } + // Send an ACK and expect a RST as the connection would have been closed. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + }) - // Expect first transmit and MaxRetries retransmits. - for i := 0; i < numRetries+1; i++ { v := c.GetPacket() defer v.Release() checker.IPv4(t, v, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + checker.TCPFlags(header.TCPFlagRst), )) - } - select { - case <-notifyCh: - case <-time.After((2 << numRetries) * initRTO): - t.Fatalf("connection still alive after maximum retransmits.\n") - } - // Send an ACK and expect a RST as the connection would have been closed. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } }) - - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - )) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } } // TestMaxRTO tests if the retransmit interval caps to MaxRTO. func TestMaxRTO(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - rto := 1 * time.Second - minRTOOpt := tcpip.TCPMinRTOOption(rto / 2) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) - } - maxRTOOpt := tcpip.TCPMaxRTOOption(rto) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + rto := 1 * time.Second + minRTOOpt := tcpip.TCPMinRTOOption(rto / 2) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } + maxRTOOpt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err) + } - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - )) - const numRetransmits = 2 - for i := 0; i < numRetransmits; i++ { - start := time.Now() v := c.GetPacket() defer v.Release() checker.IPv4(t, v, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), )) - if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() { - newRto := float64(rto / time.Millisecond) - if i == 0 { - newRto /= 2 - } - curRto := float64(elapsed.Round(time.Millisecond).Milliseconds()) - if math.Abs(newRto-curRto) > 10 { - t.Errorf("Retransmit interval not capped to RTO(%v). %v", newRto, curRto) + const numRetransmits = 2 + for i := 0; i < numRetransmits; i++ { + start := time.Now() + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + )) + if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() { + newRto := float64(rto / time.Millisecond) + if i == 0 { + newRto /= 2 + } + curRto := float64(elapsed.Round(time.Millisecond).Milliseconds()) + if math.Abs(newRto-curRto) > 10 { + t.Errorf("Retransmit interval not capped to RTO(%v). %v", newRto, curRto) + } } } - } + }) } // TestZeroSizedWriteRetransmit tests that a zero sized write should not // result in a panic on an RTO as no segment should have been queued for // a zero sized write. func TestZeroSizedWriteRetransmit(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - var r bytes.Reader - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - // Now do a non-zero sized write to trigger actual sending of data. - r.Reset(make([]byte, 1)) - _, err = c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - // Do not ACK the packet and expect an original transmit and a - // retransmit. This should not cause a panic. - for i := 0; i < 2; i++ { - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - )) - } + var r bytes.Reader + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + // Now do a non-zero sized write to trigger actual sending of data. + r.Reset(make([]byte, 1)) + _, err = c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + // Do not ACK the packet and expect an original transmit and a + // retransmit. This should not cause a panic. + for i := 0; i < 2; i++ { + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + )) + } + }) } // TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is // unique on retransmits. func TestRetransmitIPv4IDUniqueness(t *testing.T) { - for _, tc := range []struct { - name string - size int - }{ - {"1Byte", 1}, - {"512Bytes", 512}, - } { - t.Run(tc.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - minRTOOpt := tcpip.TCPMinRTOOption(time.Second) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) - } - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - // Disabling PMTU discovery causes all packets sent from this socket to - // have DF=0. This needs to be done because the IPv4 ID uniqueness - // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 - // Section 4, and datagrams with DF=0 are non-atomic. - if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, int(tcpip.PMTUDiscoveryDont)); err != nil { - t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) - } + minRTOOpt := tcpip.TCPMinRTOOption(time.Second) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, int(tcpip.PMTUDiscoveryDont)); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } - var r bytes.Reader - r.Reset(make([]byte, tc.size)) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - pkt := c.GetPacket() - defer pkt.Release() - checker.IPv4(t, pkt, - checker.FragmentFlags(0), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - idSet := map[uint16]struct{}{header.IPv4(pkt.AsSlice()).ID(): {}} - // Expect two retransmitted packets, and that all packets received have - // unique IPv4 ID values. - for i := 0; i <= 2; i++ { + var r bytes.Reader + r.Reset(make([]byte, tc.size)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } pkt := c.GetPacket() defer pkt.Release() checker.IPv4(t, pkt, @@ -4276,541 +4490,576 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) - id := header.IPv4(pkt.AsSlice()).ID() - if _, exists := idSet[id]; exists { - t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + idSet := map[uint16]struct{}{header.IPv4(pkt.AsSlice()).ID(): {}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + defer pkt.Release() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + id := header.IPv4(pkt.AsSlice()).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} } - idSet[id] = struct{}{} - } - }) - } + }) + } + }) } func TestFinImmediately(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // Check that the stack acks the FIN. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Check that the stack acks the FIN. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + }) } func TestFinRetransmit(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - // Don't acknowledge yet. We should get a retransmit of the FIN. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) - // Check that the stack acks the FIN. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Don't acknowledge yet. We should get a retransmit of the FIN. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + }) } func TestFinWithNoPendingData(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - // Write something out, and have it acknowledged. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Write something out, and have it acknowledged. + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) + next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + next += uint32(len(view)) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - // Shutdown, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Shutdown, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - // Check that the stack acks the FIN. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Check that the stack acks the FIN. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + }) } func TestFinWithPendingDataCwndFull(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - // Write enough segments to fill the congestion window before ACK'ing - // any of them. - view := make([]byte, 10) - var r bytes.Reader - for i := tcp.InitialCwnd; i > 0; i-- { - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) + // Write enough segments to fill the congestion window before ACK'ing + // any of them. + view := make([]byte, 10) + var r bytes.Reader + for i := tcp.InitialCwnd; i > 0; i-- { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + } + + next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for i := tcp.InitialCwnd; i > 0; i-- { + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + next += uint32(len(view)) + } + + // Shutdown the connection, check that the FIN segment isn't sent + // because the congestion window doesn't allow it. Wait until a + // retransmit is received. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) } - } - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := tcp.InitialCwnd; i > 0; i-- { v := c.GetPacket() defer v.Release() checker.IPv4(t, v, checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), + checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)), checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) - next += uint32(len(view)) - } - - // Shutdown the connection, check that the FIN segment isn't sent - // because the congestion window doesn't allow it. Wait until a - // retransmit is received. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + // Send the ACK that will allow the FIN to be sent as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - // Send the ACK that will allow the FIN to be sent as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) }) - - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) } func TestFinWithPendingData(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - // Write something out, and acknowledge it to get cwnd to 2. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Write something out, and acknowledge it to get cwnd to 2. + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) + next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + next += uint32(len(view)) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Write new data, but don't acknowledge it. + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + next += uint32(len(view)) - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + }) } func TestFinWithPartialAck(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Write something out, and acknowledge it to get cwnd to 2. Also send - // FIN from the test side. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) + // Write something out, and acknowledge it to get cwnd to 2. Also send + // FIN from the test side. + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Check that we get an ACK for the fin. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + next += uint32(len(view)) - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) + // Check that we get an ACK for the fin. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Write new data, but don't acknowledge it. + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + next += uint32(len(view)) - // Send an ACK for the data, but not for the FIN yet. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: seqnum.Value(next - 1), - RcvWnd: 30000, - }) + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ - // Check that we don't get a retransmit of the FIN. - c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) + // Send an ACK for the data, but not for the FIN yet. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: seqnum.Value(next - 1), + RcvWnd: 30000, + }) - // Ack the FIN. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss.Add(1), - AckNum: seqnum.Value(next), - RcvWnd: 30000, + // Check that we don't get a retransmit of the FIN. + c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) + + // Ack the FIN. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss.Add(1), + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) }) } func TestUpdateListenBacklog(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Update the backlog with another Listen() on the same endpoint. - if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %s", err) - } + // Update the backlog with another Listen() on the same endpoint. + if err := ep.Listen(20); err != nil { + t.Fatalf("Listen failed to update backlog: %s", err) + } - ep.Close() + ep.Close() + }) } func scaledSendWindow(t *testing.T, scale uint8) { @@ -4829,12 +5078,12 @@ func scaledSendWindow(t *testing.T, scale uint8) { // Open up the window with a scaled value. iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 1, + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 1, }) // Send some data. Check that it's capped by the window size. @@ -4860,458 +5109,488 @@ func scaledSendWindow(t *testing.T, scale uint8) { // Reset the connection to free resources. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: iss, }) } func TestScaledSendWindow(t *testing.T) { - for scale := uint8(0); scale <= 14; scale++ { - scaledSendWindow(t, scale) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + for scale := uint8(0); scale <= 14; scale++ { + scaledSendWindow(t, scale) + } + }) } func TestReceivedValidSegmentCountIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ValidSegmentsReceived.Value() + 1 + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.ValidSegmentsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) - } - // Ensure there were no errors during handshake. If these stats have - // incremented, then the connection should not have been established. - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) - } + if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { + t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) + } + // Ensure there were no errors during handshake. If these stats have + // incremented, then the connection should not have been established. + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) + } + }) } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.InvalidSegmentsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - buf := c.BuildSegment(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - defer buf.Release() - tcpbuf := buf.Flatten() - tcpbuf[header.IPv4MinimumSize+header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 - - segbuf := buffer.MakeWithData(tcpbuf) - c.SendSegment(segbuf) - - if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.InvalidSegmentsReceived.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + buf := c.BuildSegment(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + defer buf.Release() + tcpbuf := buf.Flatten() + tcpbuf[header.IPv4MinimumSize+header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 + + segbuf := buffer.MakeWithData(tcpbuf) + c.SendSegment(segbuf) + + if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { + t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) } func TestReceivedIncorrectChecksumIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ChecksumErrors.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - buf := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - defer buf.Release() - tcpbuf := buf.Flatten() - // Overwrite a byte in the payload which should cause checksum - // verification to fail. - tcpbuf[header.IPv4MinimumSize+((tcpbuf[header.IPv4MinimumSize+header.TCPDataOffset]>>4)*4)] = 0x4 - - segbuf := buffer.MakeWithData(tcpbuf) - defer segbuf.Release() - c.SendSegment(buffer.MakeWithData(tcpbuf)) - - if got := stats.TCP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.ChecksumErrors.Value() + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + buf := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + defer buf.Release() + tcpbuf := buf.Flatten() + // Overwrite a byte in the payload which should cause checksum + // verification to fail. + tcpbuf[header.IPv4MinimumSize+((tcpbuf[header.IPv4MinimumSize+header.TCPDataOffset]>>4)*4)] = 0x4 + + segbuf := buffer.MakeWithData(tcpbuf) + defer segbuf.Release() + c.SendSegment(buffer.MakeWithData(tcpbuf)) + + if got := stats.TCP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) + } + }) } func TestReceivedSegmentQueuing(t *testing.T) { + synctest. // This test sends 200 segments containing a few bytes each to an // endpoint and checks that they're all received and acknowledged by // the endpoint, that is, that none of the segments are dropped by // internal queues. - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Send 200 segments. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - data := []byte{1, 2, 3} - for i := 0; i < 200; i++ { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(i * len(data))), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - } + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - // Receive ACKs for all segments. - last := iss.Add(seqnum.Size(200 * len(data))) - for { - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - ack := seqnum.Value(tcpHdr.AckNumber()) - if ack == last { - break + // Send 200 segments. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + data := []byte{1, 2, 3} + for i := 0; i < 200; i++ { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(seqnum.Size(i * len(data))), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) } - if last.LessThan(ack) { - t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) + // Receive ACKs for all segments. + last := iss.Add(seqnum.Size(200 * len(data))) + for { + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + ack := seqnum.Value(tcpHdr.AckNumber()) + if ack == last { + break + } + + if last.LessThan(ack) { + t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) + } } - } + }) } func TestReadAfterClosedState(t *testing.T) { + synctest. // This test ensures that calling Read() or Peek() after the endpoint // has transitioned to closedState still works if there is pending // data. To transition to stateClosed without calling Close(), we must // shutdown the send path and the peer must send its own FIN. - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } + Test(t, func(t *testing.T) { + defer synctest.Wait() + + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + } - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - // Shutdown immediately for write, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } + // Shutdown immediately for write, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - // Send some data and acknowledge the FIN. - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) + // Send some data and acknowledge the FIN. + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // Check that ACK is received. - b = c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+uint32(len(data))+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + // Check that ACK is received. + b = c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(iss)+uint32(len(data))+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - // Give the stack the chance to transition to closed state from - // TIME_WAIT. - time.Sleep(tcpTimeWaitTimeout * 2) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } - // Check that peek works. - var peekBuf bytes.Buffer - res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) - if err != nil { - t.Fatalf("Peek failed: %s", err) - } + // Check that peek works. + var peekBuf bytes.Buffer + res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) + if err != nil { + t.Fatalf("Peek failed: %s", err) + } - if got, want := res.Count, len(data); got != want { - t.Fatalf("res.Count = %d, want %d", got, want) - } - if !bytes.Equal(data, peekBuf.Bytes()) { - t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) - } + if got, want := res.Count, len(data); got != want { + t.Fatalf("res.Count = %d, want %d", got, want) + } + if !bytes.Equal(data, peekBuf.Bytes()) { + t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) + } - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } + // Receive data. + v := ept.CheckRead(t) + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } - // Now that we drained the queue, check that functions fail with the - // right error code. - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var buf bytes.Buffer - { - _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { - t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) + // Now that we drained the queue, check that functions fail with the + // right error code. + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) + var buf bytes.Buffer + { + _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) + if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { + t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) + } } - } + }) } func TestReusePort(t *testing.T) { + synctest. // This test ensures that ports are immediately available for reuse // after Close on the endpoints using them returns. - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - // First case, just an endpoint that was bound. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() + // First case, just an endpoint that was bound. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + c.EP.SocketOptions().SetReuseAddress(true) + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Second case, an endpoint that was bound and is connecting.. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + c.EP.Close() + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) } - } - c.EP.Close() + c.EP.SocketOptions().SetReuseAddress(true) + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() + // Second case, an endpoint that was bound and is connecting.. + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + c.EP.SocketOptions().SetReuseAddress(true) + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) + } + } + c.EP.Close() - // Third case, an endpoint that was bound and is listening. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - c.EP.Close() + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + c.EP.SocketOptions().SetReuseAddress(true) + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Third case, an endpoint that was bound and is listening. + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + c.EP.SocketOptions().SetReuseAddress(true) + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + c.EP.Close() + + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + c.EP.SocketOptions().SetReuseAddress(true) + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + }) } func TestTimeWaitAssassination(t *testing.T) { - var wg sync.WaitGroup - defer wg.Wait() - // We need to run this test lots of times because it triggers a very rare race - // condition in segment processing. - initalTestPort := 1024 - testRuns := 25 - for port := initalTestPort; port < initalTestPort+testRuns; port++ { - wg.Add(1) - go func(port uint16) { - defer wg.Done() - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - twReuse := tcpip.TCPTimeWaitReuseOption(tcpip.TCPTimeWaitReuseGlobal) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Errorf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + var wg sync.WaitGroup + defer wg.Wait() + // We need to run this test lots of times because it triggers a very rare race + // condition in segment processing. + initalTestPort := 1024 + testRuns := 25 + for port := initalTestPort; port < initalTestPort+testRuns; port++ { + wg.Add(1) + go func(port uint16) { + defer wg.Done() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - if err := c.Stack().SetPortRange(port, port); err != nil { - t.Errorf("got s.SetPortRange(%d, %d) = %s, want = nil", port, port, err) - } + twReuse := tcpip.TCPTimeWaitReuseOption(tcpip.TCPTimeWaitReuseGlobal) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Errorf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) + } - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1) - c.EP.Close() + if err := c.Stack().SetPortRange(port, port); err != nil { + t.Errorf("got s.SetPortRange(%d, %d) = %s, want = nil", port, port, err) + } - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(port), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1) + c.EP.Close() - c.SendPacket(nil, finHeaders) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(port), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } - // c.EP is in TIME_WAIT. We must allow for a second to pass before the - // new endpoint is allowed to take over the old endpoint's binding. - time.Sleep(time.Second) + c.SendPacket(nil, finHeaders) - seq := iss + 1 - ack := c.IRS + 2 + // c.EP is in TIME_WAIT. We must allow for a second to pass before the + // new endpoint is allowed to take over the old endpoint's binding. + time.Sleep(time.Second) - var wg sync.WaitGroup - defer wg.Wait() + seq := iss + 1 + ack := c.IRS + 2 - wg.Add(1) - go func() { - defer wg.Done() - // The new endpoint will take over the binding. - c.Create(-1) - timeout := time.After(5 * time.Second) - connect: - for { - select { - case <-timeout: - break connect - default: - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - // It can take some extra time for the port to be available. - if _, ok := err.(*tcpip.ErrNoPortAvailable); ok { - continue connect - } - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Errorf("Unexpected return value from Connect: %v", err) - } - break connect - } - } - }() + var wg sync.WaitGroup + defer wg.Wait() - // If the new endpoint does not properly transition to connecting before - // taking over the port reservation, sending acks will cause the processor - // to panic 1-5% of the time. - for i := 0; i < 5; i++ { wg.Add(1) go func() { defer wg.Done() - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: ack, - }) + // The new endpoint will take over the binding. + c.Create(-1) + timeout := time.After(5 * time.Second) + connect: + for { + select { + case <-timeout: + break connect + default: + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + // It can take some extra time for the port to be available. + if _, ok := err.(*tcpip.ErrNoPortAvailable); ok { + continue connect + } + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Errorf("Unexpected return value from Connect: %v", err) + } + break connect + } + } }() - } - }(uint16(port)) - } + + // If the new endpoint does not properly transition to connecting before + // taking over the port reservation, sending acks will cause the processor + // to panic 1-5% of the time. + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: ack, + }) + }() + } + }(uint16(port)) + } + }) } func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { @@ -5332,116 +5611,122 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { } func TestDefaultBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - defer s.Destroy() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + defer s.Destroy() - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer func() { - if ep != nil { - ep.Close() + // Check the default values. + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) } - }() + defer func() { + if ep != nil { + ep.Close() + } + }() - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - // Change the default send buffer size. - { - opt := tcpip.TCPSendBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + // Change the default send buffer size. + { + opt := tcpip.TCPSendBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - } - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } + ep.Close() + ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - // Change the default receive buffer size. - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + // Change the default receive buffer size. + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - } - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } + ep.Close() + ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) + }) } func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) - - defer s.Destroy() - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) + + defer s.Destroy() + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + defer ep.Close() - if err := s.CreateNIC(321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %s", err) - } + if err := s.CreateNIC(321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %s", err) + } - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } + // nicIDPtr is used instead of taking the address of NICID literals, which is + // a compiler error. + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { + return &s + } - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError tcpip.Error - getBindToDevice int32 - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := int32(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) + testActions := []struct { + name string + setBindToDevice *tcpip.NICID + setBindToDeviceError tcpip.Error + getBindToDevice int32 + }{ + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) + } } - } - bindToDevice := ep.SocketOptions().GetBindToDevice() - if bindToDevice != testAction.getBindToDevice { - t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) - } - }) - } + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { + t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) + } + }) + } + }) } func makeStack() (*stack.Stack, tcpip.Error) { @@ -5450,7 +5735,7 @@ func makeStack() (*stack.Stack, tcpip.Error) { ipv4.NewProtocol, ipv6.NewProtocol, }, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) id := loopback.New() @@ -5463,15 +5748,15 @@ func makeStack() (*stack.Stack, tcpip.Error) { } for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - addrWithPrefix tcpip.AddressWithPrefix + number tcpip.NetworkProtocolNumber + addrWithPrefix tcpip.AddressWithPrefix }{ {ipv4.ProtocolNumber, context.StackAddrWithPrefix}, {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix}, } { protocolAddr := tcpip.ProtocolAddress{ - Protocol: ct.number, - AddressWithPrefix: ct.addrWithPrefix, + Protocol: ct.number, + AddressWithPrefix: ct.addrWithPrefix, } if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { return nil, err @@ -5480,12 +5765,12 @@ func makeStack() (*stack.Stack, tcpip.Error) { s.SetRouteTable([]tcpip.Route{ { - Destination: header.IPv4EmptySubnet, - NIC: 1, + Destination: header.IPv4EmptySubnet, + NIC: 1, }, { - Destination: header.IPv6EmptySubnet, - NIC: 1, + Destination: header.IPv6EmptySubnet, + NIC: 1, }, }) @@ -5493,623 +5778,654 @@ func makeStack() (*stack.Stack, tcpip.Error) { } func TestSelfConnect(t *testing.T) { + synctest. // This test ensures that intentional self-connects work. In particular, // it checks that if an endpoint binds to say 127.0.0.1:1000 then // connects to 127.0.0.1:1000, then it will be connected to itself, and // is able to send and receive data through the same endpoint. - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - defer s.Destroy() + Test(t, func(t *testing.T) { + defer synctest.Wait() - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + defer s.Destroy() - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) - wq.EventRegister(&waitEntry) - defer wq.EventUnregister(&waitEntry) + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Register for notification, then start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) + wq.EventRegister(&waitEntry) + defer wq.EventUnregister(&waitEntry) - { - err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } } - } - <-notifyCh - if err := ep.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) - } + <-notifyCh + if err := ep.LastError(); err != nil { + t.Fatalf("Connect failed: %s", err) + } - // Write something. - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Write something. + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // Read back what was written. - wq.EventUnregister(&waitEntry) - waitEntry, notifyCh = waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&waitEntry) - ept := endpointTester{ep} - rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) + // Read back what was written. + wq.EventUnregister(&waitEntry) + waitEntry, notifyCh = waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&waitEntry) + ept := endpointTester{ep} + rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) - if !bytes.Equal(data, rd) { - t.Fatalf("got data = %v, want = %v", rd, data) - } + if !bytes.Equal(data, rd) { + t.Fatalf("got data = %v, want = %v", rd, data) + } + }) } func TestConnectAvoidsBoundPorts(t *testing.T) { - addressTypes := func(t *testing.T, network string) []string { - switch network { - case "ipv4": - return []string{"v4"} - case "ipv6": - return []string{"v6"} - case "dual": - return []string{"v6", "mapped"} - default: - t.Fatalf("unknown network: '%s'", network) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + addressTypes := func(t *testing.T, network string) []string { + switch network { + case "ipv4": + return []string{"v4"} + case "ipv6": + return []string{"v6"} + case "dual": + return []string{"v6", "mapped"} + default: + t.Fatalf("unknown network: '%s'", network) + } - panic("unreachable") - } + panic("unreachable") + } - address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { - switch addressType { - case "v4": - if isAny { - return tcpip.Address{} - } - return context.StackAddr - case "v6": - if isAny { - return tcpip.Address{} - } - return context.StackV6Addr - case "mapped": - if isAny { - return context.V4MappedWildcardAddr + address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { + switch addressType { + case "v4": + if isAny { + return tcpip.Address{} + } + return context.StackAddr + case "v6": + if isAny { + return tcpip.Address{} + } + return context.StackV6Addr + case "mapped": + if isAny { + return context.V4MappedWildcardAddr + } + return context.StackV4MappedAddr + default: + t.Fatalf("unknown address type: '%s'", addressType) } - return context.StackV4MappedAddr - default: - t.Fatalf("unknown address type: '%s'", addressType) + + panic("unreachable") } + // This test ensures that Endpoint.Connect doesn't select already-bound ports. + networks := []string{"ipv4", "ipv6", "dual"} + for _, exhaustedNetwork := range networks { + t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { + for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { + t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { + for _, isAny := range []bool{false, true} { + t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { + for _, candidateNetwork := range networks { + t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { + for _, candidateAddressType := range addressTypes(t, candidateNetwork) { + t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + defer s.Destroy() - panic("unreachable") - } - // This test ensures that Endpoint.Connect doesn't select already-bound ports. - networks := []string{"ipv4", "ipv6", "dual"} - for _, exhaustedNetwork := range networks { - t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { - for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { - t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { - for _, isAny := range []bool{false, true} { - t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { - for _, candidateNetwork := range networks { - t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { - for _, candidateAddressType := range addressTypes(t, candidateNetwork) { - t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - defer s.Destroy() - - var wq waiter.Queue - var eps []tcpip.Endpoint - defer func() { - for _, ep := range eps { - ep.Close() + var wq waiter.Queue + var eps []tcpip.Endpoint + defer func() { + for _, ep := range eps { + ep.Close() + } + }() + makeEP := func(network string) tcpip.Endpoint { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch network { + case "ipv4": + networkProtocolNumber = ipv4.ProtocolNumber + case "ipv6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatalf("unknown network: '%s'", network) + } + ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + eps = append(eps, ep) + switch network { + case "ipv4": + case "ipv6": + ep.SocketOptions().SetV6Only(true) + case "dual": + ep.SocketOptions().SetV6Only(false) + default: + t.Fatalf("unknown network: '%s'", network) + } + return ep } - }() - makeEP := func(network string) tcpip.Endpoint { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch network { - case "ipv4": - networkProtocolNumber = ipv4.ProtocolNumber - case "ipv6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber + + var v4reserved, v6reserved bool + switch exhaustedAddressType { + case "v4", "mapped": + v4reserved = true + case "v6": + v6reserved = true + // Dual stack sockets bound to v6 any reserve on v4 as + // well. + if isAny { + switch exhaustedNetwork { + case "ipv6": + case "dual": + v4reserved = true + default: + t.Fatalf("unknown address type: '%s'", exhaustedNetwork) + } + } default: - t.Fatalf("unknown network: '%s'", network) - } - ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) + t.Fatalf("unknown address type: '%s'", exhaustedAddressType) } - eps = append(eps, ep) - switch network { - case "ipv4": - case "ipv6": - ep.SocketOptions().SetV6Only(true) - case "dual": - ep.SocketOptions().SetV6Only(false) + var collides bool + switch candidateAddressType { + case "v4", "mapped": + collides = v4reserved + case "v6": + collides = v6reserved default: - t.Fatalf("unknown network: '%s'", network) + t.Fatalf("unknown address type: '%s'", candidateAddressType) } - return ep - } - - var v4reserved, v6reserved bool - switch exhaustedAddressType { - case "v4", "mapped": - v4reserved = true - case "v6": - v6reserved = true - // Dual stack sockets bound to v6 any reserve on v4 as - // well. - if isAny { - switch exhaustedNetwork { - case "ipv6": - case "dual": - v4reserved = true - default: - t.Fatalf("unknown address type: '%s'", exhaustedNetwork) + + const ( + start = 16000 + end = 16050 + ) + if err := s.SetPortRange(start, end); err != nil { + t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) + } + for i := start; i <= end; i++ { + if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { + t.Fatalf("Bind(%d) failed: %s", i, err) } } - default: - t.Fatalf("unknown address type: '%s'", exhaustedAddressType) - } - var collides bool - switch candidateAddressType { - case "v4", "mapped": - collides = v4reserved - case "v6": - collides = v6reserved - default: - t.Fatalf("unknown address type: '%s'", candidateAddressType) - } - - const ( - start = 16000 - end = 16050 - ) - if err := s.SetPortRange(start, end); err != nil { - t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) - } - for i := start; i <= end; i++ { - if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %s", i, err) + var want tcpip.Error = &tcpip.ErrConnectStarted{} + if collides { + want = &tcpip.ErrNoPortAvailable{} } - } - var want tcpip.Error = &tcpip.ErrConnectStarted{} - if collides { - want = &tcpip.ErrNoPortAvailable{} - } - if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) - } - }) - } - }) - } - }) - } - }) - } - }) - } + if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { + t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) + } + }) + } + }) + } + }) + } + }) + } + }) + } + }) } func TestPathMTUDiscovery(t *testing.T) { + synctest. // This test verifies the stack retransmits packets after it receives an // ICMP packet indicating that the path MTU has been exceeded. - c := context.New(t, 1500) - defer c.Cleanup() - - // Create new connection with MSS of 1460. - const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - // Send 3200 bytes of data. - const writeSize = 3200 - data := make([]byte, writeSize) - for i := range data { - data[i] = byte(i) - } - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) *buffer.View { - var ret *buffer.View - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i, size := range sizes { - p := c.GetPacket() - if i == which { - ret = p - } else { - defer p.Release() - } - checker.IPv4(t, p, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(seqNum), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - seqNum += uint32(size) - } - return ret - } - - // Receive three packets. - sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} - first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) - defer first.Release() + Test(t, func(t *testing.T) { + defer synctest.Wait() - // Send "packet too big" messages back to netstack. - const newMTU = 1200 - const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize - mtu := buffer.NewViewWithData([]byte{0, 0, newMTU / 256, newMTU % 256}) - defer mtu.Release() - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) + c := context.New(t, 1500) + defer c.Cleanup() - // See retransmitted packets. None exceeding the new max. - sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} - receivePackets(c, sizes, -1, uint32(c.IRS)+1) -} + // Create new connection with MSS of 1460. + const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize + c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) -func TestTCPEndpointProbe(t *testing.T) { - invoked := make(chan struct{}) - var port uint16 - probe := func(state *tcp.TCPEndpointState) { - // Validate that the endpoint ID is what we expect. - // - // We don't do an extensive validation of every field but a - // basic sanity test. - if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("got LocalAddress: %q, want: %q", got, want) + // Send 3200 bytes of data. + const writeSize = 3200 + data := make([]byte, writeSize) + for i := range data { + data[i] = byte(i) } - if got, want := state.ID.LocalPort, port; got != want { - t.Fatalf("got LocalPort: %d, want: %d", got, want) - } - if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("got RemoteAddress: %q, want: %q", got, want) - } - if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("got RemotePort: %d, want: %d", got, want) + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - invoked <- struct{}{} - } - - c := context.NewWithProbe(t, 1500, probe) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - port = c.Port // c.Port is set during CreateConnected. + receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) *buffer.View { + var ret *buffer.View + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for i, size := range sizes { + p := c.GetPacket() + if i == which { + ret = p + } else { + defer p.Release() + } + checker.IPv4(t, p, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(seqNum), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + seqNum += uint32(size) + } + return ret + } - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-invoked: - case <-time.After(100 * time.Millisecond): - t.Fatalf("TCP Probe function was not called") - } + // Receive three packets. + sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} + first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) + defer first.Release() + + // Send "packet too big" messages back to netstack. + const newMTU = 1200 + const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize + mtu := buffer.NewViewWithData([]byte{0, 0, newMTU / 256, newMTU % 256}) + defer mtu.Release() + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) + + // See retransmitted packets. None exceeding the new max. + sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} + receivePackets(c, sizes, -1, uint32(c.IRS)+1) + }) } -func TestStackSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - var oldCC tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) +func TestTCPEndpointProbe(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + invoked := make(chan struct{}) + var port uint16 + probe := func(state *tcp.TCPEndpointState) { + // Validate that the endpoint ID is what we expect. + // + // We don't do an extensive validation of every field but a + // basic sanity test. + if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { + t.Fatalf("got LocalAddress: %q, want: %q", got, want) } - - var cc tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + if got, want := state.ID.LocalPort, port; got != want { + t.Fatalf("got LocalPort: %d, want: %d", got, want) } - - got, want := cc, oldCC - // If SetTransportProtocolOption is expected to succeed - // then the returned value for congestion control should - // match the one specified in the - // SetTransportProtocolOption call above, else it should - // be what it was before the call to - // SetTransportProtocolOption. - if tc.err == nil { - want = tc.cc + if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { + t.Fatalf("got RemoteAddress: %q, want: %q", got, want) } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) + if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { + t.Fatalf("got RemotePort: %d, want: %d", got, want) } - }) - } -} - -func TestStackAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - // Query permitted congestion control algorithms. - var aCC tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) - } - if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) - } -} + invoked <- struct{}{} + } -func TestStackSetAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() + c := context.NewWithProbe(t, 1500, probe) + defer c.Cleanup() - s := c.Stack() + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + port = c.Port // c.Port is set during CreateConnected. - // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.TCPAvailableCongestionControlOption("xyz") - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) - } + data := []byte{1, 2, 3} + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) - // Verify that we still get the expected list of congestion control options. - var cc tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) - } - if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) - } + select { + case <-invoked: + case <-time.After(100 * time.Millisecond): + t.Fatalf("TCP Probe function was not called") + } + }) } -func TestEndpointSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } +func TestStackSetCongestionControl(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + testCases := []struct { + cc tcpip.CongestionControlOption + err tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, + } - for _, connected := range []bool{false, true} { for _, tc := range testCases { - t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { + t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { c := context.New(t, 1500) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + defer c.Cleanup() - var oldCC tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) - } + s := c.Stack() - if connected { - c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil) + var oldCC tcpip.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } - if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { - t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } got, want := cc, oldCC - // If SetSockOpt is expected to succeed then the - // returned value for congestion control should match - // the one specified in the SetSockOpt above, else it - // should be what it was before the call to SetSockOpt. + // If SetTransportProtocolOption is expected to succeed + // then the returned value for congestion control should + // match the one specified in the + // SetTransportProtocolOption call above, else it should + // be what it was before the call to + // SetTransportProtocolOption. if tc.err == nil { want = tc.cc } if got != want { - t.Fatalf("got congestion control = %+v, want = %+v", got, want) + t.Fatalf("got congestion control: %v, want: %v", got, want) } }) } - } + }) +} + +func TestStackAvailableCongestionControl(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + // Query permitted congestion control algorithms. + var aCC tcpip.TCPAvailableCongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) + } + if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) + } + }) +} + +func TestStackSetAvailableCongestionControl(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + // Setting AvailableCongestionControlOption should fail. + aCC := tcpip.TCPAvailableCongestionControlOption("xyz") + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) + } + + // Verify that we still get the expected list of congestion control options. + var cc tcpip.TCPAvailableCongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { + t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) + } + if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) + } + }) +} + +func TestEndpointSetCongestionControl(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + testCases := []struct { + cc tcpip.CongestionControlOption + err tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, + } + + for _, connected := range []bool{false, true} { + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + var oldCC tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&oldCC); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) + } + + if connected { + c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil) + } + + if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { + t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&cc); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) + } + + got, want := cc, oldCC + // If SetSockOpt is expected to succeed then the + // returned value for congestion control should match + // the one specified in the SetSockOpt above, else it + // should be what it was before the call to SetSockOpt. + if tc.err == nil { + want = tc.cc + } + if got != want { + t.Fatalf("got congestion control = %+v, want = %+v", got, want) + } + }) + } + } + }) } func TestKeepalive(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + + const keepAliveIdle = 100 * time.Millisecond + const keepAliveInterval = 3 * time.Second + keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) + } + keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) + } + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) + } + c.EP.SocketOptions().SetKeepAlive(true) + + // 5 unacked keepalives are sent. ACK each one, and check that the + // connection stays alive after 5. + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for i := 0; i < 10; i++ { + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + // Acknowledge the keepalive. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS, + RcvWnd: 30000, + }) + } - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) - } - keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) - } - c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) + // Check that the connection is still alive. + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - // 5 unacked keepalives are sent. ACK each one, and check that the - // connection stays alive after 5. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < 10; i++ { + // Send some data and wait before ACKing it. Keepalives should be disabled + // during this period. + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 b := c.GetPacket() defer b.Release() checker.IPv4(t, b, + checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) - // Acknowledge the keepalive. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS, - RcvWnd: 30000, - }) - } - - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send some data and wait before ACKing it. Keepalives should be disabled - // during this period. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Wait for the packet to be retransmitted. Verify that no keepalives - // were sent. - b = c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - - next += uint32(len(view)) - - // Send ACK. Keepalives should start sending again. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Now receive 5 keepalives, but don't ACK them. The connection - // should be reset after 5. - for i := 0; i < 5; i++ { + // Wait for the packet to be retransmitted. Verify that no keepalives + // were sent. b = c.GetPacket() defer b.Release() checker.IPv4(t, b, + checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(next-1), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), ), ) - } + c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - // Sleep for a little over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) + next += uint32(len(view)) - // The connection should be terminated after 5 unacked keepalives. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) + // Send ACK. Keepalives should start sending again. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), - ) + // Now receive 5 keepalives, but don't ACK them. The connection + // should be reset after 5. + for i := 0; i < 5; i++ { + b = c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next-1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } + // Sleep for a little over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + keepAliveInterval/2) - ept.CheckReadError(t, &tcpip.ErrTimeout{}) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), + ) + + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) + } + + ept.CheckReadError(t, &tcpip.ErrTimeout{}) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } + }) } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -6119,12 +6435,12 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki options := []byte{header.TCPOptionWS, 3, 0, header.TCPOptionNOP} irs = seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - TCPOpts: options, + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + TCPOpts: options, }) // Receive the SYN-ACK reply. @@ -6142,13 +6458,13 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki if synCookieInUse { // When cookies are in use window scaling is disabled. tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptions(), + WS: -1, + MSS: c.MSSWithoutOptions(), })) } else { tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: tcp.FindWndScale(tcp.DefaultReceiveBufferSize), - MSS: c.MSSWithoutOptions(), + WS: tcp.FindWndScale(tcp.DefaultReceiveBufferSize), + MSS: c.MSSWithoutOptions(), })) } @@ -6156,12 +6472,12 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki // Send ACK. c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, }) return irs, iss } @@ -6173,12 +6489,12 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo options := []byte{header.TCPOptionWS, 3, 0, header.TCPOptionNOP} irs = seqnum.Value(context.TestInitialSequenceNumber) c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - TCPOpts: options, + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + TCPOpts: options, }) // Receive the SYN-ACK reply. @@ -6196,13 +6512,13 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo if synCookieInUse { // When cookies are in use window scaling is disabled. tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptionsV6(), + WS: -1, + MSS: c.MSSWithoutOptionsV6(), })) } else { tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: tcp.FindWndScale(tcp.DefaultReceiveBufferSize), - MSS: c.MSSWithoutOptionsV6(), + WS: tcp.FindWndScale(tcp.DefaultReceiveBufferSize), + MSS: c.MSSWithoutOptionsV6(), })) } @@ -6210,12 +6526,12 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo // Send ACK. c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, }) return irs, iss } @@ -6223,58 +6539,89 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo // TestListenBacklogFull tests that netstack does not complete handshakes if the // listen backlog for the endpoint is full. func TestListenBacklogFull(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + // Test acceptance. + // Start listening. + listenBacklog := 10 + if err := c.EP.Listen(listenBacklog); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Test acceptance. - // Start listening. - listenBacklog := 10 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %s", err) - } + lastPortOffset := uint16(0) + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) + } - lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - } + time.Sleep(50 * time.Millisecond) - time.Sleep(50 * time.Millisecond) + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + lastPortOffset, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(context.TestInitialSequenceNumber), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + lastPortOffset, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + for i := 0; i < listenBacklog; i++ { + _, _, err = c.EP.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + } - for i := 0; i < listenBacklog; i++ { + // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + select { + case <-ch: + t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) + case <-time.After(1 * time.Second): + } + } + + // Now a new handshake must succeed. + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) + + newEP, _, err := c.EP.Accept(nil) if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept(nil) + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6283,1109 +6630,1066 @@ func TestListenBacklogFull(t *testing.T) { t.Fatalf("Timed out waiting for accept") } } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } - - // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - newEP, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) + b := c.GetPacket() + defer b.Release() + tcp := header.TCP(header.IPv4(b.AsSlice()).Payload()) + if string(tcp.Payload()) != data { + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - b := c.GetPacket() - defer b.Release() - tcp := header.TCP(header.IPv4(b.AsSlice()).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } + }) } // TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a // non unicast IPv4 address are not accepted. func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") - otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") - subnet := context.StackAddrWithPrefix.Subnet() - subnetBroadcastAddr := subnet.Broadcast() - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - name: "SourceUnspecified", - srcAddr: header.IPv4Any, - dstAddr: context.StackAddr, - }, - { - name: "SourceBroadcast", - srcAddr: header.IPv4Broadcast, - dstAddr: context.StackAddr, - }, - { - name: "SourceOurMulticast", - srcAddr: multicastAddr, - dstAddr: context.StackAddr, - }, - { - name: "SourceOtherMulticast", - srcAddr: otherMulticastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestUnspecified", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Any, - }, - { - name: "DestBroadcast", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Broadcast, - }, - { - name: "DestOurMulticast", - srcAddr: context.TestAddr, - dstAddr: multicastAddr, - }, - { - name: "DestOtherMulticast", - srcAddr: context.TestAddr, - dstAddr: otherMulticastAddr, - }, - { - name: "SrcSubnetBroadcast", - srcAddr: subnetBroadcastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestSubnetBroadcast", - srcAddr: context.TestAddr, - dstAddr: subnetBroadcastAddr, - }, - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") + otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") + subnet := context.StackAddrWithPrefix.Subnet() + subnetBroadcastAddr := subnet.Broadcast() + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + name: "SourceUnspecified", + srcAddr: header.IPv4Any, + dstAddr: context.StackAddr, + }, + { + name: "SourceBroadcast", + srcAddr: header.IPv4Broadcast, + dstAddr: context.StackAddr, + }, + { + name: "SourceOurMulticast", + srcAddr: multicastAddr, + dstAddr: context.StackAddr, + }, + { + name: "SourceOtherMulticast", + srcAddr: otherMulticastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestUnspecified", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Any, + }, + { + name: "DestBroadcast", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Broadcast, + }, + { + name: "DestOurMulticast", + srcAddr: context.TestAddr, + dstAddr: multicastAddr, + }, + { + name: "DestOtherMulticast", + srcAddr: context.TestAddr, + dstAddr: otherMulticastAddr, + }, + { + name: "SrcSubnetBroadcast", + srcAddr: subnetBroadcastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestSubnetBroadcast", + srcAddr: context.TestAddr, + dstAddr: subnetBroadcastAddr, + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1) + c.Create(-1) - if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } + if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestAddr, context.StackAddr) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestAddr, context.StackAddr) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1))) + }) + } + }) } // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpiptestutil.MustParse6("ff0e::101") - otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv6Any, - context.StackV6Addr, - }, - { - "SourceAllNodes", - header.IPv6AllNodesMulticastAddress, - context.StackV6Addr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackV6Addr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackV6Addr, - }, - { - "DestUnspecified", - context.TestV6Addr, - header.IPv6Any, - }, - { - "DestAllNodes", - context.TestV6Addr, - header.IPv6AllNodesMulticastAddress, - }, - { - "DestOurMulticast", - context.TestV6Addr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestV6Addr, - otherMulticastAddr, - }, - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + multicastAddr := tcpiptestutil.MustParse6("ff0e::101") + otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv6Any, + context.StackV6Addr, + }, + { + "SourceAllNodes", + header.IPv6AllNodesMulticastAddress, + context.StackV6Addr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackV6Addr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackV6Addr, + }, + { + "DestUnspecified", + context.TestV6Addr, + header.IPv6Any, + }, + { + "DestAllNodes", + context.TestV6Addr, + header.IPv6AllNodesMulticastAddress, + }, + { + "DestOurMulticast", + context.TestV6Addr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestV6Addr, + otherMulticastAddr, + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateV6Endpoint(true) + c.CreateV6Endpoint(true) - if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } + if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestV6Addr, context.StackV6Addr) - v := c.GetV6Packet() - defer v.Release() - checker.IPv6(t, v, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestV6Addr, context.StackV6Addr) + v := c.GetV6Packet() + defer v.Release() + checker.IPv6(t, v, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1))) + }) + } + }) } func TestListenSynRcvdQueueFull(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + // Test acceptance. + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Test acceptance. - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Send two SYN's the first one should get a SYN-ACK, the + // second one should not get any response and is dropped as + // the accept queue is full. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) - // Send two SYN's the first one should get a SYN-ACK, the - // second one should not get any response and is dropped as - // the accept queue is full. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) + // Receive the SYN-ACK reply. + b := c.GetPacket() + defer b.Release() + tcp := header.TCP(header.IPv4(b.AsSlice()).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Receive the SYN-ACK reply. - b := c.GetPacket() - defer b.Release() - tcp := header.TCP(header.IPv4(b.AsSlice()).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + // Now complete the previous connection. + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) - // Now complete the previous connection. - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Verify if that is delivered to the accept queue. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - <-ch - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + // Verify if that is delivered to the accept queue. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) + <-ch - // Try to accept the connections in the backlog. - newEP, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(889), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + // Try to accept the connections in the backlog. + newEP, _, err := c.EP.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - pkt := c.GetPacket() - defer pkt.Release() - tcp = header.IPv4(pkt.AsSlice()).Payload() - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) + pkt := c.GetPacket() + defer pkt.Release() + tcp = header.IPv4(pkt.AsSlice()).Payload() + if string(tcp.Payload()) != data { + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) + } + }) } func TestListenBacklogFullSynCookieInUse(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + // Test for SynCookies usage after filling up the backlog. + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Test for SynCookies usage after filling up the backlog. - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } + executeHandshake(t, c, context.TestPort, true) - executeHandshake(t, c, context.TestPort, true) + // Wait for this to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) + // Send a SYN request. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + // The Syn should be dropped as the endpoint's backlog is full. + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - // Send a SYN request. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - // pick a different src port for new SYN. - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - // The Syn should be dropped as the endpoint's backlog is full. - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Verify that there is only one acceptable connection at this point. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + // Verify that there is only one acceptable connection at this point. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + _, _, err = c.EP.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): + // Now verify that there are no more connections that can be accepted. + _, _, err = c.EP.Accept(nil) + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + select { + case <-ch: + t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) + case <-time.After(1 * time.Second): + } } - } + }) } func TestSYNRetransmit(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Send the same SYN packet multiple times. We should still get a valid SYN-ACK - // reply. - irs := seqnum.Value(context.TestInitialSequenceNumber) - for i := 0; i < 5; i++ { - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Receive the SYN-ACK reply. - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP(tcpCheckers...)) -} + // Send the same SYN packet multiple times. We should still get a valid SYN-ACK + // reply. + irs := seqnum.Value(context.TestInitialSequenceNumber) + for i := 0; i < 5; i++ { + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } -func TestSynRcvdBadSeqNumber(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + // Receive the SYN-ACK reply. + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs) + 1), + } + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP(tcpCheckers...)) + }) +} - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } +func TestSynRcvdBadSeqNumber(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) + // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) - // Receive the SYN-ACK reply. - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - iss := seqnum.Value(tcpHdr.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + // Receive the SYN-ACK reply. + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + iss := seqnum.Value(tcpHdr.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Now send a packet with an out-of-window sequence number - largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: largeSeqnum, - AckNum: iss + 1, - RcvWnd: 30000, - }) + // Now send a packet with an out-of-window sequence number + largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: largeSeqnum, + AckNum: iss + 1, + RcvWnd: 30000, + }) - // Should receive an ACK with the expected SEQ number - b = c.GetPacket() - defer b.Release() - tcpCheckers = []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPAckNum(uint32(irs) + 1), - checker.TCPSeqNum(uint32(iss + 1)), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + // Should receive an ACK with the expected SEQ number + b = c.GetPacket() + defer b.Release() + tcpCheckers = []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPAckNum(uint32(irs) + 1), + checker.TCPSeqNum(uint32(iss + 1)), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Now that the socket replied appropriately with the ACK, - // complete the connection to test that the large SEQ num - // did not change the state from SYN-RCVD. + // Now that the socket replied appropriately with the ACK, + // complete the connection to test that the large SEQ num + // did not change the state from SYN-RCVD. - // Get setup to be notified about connection establishment. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + // Get setup to be notified about connection establishment. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - // Send ACK to move to ESTABLISHED state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) + // Send ACK to move to ESTABLISHED state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) - <-ch - newEP, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + <-ch + newEP, _, err := c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + var r strings.Reader + r.Reset(data) + if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - pkt := c.GetPacket() - defer pkt.Release() - tcpHdr = header.IPv4(pkt.AsSlice()).Payload() - if string(tcpHdr.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) - } + pkt := c.GetPacket() + defer pkt.Release() + tcpHdr = header.IPv4(pkt.AsSlice()).Payload() + if string(tcpHdr.Payload()) != data { + t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) + } + }) } func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 + stats := c.Stack().Stats() + want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - srcPort := uint16(context.TestPort) - executeHandshake(t, c, srcPort+1, true /* synCookiesInUse */) + srcPort := uint16(context.TestPort) + executeHandshake(t, c, srcPort+1, true /* synCookiesInUse */) - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + // Verify that there is only one acceptable connection at this point. + _, _, err = c.EP.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) - } + if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) + } + }) } func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } - srcPort := uint16(context.TestPort) - // Now attempt a handshakes it will fill up the accept backlog. - executeHandshake(t, c, srcPort, true /* synCookesInUse */) + srcPort := uint16(context.TestPort) + // Now attempt a handshakes it will fill up the accept backlog. + executeHandshake(t, c, srcPort, true /* synCookesInUse */) - // Give time for the final ACK to be processed as otherwise the next handshake could - // get accepted before the previous one based on goroutine scheduling. - time.Sleep(50 * time.Millisecond) + // Give time for the final ACK to be processed as otherwise the next handshake could + // get accepted before the previous one based on goroutine scheduling. + time.Sleep(50 * time.Millisecond) - want := stats.TCP.ListenOverflowSynDrop.Value() + 1 + want := stats.TCP.ListenOverflowSynDrop.Value() + 1 - // Now we will send one more SYN and this one should get dropped - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber), - RcvWnd: 30000, - }) + // Now we will send one more SYN and this one should get dropped + // Send a SYN request. + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort + 2, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(context.TestInitialSequenceNumber), + RcvWnd: 30000, + }) - checkValid := func() []error { - var errors []error - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) + checkValid := func() []error { + var errors []error + if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) + } + return errors } - return errors - } - start := time.Now() - for time.Since(start) < time.Minute && len(checkValid()) > 0 { - time.Sleep(50 * time.Millisecond) - } - for _, err := range checkValid() { - t.Error(err) - } - if t.Failed() { - t.FailNow() - } + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) + } + for _, err := range checkValid() { + t.Error(err) + } + if t.Failed() { + t.FailNow() + } - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) - // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - <-ch + // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + <-ch + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } } - } + }) } func TestListenDropIncrement(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - stats := c.Stack().Stats() - c.Create(-1 /*epRcvBuf*/) + stats := c.Stack().Stats() + c.Create(-1 /*epRcvBuf*/) - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(1 /*backlog*/); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1 /*backlog*/); err != nil { + t.Fatalf("Listen failed: %s", err) + } - initialDropped := stats.DroppedPackets.Value() + initialDropped := stats.DroppedPackets.Value() - // Send RST, FIN segments, that are expected to be dropped by the listener. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - }) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin, - }) + // Send RST, FIN segments, that are expected to be dropped by the listener. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin, + }) - // To ensure that the RST, FIN sent earlier are indeed received and ignored - // by the listener, send a SYN and wait for the SYN to be ACKd. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, + // To ensure that the RST, FIN sent earlier are indeed received and ignored + // by the listener, send a SYN and wait for the SYN to be ACKd. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + }) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP(checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + )) + + if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { + t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) + } }) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP(checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - )) - - if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { - t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) - } } func TestEndpointBindListenAcceptState(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - ept := endpointTester{ep} - ept.CheckReadError(t, &tcpip.ErrNotConnected{}) - if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { - t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) - } + ept := endpointTester{ep} + ept.CheckReadError(t, &tcpip.ErrNotConnected{}) + if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { + t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}, 0 /* delay */) + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS}, 0 /* delay */) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - aep, _, err := ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - aep, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + aep, _, err := ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + aep, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - { - err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { - t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + { + err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { + t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) + } + } + // Listening endpoint remains in listen state. + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - } - // Listening endpoint remains in listen state. - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - ep.Close() - // Give worker goroutines time to receive the close notification. - time.Sleep(1 * time.Second) - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - // Accepted endpoint remains open when the listen endpoint is closed. - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } + ep.Close() + // Give worker goroutines time to receive the close notification. + time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + // Accepted endpoint remains open when the listen endpoint is closed. + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + }) } // This test verifies that the auto tuning does not grow the receive buffer if // the application is not reading the data actively. func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const mtu = 1500 + const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + + c := context.New(t, mtu) + defer c.Cleanup() + + stk := c.Stack() + // Set lower limits for auto-tuning tests. This is required because the + // test stops the worker which can cause packets to be dropped because + // the segment queue holding unprocessed packets is limited to 500. + const receiveBufferSize = 80 << 10 // 80KB. + const maxReceiveBufferSize = receiveBufferSize * 10 + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size defined above. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) - - // NOTE: The timestamp values in the sent packets are meaningless to the - // peer so we just increment the timestamp value by 1 every batch as we - // are not really using them for anything. Send a single byte to verify - // the advertised window. - tsVal := rawEP.TSVal + 1 - - // Introduce a 25ms latency by delaying the first byte. - latency := 25 * time.Millisecond - time.Sleep(latency) - // Send an initial payload with atleast segment overhead size. The receive - // window would not grow for smaller segments. - rawEP.SendPacketWithTS(make([]byte, tcp.SegOverheadSize), tsVal) - - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - defer pkt.Release() - rcvWnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize() - - time.Sleep(25 * time.Millisecond) - - // Allocate a large enough payload for the test. - payloadSize := receiveBufferSize * 2 - b := make([]byte, payloadSize) - - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := 0 - end := payloadSize / 2 - packetsSent := 0 - for ; start < end; start += mss { - packetEnd := start + mss - if start+mss > end { - packetEnd = end - } - rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) - packetsSent++ - } - - // Resume the worker so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Since we sent almost the full receive buffer worth of data (some may have - // been dropped due to segment overheads), we should get a zero window back. - pkt = c.GetPacket() - defer pkt.Release() - tcpHdr := header.TCP(header.IPv4(pkt.AsSlice()).Payload()) - gotRcvWnd := tcpHdr.WindowSize() - wantAckNum := tcpHdr.AckNumber() - if got, want := int(gotRcvWnd), 0; got != want { - t.Fatalf("got rcvWnd: %d, want: %d", got, want) - } - - time.Sleep(25 * time.Millisecond) - // Verify that sending more data when receiveBuffer is exhausted. - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - // Now read all the data from the endpoint and verify that advertised - // window increases to the full available buffer size. - for { - _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break + // Enable auto-tuning. + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } - } + // Change the expected window scale to match the value needed for the + // maximum buffer size defined above. + c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - // Verify that we receive a non-zero window update ACK. When running - // under thread sanitizer this test can end up sending more than 1 - // ack, 1 for the non-zero window - p := c.GetPacket() - defer p.Release() - checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(wantAckNum), - func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - // We use 10% here as the error margin upwards as the initial window we - // got was after 1 segment was already in the receive buffer queue. - tolerance := 1.1 - if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { - t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) - } - }, - )) -} + rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) -// This test verifies that the advertised window is auto-tuned up as the -// application is reading the data that is being received. -func TestReceiveBufferAutoTuning(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + // NOTE: The timestamp values in the sent packets are meaningless to the + // peer so we just increment the timestamp value by 1 every batch as we + // are not really using them for anything. Send a single byte to verify + // the advertised window. + tsVal := rawEP.TSVal + 1 - c := context.New(t, mtu) - defer c.Cleanup() + // Introduce a 25ms latency by delaying the first byte. + latency := 25 * time.Millisecond + time.Sleep(latency) + // Send an initial payload with atleast segment overhead size. The receive + // window would not grow for smaller segments. + rawEP.SendPacketWithTS(make([]byte, tcp.SegOverheadSize), tsVal) - // Enable Auto-tuning. - stk := c.Stack() - // Disable out of window rate limiting for this test by setting it to 0 as we - // use out of window ACKs to measure the advertised window. - var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption - if err := stk.SetOption(tcpInvalidRateLimit); err != nil { - t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) - } + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + defer pkt.Release() + rcvWnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize() - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } + time.Sleep(25 * time.Millisecond) - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size used by stack. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := rawEP.TSVal - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - defer pkt.Release() - curRcvWnd := int(header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize()) << c.WindowScale - scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> c.WindowScale) - } - // Allocate a large array to send to the endpoint. - b := make([]byte, receiveBufferSize*48) - - // In every iteration we will send double the number of bytes sent in - // the previous iteration and read the same from the app. The received - // window should grow by at least 2x of bytes read by the app in every - // RTT. - offset := 0 - payloadSize := receiveBufferSize / 8 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - latency := 1 * time.Millisecond - for i := 0; i < 5; i++ { + // Allocate a large enough payload for the test. + payloadSize := receiveBufferSize * 2 + b := make([]byte, payloadSize) + + worker := (c.EP).(interface { + StopWork() + ResumeWork() + }) tsVal++ // Stop the worker goroutine. worker.StopWork() - start := offset - end := offset + payloadSize - totalSent := 0 + start := 0 + end := payloadSize / 2 packetsSent := 0 for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - totalSent += mss + packetEnd := start + mss + if start+mss > end { + packetEnd = end + } + rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) packetsSent++ } - // Resume it so that it only sees the packets once all of them + // Resume the worker so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() - // Give 1ms for the worker to process the packets. - time.Sleep(1 * time.Millisecond) + // Since we sent almost the full receive buffer worth of data (some may have + // been dropped due to segment overheads), we should get a zero window back. + pkt = c.GetPacket() + defer pkt.Release() + tcpHdr := header.TCP(header.IPv4(pkt.AsSlice()).Payload()) + gotRcvWnd := tcpHdr.WindowSize() + wantAckNum := tcpHdr.AckNumber() + if got, want := int(gotRcvWnd), 0; got != want { + t.Fatalf("got rcvWnd: %d, want: %d", got, want) + } + + time.Sleep(25 * time.Millisecond) + // Verify that sending more data when receiveBuffer is exhausted. + rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - lastACK := c.GetPacket() - defer lastACK.Release() - // Discard any intermediate ACKs and only check the last ACK we get in a - // short time period of few ms. + // Now read all the data from the endpoint and verify that advertised + // window increases to the full available buffer size. for { - time.Sleep(1 * time.Millisecond) - pkt := c.GetPacketNonBlocking() - if pkt == nil { + _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } - defer pkt.Release() - lastACK = pkt } - if got, want := int(header.TCP(header.IPv4(lastACK.AsSlice()).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { - t.Fatalf("advertised window got: %d, want <= %d", got, want) + + // Verify that we receive a non-zero window update ACK. When running + // under thread sanitizer this test can end up sending more than 1 + // ack, 1 for the non-zero window + p := c.GetPacket() + defer p.Release() + checker.IPv4(t, p, checker.TCP( + checker.TCPAckNum(wantAckNum), + func(t *testing.T, h header.Transport) { + tcp, ok := h.(header.TCP) + if !ok { + return + } + // We use 10% here as the error margin upwards as the initial window we + // got was after 1 segment was already in the receive buffer queue. + tolerance := 1.1 + if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { + t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) + } + }, + )) + }) +} + +// This test verifies that the advertised window is auto-tuned up as the +// application is reading the data that is being received. +func TestReceiveBufferAutoTuning(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const mtu = 1500 + const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + + c := context.New(t, mtu) + defer c.Cleanup() + + // Enable Auto-tuning. + stk := c.Stack() + // Disable out of window rate limiting for this test by setting it to 0 as we + // use out of window ACKs to measure the advertised window. + var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption + if err := stk.SetOption(tcpInvalidRateLimit); err != nil { + t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) } - // Now read all the data from the endpoint and invoke the - // moderation API to allow for receive buffer auto-tuning - // to happen before we measure the new window. - totalCopied := 0 - for { - res, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break + const receiveBufferSize = 80 << 10 // 80KB. + const maxReceiveBufferSize = receiveBufferSize * 10 + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) } - totalCopied += res.Count } - // Invoke the moderation API. This is required for auto-tuning - // to happen. This method is normally expected to be invoked - // from a higher layer than tcpip.Endpoint. So we simulate - // copying to userspace by invoking it explicitly here. - c.EP.ModerateRecvBuf(totalCopied) + // Enable auto-tuning. + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + } + // Change the expected window scale to match the value needed for the + // maximum buffer size used by stack. + c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - // Now send a keep-alive packet to trigger an ACK so that we can - // measure the new window. + rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4}) + tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + defer pkt.Release() + curRcvWnd := int(header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize()) << c.WindowScale + scaleRcvWnd := func(rcvWnd int) uint16 { + return uint16(rcvWnd >> c.WindowScale) + } + // Allocate a large array to send to the endpoint. + b := make([]byte, receiveBufferSize*48) + + // In every iteration we will send double the number of bytes sent in + // the previous iteration and read the same from the app. The received + // window should grow by at least 2x of bytes read by the app in every + // RTT. + offset := 0 + payloadSize := receiveBufferSize / 8 + worker := (c.EP).(interface { + StopWork() + ResumeWork() + }) + latency := 1 * time.Millisecond + for i := 0; i < 5; i++ { + tsVal++ + + // Stop the worker goroutine. + worker.StopWork() + start := offset + end := offset + payloadSize + totalSent := 0 + packetsSent := 0 + for ; start < end; start += mss { + rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + totalSent += mss + packetsSent++ + } + + // Resume it so that it only sees the packets once all of them + // are waiting to be read. + worker.ResumeWork() + + // Give 1ms for the worker to process the packets. + time.Sleep(1 * time.Millisecond) - if i == 0 { - // In the first iteration the receiver based RTT is not - // yet known as a result the moderation code should not - // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) - } else { - // Read loop above could generate an ACK if the window had dropped to - // zero and then read had opened it up. lastACK := c.GetPacket() defer lastACK.Release() // Discard any intermediate ACKs and only check the last ACK we get in a @@ -7399,49 +7703,101 @@ func TestReceiveBufferAutoTuning(t *testing.T) { defer pkt.Release() lastACK = pkt } - curRcvWnd = int(header.TCP(header.IPv4(lastACK.AsSlice()).Payload()).WindowSize()) << c.WindowScale - // If thew new current window is close maxReceiveBufferSize then terminate - // the loop. This can happen before all iterations are done due to timing - // differences when running the test. - if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { - break + if got, want := int(header.TCP(header.IPv4(lastACK.AsSlice()).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { + t.Fatalf("advertised window got: %d, want <= %d", got, want) } - // Increase the latency after first two iterations to - // establish a low RTT value in the receiver since it - // only tracks the lowest value. This ensures that when - // ModerateRcvBuf is called the elapsed time is always > - // rtt. Without this the test is flaky due to delays due - // to scheduling/wakeup etc. - latency += 50 * time.Millisecond - } - time.Sleep(latency) - offset += payloadSize - payloadSize *= 2 - } - // Check that at the end of our iterations the receive window grew close to the maximum - // permissible size of maxReceiveBufferSize/2 - if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { - t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) - } + // Now read all the data from the endpoint and invoke the + // moderation API to allow for receive buffer auto-tuning + // to happen before we measure the new window. + totalCopied := 0 + for { + res, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + break + } + totalCopied += res.Count + } + + // Invoke the moderation API. This is required for auto-tuning + // to happen. This method is normally expected to be invoked + // from a higher layer than tcpip.Endpoint. So we simulate + // copying to userspace by invoking it explicitly here. + c.EP.ModerateRecvBuf(totalCopied) + + // Now send a keep-alive packet to trigger an ACK so that we can + // measure the new window. + rawEP.NextSeqNum-- + rawEP.SendPacketWithTS(nil, tsVal) + rawEP.NextSeqNum++ + + if i == 0 { + // In the first iteration the receiver based RTT is not + // yet known as a result the moderation code should not + // increase the advertised window. + rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) + } else { + // Read loop above could generate an ACK if the window had dropped to + // zero and then read had opened it up. + lastACK := c.GetPacket() + defer lastACK.Release() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break + } + defer pkt.Release() + lastACK = pkt + } + curRcvWnd = int(header.TCP(header.IPv4(lastACK.AsSlice()).Payload()).WindowSize()) << c.WindowScale + // If thew new current window is close maxReceiveBufferSize then terminate + // the loop. This can happen before all iterations are done due to timing + // differences when running the test. + if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { + break + } + // Increase the latency after first two iterations to + // establish a low RTT value in the receiver since it + // only tracks the lowest value. This ensures that when + // ModerateRcvBuf is called the elapsed time is always > + // rtt. Without this the test is flaky due to delays due + // to scheduling/wakeup etc. + latency += 50 * time.Millisecond + } + time.Sleep(latency) + offset += payloadSize + payloadSize *= 2 + } + // Check that at the end of our iterations the receive window grew close to the maximum + // permissible size of maxReceiveBufferSize/2 + if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { + t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) + } + }) } func TestDelayEnabled(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - checkDelayOption(t, c, false, false) // Delay is disabled by default. - - for _, delayEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - opt := tcpip.TCPDelayEnabled(delayEnabled) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) - } - checkDelayOption(t, c, opt, delayEnabled) - }) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, false) // Delay is disabled by default. + + for _, delayEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + opt := tcpip.TCPDelayEnabled(delayEnabled) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) + } + checkDelayOption(t, c, opt, delayEnabled) + }) + } + }) } func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { @@ -7466,1065 +7822,1115 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T } func TestTCPLingerTimeout(t *testing.T) { - c := context.New(t, 1500 /* mtu */) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - testCases := []struct { - name string - tcpLingerTimeout time.Duration - want time.Duration - }{ - {"NegativeLingerTimeout", -123123, -1}, - // Zero is treated same as the stack's default TCP_LINGER2 timeout. - {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, - {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, - // Values > stack's TCPLingerTimeout are capped to the stack's - // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, -1}, + // Zero is treated same as the stack's default TCP_LINGER2 timeout. + {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) + } - v = 0 - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(&%T) = %s", v, err) - } - if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("got linger timeout = %s, want = %s", got, want) - } - }) - } + v = 0 + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(&%T) = %s", v, err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("got linger timeout = %s, want = %s", got, want) + } + }) + } + }) } func TestTCPTimeWaitRSTIgnored(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - // Receive the SYN-ACK reply. - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } + // Receive the SYN-ACK reply. + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Send ACK. + c.SendPacket(nil, ackHeaders) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } - c.EP.Close() + c.EP.Close() - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } - c.SendPacket(nil, finHeaders) + c.SendPacket(nil, finHeaders) - // Get the ACK to the FIN we just sent. + // Get the ACK to the FIN we just sent. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) - // Now send a RST and this should be ignored and not - // generate an ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - }) + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) - c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + }) } func TestTCPTimeWaitOutOfOrder(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - // Receive the SYN-ACK reply. - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } + // Receive the SYN-ACK reply. + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Send ACK. + c.SendPacket(nil, ackHeaders) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } - c.EP.Close() + c.EP.Close() - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } - c.SendPacket(nil, finHeaders) + c.SendPacket(nil, finHeaders) - // Get the ACK to the FIN we just sent. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) + // Get the ACK to the FIN we just sent. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) }) - - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) } func TestTCPTimeWaitNewSyn(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - // Receive the SYN-ACK reply. - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } + // Receive the SYN-ACK reply. + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } - } - c.EP.Close() + c.EP.Close() - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } - c.SendPacket(nil, finHeaders) + c.SendPacket(nil, finHeaders) - // Get the ACK to the FIN we just sent. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Send a SYN request w/ sequence number lower than - // the highest sequence number sent. We just reuse - // the same number. - iss = seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + // Get the ACK to the FIN we just sent. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) - // drain any older notifications from the notification channel before attempting - // 2nd connection. - select { - case <-ch: - default: - } + // drain any older notifications from the notification channel before attempting + // 2nd connection. + select { + case <-ch: + default: + } - // Send a SYN request w/ sequence number higher than - // the highest sequence number sent. - iss = iss.Add(3) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = iss.Add(3) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - // Receive the SYN-ACK reply. - b = c.GetPacket() - defer b.Release() - tcpHdr = header.IPv4(b.AsSlice()).Payload() - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } + // Receive the SYN-ACK reply. + b = c.GetPacket() + defer b.Release() + tcpHdr = header.IPv4(b.AsSlice()).Payload() + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Send ACK. + c.SendPacket(nil, ackHeaders) - // Try to accept the connection. - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + // Try to accept the connection. + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } } - } + }) } func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) + } - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 - want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + // Receive the SYN-ACK reply. + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } - // Receive the SYN-ACK reply. - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } + // Send ACK. + c.SendPacket(nil, ackHeaders) - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") } - } - c.EP.Close() - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } + c.EP.Close() + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } - c.SendPacket(nil, finHeaders) + c.SendPacket(nil, finHeaders) - // Get the ACK to the FIN we just sent. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) + // Get the ACK to the FIN we just sent. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) - time.Sleep(2 * time.Second) + time.Sleep(2 * time.Second) - // Now send a duplicate FIN. This should cause the TIME_WAIT to extend - // by another 5 seconds and also send us a duplicate ACK as it should - // indicate that the final ACK was potentially lost. - c.SendPacket(nil, finHeaders) + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) - // Get the ACK to the FIN we just sent. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Sleep for 4 seconds so at this point we are 1 second past the - // original tcpLingerTimeout of 5 seconds. - time.Sleep(4 * time.Second) - - // Send an ACK and it should not generate any packet as the socket - // should still be in TIME_WAIT for another another 5 seconds due - // to the duplicate FIN we sent earlier. - *ackHeaders = *finHeaders - ackHeaders.SeqNum = ackHeaders.SeqNum + 1 - ackHeaders.Flags = header.TCPFlagAck - c.SendPacket(nil, ackHeaders) - - c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) - // Now sleep for another 2 seconds so that we are past the - // extended TIME_WAIT of 7 seconds (2 + 5). - time.Sleep(2 * time.Second) - - // Resend the same ACK. - c.SendPacket(nil, ackHeaders) - - // Receive the RST that should be generated as there is no valid - // endpoint. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) + // Get the ACK to the FIN we just sent. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + }) } func TestTCPCloseWithData(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) + } - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // Send a SYN request. + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) + // Receive the SYN-ACK reply. + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } - // Receive the SYN-ACK reply. - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: 30000, - } + // Send ACK. + c.SendPacket(nil, ackHeaders) - // Send ACK. - c.SendPacket(nil, ackHeaders) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } + } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) } - } - - // Now trigger a passive close by sending a FIN. - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - RcvWnd: 30000, - } - - c.SendPacket(nil, finHeaders) - // Get the ACK to the FIN we just sent. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) + // Check that data is received. + b = c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) - // Now write a few bytes and then close the endpoint. - data := []byte{1, 2, 3} + if p := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + c.EP.Close() + // Check the FIN. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.TCPAckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) - // Check that data is received. - b = c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( + // Check the RST. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + }) +} - if p := b.AsSlice()[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } +func TestTCPUserTimeout(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + initRTO := 1 * time.Second + minRTOOpt := tcpip.TCPMinRTOOption(initRTO) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) + } + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - c.EP.Close() - // Check the FIN. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.TCPAckNum(uint32(iss+2)), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - // First send a partial ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now send a full ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now ACK the FIN. - ackHeaders.AckNum++ - c.SendPacket(nil, ackHeaders) - - // Now send an ACK and we should get a RST back as the endpoint should - // be in CLOSED state. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) + c.WQ.EventRegister(&waitEntry) + defer c.WQ.EventUnregister(&waitEntry) - // Check the RST. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() -func TestTCPUserTimeout(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + // Ensure that on the next retransmit timer fire, the user timeout has + // expired. + userTimeout := initRTO / 2 + v := tcpip.TCPUserTimeoutOption(userTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) + } - initRTO := 1 * time.Second - minRTOOpt := tcpip.TCPMinRTOOption(initRTO) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) - } - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + // Send some data and wait before ACKing it. + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + b := c.GetPacket() + defer b.Release() + checker.IPv4(t, b, + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(next), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) - c.WQ.EventRegister(&waitEntry) - defer c.WQ.EventUnregister(&waitEntry) + // Wait for the retransmit timer to be fired and the user timeout to cause + // close of the connection. + select { + case <-notifyCh: + case <-time.After(2 * initRTO): + t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) + } - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") - // Ensure that on the next retransmit timer fire, the user timeout has - // expired. - userTimeout := initRTO / 2 - v := tcpip.TCPUserTimeoutOption(userTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) - } + next += uint32(len(view)) - // Send some data and wait before ACKing it. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - b := c.GetPacket() - defer b.Release() - checker.IPv4(t, b, - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( + b = c.GetPacket() + defer b.Release() + checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + checker.TCPAckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), - ) - - // Wait for the retransmit timer to be fired and the user timeout to cause - // close of the connection. - select { - case <-notifyCh: - case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) - } - - // No packet should be received as the connection should be silently - // closed due to timeout. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") + ) - next += uint32(len(view)) + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrTimeout{}) - // The connection should be terminated after userTimeout has expired. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } }) +} - b = c.GetPacket() - defer b.Release() - checker.IPv4(t, b, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) +func TestKeepaliveWithUserTimeout(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrTimeout{}) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() -func TestKeepaliveWithUserTimeout(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + const keepAliveIdle = 100 * time.Millisecond + const keepAliveInterval = 3 * time.Second + keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) + } + keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) + } + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) + } + c.EP.SocketOptions().SetKeepAlive(true) + + // Set userTimeout to be the duration to be 1 keepalive + // probes. Which means that after the first probe is sent + // the second one should cause the connection to be + // closed due to userTimeout being hit. + userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&userTimeout); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) + } - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + // Check that the connection is still alive. + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + // Now receive 1 keepalives, but don't ACK it. + b := c.GetPacket() + defer b.Release() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) - } - keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) - } - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) - - // Set userTimeout to be the duration to be 1 keepalive - // probes. Which means that after the first probe is sent - // the second one should cause the connection to be - // closed due to userTimeout being hit. - userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&userTimeout); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) - } + // Sleep for a little over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + keepAliveInterval/2) - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) + // The connection should be closed with a timeout. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS + 1, + RcvWnd: 30000, + }) - // Now receive 1 keepalives, but don't ACK it. - b := c.GetPacket() - defer b.Release() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.TCP( + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), - ) - - // Sleep for a little over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) + ) - // The connection should be closed with a timeout. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS + 1, - RcvWnd: 30000, + ept.CheckReadError(t, &tcpip.ErrTimeout{}) + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } }) - - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } } func TestIncreaseWindowOnRead(t *testing.T) { + synctest. // This test ensures that the endpoint sends an ack, // after read() when the window grows by more than 1 MSS. - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - const rcvBuf = 65535 * 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf * 2 - sent := 0 - data := make([]byte, e2e.DefaultMTU/2) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - pkt := c.GetPacket() - defer pkt.Release() - checker.IPv4(t, pkt, + const rcvBuf = 65535 * 10 + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf * 2 + sent := 0 + data := make([]byte, e2e.DefaultMTU/2) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(seqnum.Size(sent)), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + pkt := c.GetPacket() + defer pkt.Release() + checker.IPv4(t, pkt, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(sent)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + // Break once the window drops below e2e.DefaultMTU/2 + if wnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize(); wnd < e2e.DefaultMTU/2 { + break + } + } + + // We now have < 1 MSS in the buffer space. Read at least > 2 MSS + // worth of data as receive buffer space + w := tcpip.LimitedWriter{ + W: io.Discard, + // e2e.DefaultMTU is a good enough estimate for the MSS used for this + // connection. + N: e2e.DefaultMTU * 2, + } + for w.N != 0 { + _, err := c.EP.Read(&w, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + } + + // After reading > MSS worth of data, we surely crossed MSS. See the ack: + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)+uint32(sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) - // Break once the window drops below e2e.DefaultMTU/2 - if wnd := header.TCP(header.IPv4(pkt.AsSlice()).Payload()).WindowSize(); wnd < e2e.DefaultMTU/2 { - break - } - } - - // We now have < 1 MSS in the buffer space. Read at least > 2 MSS - // worth of data as receive buffer space - w := tcpip.LimitedWriter{ - W: io.Discard, - // e2e.DefaultMTU is a good enough estimate for the MSS used for this - // connection. - N: e2e.DefaultMTU * 2, - } - for w.N != 0 { - _, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - // After reading > MSS worth of data, we surely crossed MSS. See the ack: - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + }) } func TestIncreaseWindowOnBufferResize(t *testing.T) { + synctest. // This test ensures that the endpoint sends an ack, // after available recv buffer grows to more than 1 MSS. - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + Test(t, func(t *testing.T) { + defer synctest.Wait() - const rcvBuf = 65535 * 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf - sent := 0 - data := make([]byte, e2e.DefaultMTU/2) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) + const rcvBuf = 65535 * 10 + c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, e2e.DefaultMTU/2) + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(seqnum.Size(sent)), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+uint32(sent)), + checker.TCPWindowLessThanEq(0xffff), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Increasing the buffer from should generate an ACK, + // since window grew from small value to larger equal MSS + c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */) v := c.GetPacket() defer v.Release() checker.IPv4(t, v, @@ -8533,409 +8939,414 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindowLessThanEq(0xffff), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) - } - - // Increasing the buffer from should generate an ACK, - // since window grew from small value to larger equal MSS - c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) + }) } func TestTCPDeferAccept(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1) + c.Create(-1) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) - } + const tcpDeferAccept = 1 * time.Second + tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) + } - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + _, _, err := c.EP.Accept(nil) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) - _, _, err := c.EP.Accept(nil) - if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { - t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) - } + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept(nil) + if err != nil { + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) + } - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, + aep.Close() + // Closing aep without reading the data should trigger a RST. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) }) - - // Receive ACK for the data we sent. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give a bit of time for the socket to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) } func TestTCPDeferAcceptTimeout(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.Create(-1) + c.Create(-1) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) - } + const tcpDeferAccept = 1 * time.Second + tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) + } - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - _, _, err := c.EP.Accept(nil) - if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { - t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) - } + _, _, err := c.EP.Accept(nil) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) + } - // Sleep for a little of the tcpDeferAccept timeout. - time.Sleep(tcpDeferAccept + 100*time.Millisecond) + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) - // On timeout expiry we should get a SYN-ACK retransmission. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) + // On timeout expiry we should get a SYN-ACK retransmission. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) + // Receive ACK for the data we sent. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) - // Receive ACK for the data we sent. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give sometime for the endpoint to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept(nil) + if err != nil { + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) + } - aep.Close() - // Closing aep without reading the data should trigger a RST. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) + aep.Close() + // Closing aep without reading the data should trigger a RST. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) + }) } func TestResetDuringClose(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */) + // Send some data to make sure there is some unread + // data to trigger a reset on c.Close. + irs := c.IRS + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: irs.Add(1), + RcvWnd: 30000, + }) - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */) - // Send some data to make sure there is some unread - // data to trigger a reset on c.Close. - irs := c.IRS - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: irs.Add(1), - RcvWnd: 30000, - }) + // Receive ACK for the data we sent. + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(irs.Add(1))), + checker.TCPAckNum(uint32(iss)+4))) - // Receive ACK for the data we sent. - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(irs.Add(1))), - checker.TCPAckNum(uint32(iss)+4))) - - // Close in a separate goroutine so that we can trigger - // a race with the RST we send below. This should not - // panic due to the route being released depending on - // whether Close() sends an active RST or the RST sent - // below is processed by the worker first. - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(4), - AckNum: c.IRS.Add(5), - RcvWnd: 30000, - Flags: header.TCPFlagRst, - }) - }() + // Close in a separate goroutine so that we can trigger + // a race with the RST we send below. This should not + // panic due to the route being released depending on + // whether Close() sends an active RST or the RST sent + // below is processed by the worker first. + var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - c.EP.Close() - }() + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(4), + AckNum: c.IRS.Add(5), + RcvWnd: 30000, + Flags: header.TCPFlagRst, + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.EP.Close() + }() - wg.Wait() + wg.Wait() + }) } func TestStackTimeWaitReuse(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - s := c.Stack() - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) - } - if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) - } + s := c.Stack() + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) + } + if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } + }) } func TestSetStackTimeWaitReuse(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - s := c.Stack() - testCases := []struct { - v int - err tcpip.Error - }{ - {int(tcpip.TCPTimeWaitReuseDisabled), nil}, - {int(tcpip.TCPTimeWaitReuseGlobal), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, - } - - for _, tc := range testCases { - opt := tcpip.TCPTimeWaitReuseOption(tc.v) - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) - if got, want := err, tc.err; got != want { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) - } - if tc.err != nil { - continue + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + s := c.Stack() + testCases := []struct { + v int + err tcpip.Error + }{ + {int(tcpip.TCPTimeWaitReuseDisabled), nil}, + {int(tcpip.TCPTimeWaitReuseGlobal), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, } - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) - } + for _, tc := range testCases { + opt := tcpip.TCPTimeWaitReuseOption(tc.v) + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) + if got, want := err, tc.err; got != want { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) + } + if tc.err != nil { + continue + } - if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + var twReuse tcpip.TCPTimeWaitReuseOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) + } + + if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { + t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) + } } - } + }) } func TestHandshakeRTT(t *testing.T) { - type testCase struct { - connect bool - tsEnabled bool - useCookie bool - retrans bool - delay time.Duration - wantRTT time.Duration - } - var testCases []testCase - for _, connect := range []bool{false, true} { - for _, tsEnabled := range []bool{false, true} { - for _, useCookie := range []bool{false, true} { - for _, retrans := range []bool{false, true} { - if connect && useCookie { - continue - } - delay := 800 * time.Millisecond - if retrans { - delay = 1200 * time.Millisecond - } - wantRTT := delay - // If syncookie is enabled, sample RTT only when TS option is enabled. - if !retrans && useCookie && !tsEnabled { - wantRTT = 0 - } - // If retransmitted, sample RTT only when TS option is enabled. - if retrans && !tsEnabled { - wantRTT = 0 + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + type testCase struct { + connect bool + tsEnabled bool + useCookie bool + retrans bool + delay time.Duration + wantRTT time.Duration + } + var testCases []testCase + for _, connect := range []bool{false, true} { + for _, tsEnabled := range []bool{false, true} { + for _, useCookie := range []bool{false, true} { + for _, retrans := range []bool{false, true} { + if connect && useCookie { + continue + } + delay := 800 * time.Millisecond + if retrans { + delay = 1200 * time.Millisecond + } + wantRTT := delay + // If syncookie is enabled, sample RTT only when TS option is enabled. + if !retrans && useCookie && !tsEnabled { + wantRTT = 0 + } + // If retransmitted, sample RTT only when TS option is enabled. + if retrans && !tsEnabled { + wantRTT = 0 + } + testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT}) } - testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT}) } } } - } - for _, tt := range testCases { - tt := tt - t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) { - t.Parallel() - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - if tt.useCookie { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + for _, tt := range testCases { + tt := tt + t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) { + t.Parallel() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + if tt.useCookie { + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } - } - synOpts := header.TCPSynOptions{} - if tt.tsEnabled { - synOpts.TS = true - synOpts.TSVal = 42 - } - if tt.connect { - c.CreateConnectedWithOptions(synOpts, tt.delay) - } else { - synOpts.MSS = e2e.DefaultIPv4MSS - synOpts.WS = -1 - c.AcceptWithOptions(-1, synOpts, tt.delay) - } - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) - } - if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT { - t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT) - } - if info.RTTVar != 0 && tt.wantRTT == 0 { - t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar) - } - if info.RTTVar == 0 && tt.wantRTT != 0 { - t.Fatalf("got info.RTTVar=0, expect non zero") - } - }) - } + synOpts := header.TCPSynOptions{} + if tt.tsEnabled { + synOpts.TS = true + synOpts.TSVal = 42 + } + if tt.connect { + c.CreateConnectedWithOptions(synOpts, tt.delay) + } else { + synOpts.MSS = e2e.DefaultIPv4MSS + synOpts.WS = -1 + c.AcceptWithOptions(-1, synOpts, tt.delay) + } + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err) + } + if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT { + t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT) + } + if info.RTTVar != 0 && tt.wantRTT == 0 { + t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar) + } + if info.RTTVar == 0 && tt.wantRTT != 0 { + t.Fatalf("got info.RTTVar=0, expect non zero") + } + }) + } + }) } func TestSetRTO(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - minRTO, maxRTO := tcpRTOMinMax(t, c) - c.Cleanup() - for _, tt := range []struct { - name string - RTO time.Duration - minRTO time.Duration - maxRTO time.Duration - err tcpip.Error - }{ - { - name: "invalid minRTO", - minRTO: maxRTO + time.Second, - err: &tcpip.ErrInvalidOptionValue{}, - }, - { - name: "invalid maxRTO", - maxRTO: minRTO - time.Millisecond, - err: &tcpip.ErrInvalidOptionValue{}, - }, - { - name: "valid minRTO", - minRTO: maxRTO - time.Second, - }, - { - name: "valid maxRTO", - maxRTO: minRTO + time.Millisecond, - }, - } { - t.Run(tt.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - var opt tcpip.SettableTransportProtocolOption - if tt.minRTO > 0 { - min := tcpip.TCPMinRTOOption(tt.minRTO) - opt = &min - } - if tt.maxRTO > 0 { - max := tcpip.TCPMaxRTOOption(tt.maxRTO) - opt = &max - } - err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt) - if got, want := err, tt.err; got != want { - t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want) - } - if tt.err == nil { - minRTO, maxRTO := tcpRTOMinMax(t, c) - if tt.minRTO > 0 && tt.minRTO != minRTO { - t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + minRTO, maxRTO := tcpRTOMinMax(t, c) + c.Cleanup() + for _, tt := range []struct { + name string + RTO time.Duration + minRTO time.Duration + maxRTO time.Duration + err tcpip.Error + }{ + { + name: "invalid minRTO", + minRTO: maxRTO + time.Second, + err: &tcpip.ErrInvalidOptionValue{}, + }, + { + name: "invalid maxRTO", + maxRTO: minRTO - time.Millisecond, + err: &tcpip.ErrInvalidOptionValue{}, + }, + { + name: "valid minRTO", + minRTO: maxRTO - time.Second, + }, + { + name: "valid maxRTO", + maxRTO: minRTO + time.Millisecond, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + var opt tcpip.SettableTransportProtocolOption + if tt.minRTO > 0 { + min := tcpip.TCPMinRTOOption(tt.minRTO) + opt = &min } - if tt.maxRTO > 0 && tt.maxRTO != maxRTO { - t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO) + if tt.maxRTO > 0 { + max := tcpip.TCPMaxRTOOption(tt.maxRTO) + opt = &max } - } - }) - } + err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt) + if got, want := err, tt.err; got != want { + t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want) + } + if tt.err == nil { + minRTO, maxRTO := tcpRTOMinMax(t, c) + if tt.minRTO > 0 && tt.minRTO != minRTO { + t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO) + } + if tt.maxRTO > 0 && tt.maxRTO != maxRTO { + t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO) + } + } + }) + } + }) } func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) { @@ -8963,338 +9374,353 @@ func generateRandomPayload(t *testing.T, n int) []byte { } func TestSendBufferTuning(t *testing.T) { - const maxPayload = 536 - const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + e2e.MaxTCPOptionSize + maxPayload - const packetOverheadFactor = 2 - - testCases := []struct { - name string - autoTuningDisabled bool - }{ - {"autoTuningDisabled", true}, - {"autoTuningEnabled", false}, - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 536 + const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + e2e.MaxTCPOptionSize + maxPayload + const packetOverheadFactor = 2 + + testCases := []struct { + name string + autoTuningDisabled bool + }{ + {"autoTuningDisabled", true}, + {"autoTuningEnabled", false}, + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() - // Set the stack option for send buffer size. - const defaultSndBufSz = maxPayload * tcp.InitialCwnd - const maxSndBufSz = defaultSndBufSz * 10 - { - opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz} - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + // Set the stack option for send buffer size. + const defaultSndBufSz = maxPayload * tcp.InitialCwnd + const maxSndBufSz = defaultSndBufSz * 10 + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz} + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - } - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - oldSz := c.EP.SocketOptions().GetSendBufferSize() - if oldSz != defaultSndBufSz { - t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz) - } + oldSz := c.EP.SocketOptions().GetSendBufferSize() + if oldSz != defaultSndBufSz { + t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz) + } - if tc.autoTuningDisabled { - c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */) - } + if tc.autoTuningDisabled { + c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */) + } - data := make([]byte, maxPayload) - for i := range data { - data[i] = byte(i) - } + data := make([]byte, maxPayload) + for i := range data { + data[i] = byte(i) + } - w, ch := waiter.NewChannelEntry(waiter.WritableEvents) - c.WQ.EventRegister(&w) - defer c.WQ.EventUnregister(&w) + w, ch := waiter.NewChannelEntry(waiter.WritableEvents) + c.WQ.EventRegister(&w) + defer c.WQ.EventUnregister(&w) - bytesRead := 0 - for { - // Packets will be sent till the send buffer - // size is reached. - var r bytes.Reader - r.Reset(data[bytesRead : bytesRead+maxPayload]) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break - } + bytesRead := 0 + for { + // Packets will be sent till the send buffer + // size is reached. + var r bytes.Reader + r.Reset(data[bytesRead : bytesRead+maxPayload]) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + break + } - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0) - bytesRead += maxPayload - data = append(data, data...) - } + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0) + bytesRead += maxPayload + data = append(data, data...) + } - // Send an ACK and wait for connection to become writable again. - c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Write failed: %s", err) + // Send an ACK and wait for connection to become writable again. + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) + select { + case <-ch: + if err := c.EP.LastError(); err != nil { + t.Fatalf("Write failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - outSz := int64(defaultSndBufSz) - if !tc.autoTuningDisabled { - // Calculate the new auto tuned send buffer. - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + outSz := int64(defaultSndBufSz) + if !tc.autoTuningDisabled { + // Calculate the new auto tuned send buffer. + var info tcpip.TCPInfoOption + if err := c.EP.GetSockOpt(&info); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload } - outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload - } - if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz { - t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz) - } - }) - } + if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz { + t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz) + } + }) + } + }) } func TestTimestampSynCookies(t *testing.T) { - clock := faketime.NewManualClock() - tsNow := func() uint32 { - return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds()) - } - // Advance the clock so that NowMonotonic is non-zero. - clock.Advance(time.Second) - c := context.NewWithOpts(t, context.Options{ - EnableV4: true, - EnableV6: true, - MTU: e2e.DefaultMTU, - Clock: clock, - }) - defer c.Cleanup() - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + clock := faketime.NewManualClock() + tsNow := func() uint32 { + return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds()) + } + // Advance the clock so that NowMonotonic is non-zero. + clock.Advance(time.Second) + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + EnableV6: true, + MTU: e2e.DefaultMTU, + Clock: clock, + }) + defer c.Cleanup() + opt := tcpip.TCPAlwaysUseSynCookies(true) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() - tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(42, 0, tcpOpts[2:]) - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - RcvWnd: seqnum.Size(512), - SeqNum: iss, - TCPOpts: tcpOpts[:], - }) - // Get the TSVal of SYN-ACK. - b := c.GetPacket() - defer b.Release() - tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - initialTSVal := tcpHdr.ParsedOptions().TSVal - // derive the tsOffset. - tsOffset := initialTSVal - tsNow() + tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(42, 0, tcpOpts[2:]) + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + RcvWnd: seqnum.Size(512), + SeqNum: iss, + TCPOpts: tcpOpts[:], + }) + // Get the TSVal of SYN-ACK. + b := c.GetPacket() + defer b.Release() + tcpHdr := header.TCP(header.IPv4(b.AsSlice()).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + initialTSVal := tcpHdr.ParsedOptions().TSVal + // derive the tsOffset. + tsOffset := initialTSVal - tsNow() + + header.EncodeTSOption(420, initialTSVal, tcpOpts[2:]) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + RcvWnd: seqnum.Size(512), + SeqNum: iss + 1, + AckNum: c.IRS + 1, + TCPOpts: tcpOpts[:], + }) + c.EP, _, err = ep.Accept(nil) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - header.EncodeTSOption(420, initialTSVal, tcpOpts[2:]) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - RcvWnd: seqnum.Size(512), - SeqNum: iss + 1, - AckNum: c.IRS + 1, - TCPOpts: tcpOpts[:], - }) - c.EP, _, err = ep.Accept(nil) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + } else if err != nil { + t.Fatalf("failed to accept: %s", err) } - } else if err != nil { - t.Fatalf("failed to accept: %s", err) - } - // Advance the clock again so that we expect the next TSVal to change. - clock.Advance(time.Second) - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } + // Advance the clock again so that we expect the next TSVal to change. + clock.Advance(time.Second) + data := []byte{1, 2, 3} + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } - // The endpoint should have a correct TSOffset so that the received TSVal - // should match our expectation. - p := c.GetPacket() - defer p.Release() - if got, want := header.TCP(header.IPv4(p.AsSlice()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want { - t.Fatalf("got TSVal = %d, want %d", got, want) - } + // The endpoint should have a correct TSOffset so that the received TSVal + // should match our expectation. + p := c.GetPacket() + defer p.Release() + if got, want := header.TCP(header.IPv4(p.AsSlice()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want { + t.Fatalf("got TSVal = %d, want %d", got, want) + } + }) } // TestECNFlagsAccept tests that an ECN non-setup/setup SYN is accepted // and the connection is correctly completed. func TestECNFlagsAccept(t *testing.T) { - testCases := []struct { - name string - flags header.TCPFlags - }{ - {name: "non-setup ECN SYN w/ ECE", flags: header.TCPFlagEce}, - {name: "non-setup ECN SYN w/ CWR", flags: header.TCPFlagCwr}, - {name: "setup ECN SYN", flags: header.TCPFlagEce | header.TCPFlagCwr}, - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + testCases := []struct { + name string + flags header.TCPFlags + }{ + {name: "non-setup ECN SYN w/ ECE", flags: header.TCPFlagEce}, + {name: "non-setup ECN SYN w/ CWR", flags: header.TCPFlagCwr}, + {name: "setup ECN SYN", flags: header.TCPFlagEce | header.TCPFlagCwr}, + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // Do 3-way handshake. - const maxPayload = 100 + // Do 3-way handshake. + const maxPayload = 100 - c.PassiveConnect(maxPayload, -1 /* wndScale */, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS, Flags: tc.flags}) + c.PassiveConnect(maxPayload, -1 /* wndScale */, header.TCPSynOptions{MSS: e2e.DefaultIPv4MSS, Flags: tc.flags}) - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - wq.EventRegister(&we) - defer wq.EventUnregister(&we) + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + wq.EventRegister(&we) + defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } + c.EP, _, err = ep.Accept(nil) + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } else if err != nil { + t.Fatalf("Accept failed: %s", err) } - } else if err != nil { - t.Fatalf("Accept failed: %s", err) - } - }) - } + }) + } + }) } func TestReadAfterCloseWithBufferedData(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - con := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) - // Fill up the receive queue. - for i := 0; i < 300; i++ { - con.SendPacket([]byte{1, 2, 3, 4}, nil) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + con := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) + // Fill up the receive queue. + for i := 0; i < 300; i++ { + con.SendPacket([]byte{1, 2, 3, 4}, nil) + } - timeout := time.After(5 * time.Second) - // If the receive queue is not properly drained, the endpoint will never - // return ErrClosedForReceive. - c.EP.Close() - for { - select { - case <-timeout: - t.Fatalf("timed out waiting for read to return error %q", &tcpip.ErrClosedForReceive{}) - return - default: - if _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}); cmp.Equal(err, &tcpip.ErrClosedForReceive{}) { + timeout := time.After(5 * time.Second) + // If the receive queue is not properly drained, the endpoint will never + // return ErrClosedForReceive. + c.EP.Close() + for { + select { + case <-timeout: + t.Fatalf("timed out waiting for read to return error %q", &tcpip.ErrClosedForReceive{}) return + default: + if _, err := c.EP.Read(io.Discard, tcpip.ReadOptions{}); cmp.Equal(err, &tcpip.ErrClosedForReceive{}) { + return + } } } - } + }) } func TestReleaseDanglingEndpoints(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + // Close the endpoint, make sure we get a FIN segment. The endpoint should be + // dangling. + ep.Close() + iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + v := c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + )) + tcpip.ReleaseDanglingEndpoints() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil + // ReleaseDanglingEndpoints should abort the half-closed endpoint causing + // a RST to be sent. + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(iss)), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + )) - // Close the endpoint, make sure we get a FIN segment. The endpoint should be - // dangling. - ep.Close() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - v := c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - )) - tcpip.ReleaseDanglingEndpoints() + // Now send an ACK and it should trigger a RST as the endpoint is aborted. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) - // ReleaseDanglingEndpoints should abort the half-closed endpoint causing - // a RST to be sent. - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - )) - - // Now send an ACK and it should trigger a RST as the endpoint is aborted. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, + v = c.GetPacket() + defer v.Release() + checker.IPv4(t, v, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst), + )) }) - - v = c.GetPacket() - defer v.Release() - checker.IPv4(t, v, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - )) } // TestLateSynCookieAck ensures that we properly handle the following case @@ -9306,283 +9732,292 @@ func TestReleaseDanglingEndpoints(t *testing.T) { // - We receive an ACK based on S. // - We respond with an RST because we expected an ACK based on S'. func TestLateSynCookieAck(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } - initial := stats.TCP.CurrentEstablished.Value() + initial := stats.TCP.CurrentEstablished.Value() - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } - // With a backlog of 2, we get one slot in the SYN queue before we - // start using SYN cookies. See - // //pkg/tcpip/transport/tcp/accept.go:handleListenSegment:useSynCookies - // for an explanation. - if err := ep.Listen(2); err != nil { - t.Fatalf("Listen failed: %s", err) - } + // With a backlog of 2, we get one slot in the SYN queue before we + // start using SYN cookies. See + // //pkg/tcpip/transport/tcp/accept.go:handleListenSegment:useSynCookies + // for an explanation. + if err := ep.Listen(2); err != nil { + t.Fatalf("Listen failed: %s", err) + } - // To reach our desired state, we're gonna do the following: - // - // - Send SYN S1 to force subsequent SYNs to return cookies. - // - Send SYN S2, which returns a cookie SYN/ACK. - // - Finish S1's handshake, opening space in the SYN queue. - // - Retransmit S2, which will give use a different seqnum. - // - Finish S2's handshake with the cookie SYN/ACK. - - // Send S1. - const otherTestPort = context.TestPort + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: otherTestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - // Receive the SYN-ACK reply. - s1Reply := c.GetPacket() - defer s1Reply.Release() - s1ReplyHdr := header.TCP(header.IPv4(s1Reply.AsSlice()).Payload()) + // To reach our desired state, we're gonna do the following: + // + // - Send SYN S1 to force subsequent SYNs to return cookies. + // - Send SYN S2, which returns a cookie SYN/ACK. + // - Finish S1's handshake, opening space in the SYN queue. + // - Retransmit S2, which will give use a different seqnum. + // - Finish S2's handshake with the cookie SYN/ACK. + + // Send S1. + const otherTestPort = context.TestPort + 1 + iss := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: otherTestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + // Receive the SYN-ACK reply. + s1Reply := c.GetPacket() + defer s1Reply.Release() + s1ReplyHdr := header.TCP(header.IPv4(s1Reply.AsSlice()).Payload()) - // Send S2. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - // Receive the SYN-ACK reply. - s2CookieReply := c.GetPacket() - defer s2CookieReply.Release() - s2CookieReplyHdr := header.TCP(header.IPv4(s2CookieReply.AsSlice()).Payload()) - - // Finish the S1 handshake. - ackHeaders := &context.Headers{ - SrcPort: otherTestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: seqnum.Value(s1ReplyHdr.SequenceNumber()) + 1, - } - c.SendPacket(nil, ackHeaders) + // Send S2. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + // Receive the SYN-ACK reply. + s2CookieReply := c.GetPacket() + defer s2CookieReply.Release() + s2CookieReplyHdr := header.TCP(header.IPv4(s2CookieReply.AsSlice()).Payload()) + + // Finish the S1 handshake. + ackHeaders := &context.Headers{ + SrcPort: otherTestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: seqnum.Value(s1ReplyHdr.SequenceNumber()) + 1, + } + c.SendPacket(nil, ackHeaders) - // Wait for S1's connection to move from the SYN to the accept queue. - metricPollFn := func() error { - if got, want := stats.TCP.CurrentEstablished.Value(), initial+1; got != want { - return fmt.Errorf("connection never established: got stats.TCP.CurrentEstablished.Value() = %d, want = %d", got, want) + // Wait for S1's connection to move from the SYN to the accept queue. + metricPollFn := func() error { + if got, want := stats.TCP.CurrentEstablished.Value(), initial+1; got != want { + return fmt.Errorf("connection never established: got stats.TCP.CurrentEstablished.Value() = %d, want = %d", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Fatal(err) } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Fatal(err) - } - // Retransmit S2. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - // Receive the SYN-ACK reply. - s2QueueReply := c.GetPacket() - defer s2QueueReply.Release() - s2QueueReplyHdr := header.TCP(header.IPv4(s2QueueReply.AsSlice()).Payload()) - if s2CookieReplyHdr.SequenceNumber() == s2QueueReplyHdr.SequenceNumber() { - t.Fatalf("the SYN cookie and regular seqnum are equal; is the backlog too large?") - } + // Retransmit S2. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + // Receive the SYN-ACK reply. + s2QueueReply := c.GetPacket() + defer s2QueueReply.Release() + s2QueueReplyHdr := header.TCP(header.IPv4(s2QueueReply.AsSlice()).Payload()) + if s2CookieReplyHdr.SequenceNumber() == s2QueueReplyHdr.SequenceNumber() { + t.Fatalf("the SYN cookie and regular seqnum are equal; is the backlog too large?") + } - // Finish S2's handshake using the cookie. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: seqnum.Value(s2CookieReplyHdr.SequenceNumber()) + 1, - } - c.SendPacket(nil, ackHeaders) + // Finish S2's handshake using the cookie. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: seqnum.Value(s2CookieReplyHdr.SequenceNumber()) + 1, + } + c.SendPacket(nil, ackHeaders) - // Verify that we've completed two connections. - metricPollFn = func() error { - if got, want := stats.TCP.CurrentEstablished.Value(), initial+2; got != want { - return fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = %d", got, want) + // Verify that we've completed two connections. + metricPollFn = func() error { + if got, want := stats.TCP.CurrentEstablished.Value(), initial+2; got != want { + return fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = %d", got, want) + } + return nil } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + }) } func TestSetExperimentOption(t *testing.T) { - c := context.NewWithOpts(t, context.Options{ - EnableV4: true, - MTU: e2e.DefaultMTU, - EnableExperimentIPOption: true, - }) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + MTU: e2e.DefaultMTU, + EnableExperimentIPOption: true, + }) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - var expval uint16 = 99 - c.EP.SocketOptions().SetExperimentOptionValue(expval) + var expval uint16 = 99 + c.EP.SocketOptions().SetExperimentOptionValue(expval) - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } - v := c.GetPacket() - defer v.Release() - want := header.IPv4Options{ - byte(header.IPv4OptionExperimentType), - byte(header.IPv4OptionExperimentLength), - 0, - byte(expval), - } - checker.IPv4(t, v, checker.IPv4Options(want)) + v := c.GetPacket() + defer v.Release() + want := header.IPv4Options{ + byte(header.IPv4OptionExperimentType), + byte(header.IPv4OptionExperimentLength), + 0, + byte(expval), + } + checker.IPv4(t, v, checker.IPv4Options(want)) + }) } func TestSetExperimentOptionIPv6(t *testing.T) { - c := context.NewWithOpts(t, context.Options{ - EnableV4: false, - EnableV6: true, - MTU: e2e.DefaultMTU, - EnableExperimentIPOption: true, - }) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.NewWithOpts(t, context.Options{ + EnableV4: false, + EnableV6: true, + MTU: e2e.DefaultMTU, + EnableExperimentIPOption: true, + }) + defer c.Cleanup() - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) - c.WQ.EventRegister(&waitEntry) - defer c.WQ.EventUnregister(&waitEntry) + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + c.EP = ep - err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) - } + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents) + c.WQ.EventRegister(&waitEntry) + defer c.WQ.EventUnregister(&waitEntry) - // Receive SYN packet. - b := c.GetV6Packet() - defer b.Release() - checker.IPv6(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } + err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("Unexpected return value from Connect: %v", err) + } - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - tcpHdr := header.TCP(header.IPv6(b.AsSlice()).Payload()) - synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + // Receive SYN packet. + b := c.GetV6Packet() + defer b.Release() + checker.IPv6(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - c.SendV6Packet(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) + iss := seqnum.Value(context.TestInitialSequenceNumber) + rcvWnd := seqnum.Size(30000) + tcpHdr := header.TCP(header.IPv6(b.AsSlice()).Payload()) + synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) - // Receive ACK packet. - b = c.GetV6Packet() - defer b.Release() - checker.IPv6(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) + // Receive ACK packet. + b = c.GetV6Packet() + defer b.Release() + checker.IPv6(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), + ), + ) - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) + // Wait for connection to be established. + select { + case <-notifyCh: + if err := c.EP.LastError(); err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - c.RcvdWindowScale = uint8(synOpts.WS) - c.Port = tcpHdr.SourcePort() + c.RcvdWindowScale = uint8(synOpts.WS) + c.Port = tcpHdr.SourcePort() - var expval uint16 = 99 - c.EP.SocketOptions().SetExperimentOptionValue(expval) + var expval uint16 = 99 + c.EP.SocketOptions().SetExperimentOptionValue(expval) - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err = c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err = c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } - v := c.GetV6Packet() - defer v.Release() + v := c.GetV6Packet() + defer v.Release() - checker.IPv6WithExtHdr(t, v, checker.IPv6ExtHdr(checker.IPv6ExperimentHeader(expval))) + checker.IPv6WithExtHdr(t, v, checker.IPv6ExtHdr(checker.IPv6ExperimentHeader(expval))) + }) } func TestSetExperimentOptionWithOptionDisabled(t *testing.T) { - c := context.NewWithOpts(t, context.Options{ - EnableV4: true, - MTU: e2e.DefaultMTU, - EnableExperimentIPOption: false, - }) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.NewWithOpts(t, context.Options{ + EnableV4: true, + MTU: e2e.DefaultMTU, + EnableExperimentIPOption: false, + }) + defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) + c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - var expval uint16 = 99 - c.EP.SocketOptions().SetExperimentOptionValue(expval) + var expval uint16 = 99 + c.EP.SocketOptions().SetExperimentOptionValue(expval) - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } - v := c.GetPacket() - defer v.Release() - want := header.IPv4Options{} - checker.IPv4(t, v, checker.IPv4Options(want)) + v := c.GetPacket() + defer v.Release() + want := header.IPv4Options{} + checker.IPv4(t, v, checker.IPv4Options(want)) + }) } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() tcpip.ReleaseDanglingEndpoints() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) } diff --git a/pkg/tcpip/transport/tcp/test/e2e/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/test/e2e/tcp_timestamp_test.go index 6b6dd0ca9f..d9d6b41eaf 100644 --- a/pkg/tcpip/transport/tcp/test/e2e/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/test/e2e/tcp_timestamp_test.go @@ -19,6 +19,7 @@ import ( "math/rand" "os" "testing" + "testing/synctest" "time" "github.com/google/go-cmp/cmp" @@ -42,88 +43,91 @@ func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint // an active connect and sets the TS Echo Reply fields correctly when the // SYN-ACK also indicates support for the TS option and provides a TSVal. func TestTimeStampEnabledConnect(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read and validate that we have data to read. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - - // The following tests ensure that TS option once enabled behaves - // correctly as described in - // https://tools.ietf.org/html/rfc7323#section-4.3. - // - // We are not testing delayed ACKs here, but we do test out of order - // packet delivery and filling the sequence number hole created due to - // the out of order packet. - // - // The test also verifies that the sequence numbers and timestamps are - // as expected. - data := []byte{1, 2, 3} - - // First we increment tsVal by a small amount. - tsVal := rep.TSVal + 100 - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Next we send an out of order packet. - rep.NextSeqNum += 3 - tsVal += 200 - rep.SendPacketWithTS(data, tsVal) - - // The ACK should contain the original sequenceNumber and an older TS. - rep.NextSeqNum -= 6 - rep.VerifyACKWithTS(tsVal - 200) - - // Next we fill the hole and the returned ACK should contain the - // cumulative sequence number acking all data sent till now and have the - // latest timestamp sent below in its TSEcr field. - tsVal -= 100 - rep.SendPacketWithTS(data, tsVal) - rep.NextSeqNum += 3 - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal by a large value that doesn't result in a wrap around. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal again by a large value which should cause the - // timestamp value to wrap around. The returned ACK should contain the - // wrapped around timestamp in its tsEcr field and not the tsVal from - // the previous packet sent above. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // There should be 5 views to read and each of them should - // contain the same data. - for i := 0; i < 5; i++ { - buf := make([]byte, len(data)) - w := tcpip.SliceWriter(buf) - result, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(buf), - Total: len(buf), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() + + rep := createConnectedWithTimestampOption(c) + + // Register for read and validate that we have data to read. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) + + // The following tests ensure that TS option once enabled behaves + // correctly as described in + // https://tools.ietf.org/html/rfc7323#section-4.3. + // + // We are not testing delayed ACKs here, but we do test out of order + // packet delivery and filling the sequence number hole created due to + // the out of order packet. + // + // The test also verifies that the sequence numbers and timestamps are + // as expected. + data := []byte{1, 2, 3} + + // First we increment tsVal by a small amount. + tsVal := rep.TSVal + 100 + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + // Next we send an out of order packet. + rep.NextSeqNum += 3 + tsVal += 200 + rep.SendPacketWithTS(data, tsVal) + + // The ACK should contain the original sequenceNumber and an older TS. + rep.NextSeqNum -= 6 + rep.VerifyACKWithTS(tsVal - 200) + + // Next we fill the hole and the returned ACK should contain the + // cumulative sequence number acking all data sent till now and have the + // latest timestamp sent below in its TSEcr field. + tsVal -= 100 + rep.SendPacketWithTS(data, tsVal) + rep.NextSeqNum += 3 + rep.VerifyACKWithTS(tsVal) + + // Increment tsVal by a large value that doesn't result in a wrap around. + tsVal += 0x7fffffff + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + // Increment tsVal again by a large value which should cause the + // timestamp value to wrap around. The returned ACK should contain the + // wrapped around timestamp in its tsEcr field and not the tsVal from + // the previous packet sent above. + tsVal += 0x7fffffff + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") } - if got, want := buf, data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) + + // There should be 5 views to read and each of them should + // contain the same data. + for i := 0; i < 5; i++ { + buf := make([]byte, len(data)) + w := tcpip.SliceWriter(buf) + result, err := c.EP.Read(&w, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(buf), + Total: len(buf), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) + } + if got, want := buf, data; bytes.Compare(got, want) != 0 { + t.Fatalf("Data is different: got: %v, want: %v", got, want) + } } - } + }) } // TestTimeStampDisabledConnect tests that netstack sends timestamp option on an @@ -131,10 +135,13 @@ func TestTimeStampEnabledConnect(t *testing.T) { // timestamp option is not enabled and future packets do not contain a // timestamp. func TestTimeStampDisabledConnect(t *testing.T) { - c := context.New(t, e2e.DefaultMTU) - defer c.Cleanup() + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + c := context.New(t, e2e.DefaultMTU) + defer c.Cleanup() - c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) + c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{}) + }) } func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { @@ -187,18 +194,21 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS // that Timestamp option is enabled in both cases if requested in the original // SYN. func TestTimeStampEnabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + testCases := []struct { + cookieEnabled bool + wndScale int + wndSize uint16 + }{ + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. + {false, 5, 0x4000}, + } + for _, tc := range testCases { + timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) + } + }) } func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { @@ -245,82 +255,88 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd // TestTimeStampDisabledAccept tests that Timestamp option is not used when the // peer doesn't advertise it and connection is established with Accept(). func TestTimeStampDisabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of - // that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + testCases := []struct { + cookieEnabled bool + wndScale int + wndSize uint16 + }{ + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of + // that. + {false, 5, 0x4000}, + } + for _, tc := range testCases { + timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) + } + }) } func TestSendGreaterThanMTUWithOptions(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - createConnectedWithTimestampOption(c) - e2e.CheckBrokenUpWrite(t, c, maxPayload) + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + createConnectedWithTimestampOption(c) + e2e.CheckBrokenUpWrite(t, c, maxPayload) + }) } func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read. - we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) - c.WQ.EventRegister(&we) - defer c.WQ.EventUnregister(&we) - - droppedPacketsStat := c.Stack().Stats().DroppedPackets - droppedPackets := droppedPacketsStat.Value() - data := []byte{1, 2, 3} - // Send a packet with no TCP options/timestamp. - rep.SendPacket(data, nil) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } + synctest.Test(t, func(t *testing.T) { + defer synctest.Wait() + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + rep := createConnectedWithTimestampOption(c) + + // Register for read. + we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) + c.WQ.EventRegister(&we) + defer c.WQ.EventUnregister(&we) + + droppedPacketsStat := c.Stack().Stats().DroppedPackets + droppedPackets := droppedPacketsStat.Value() + data := []byte{1, 2, 3} + // Send a packet with no TCP options/timestamp. + rep.SendPacket(data, nil) + + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } - // Assert that DroppedPackets was not incremented. - if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { - t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) - } + // Assert that DroppedPackets was not incremented. + if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { + t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) + } - // Issue a read and we should data. - var buf bytes.Buffer - result, err := c.EP.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } + // Issue a read and we should data. + var buf bytes.Buffer + result, err := c.EP.Read(&buf, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) + } + if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { + t.Fatalf("Data is different: got: %v, want: %v", got, want) + } + }) } func TestMain(m *testing.M) { refs.SetLeakMode(refs.LeaksPanic) code := m.Run() - // Allow TCP async work to complete to avoid false reports of leaks. - // TODO(gvisor.dev/issue/5940): Use fake clock in tests. - time.Sleep(1 * time.Second) refs.DoLeakCheck() os.Exit(code) }