diff --git a/pkg/labels/labels_test.go b/pkg/labels/labels_test.go index 0cee180ae..2f14ac115 100644 --- a/pkg/labels/labels_test.go +++ b/pkg/labels/labels_test.go @@ -725,3 +725,316 @@ func TestGetToolType(t *testing.T) { }) } } + +func TestAddNetworkLabels(t *testing.T) { + t.Parallel() + tests := []struct { + name string + networkName string + expected map[string]string + }{ + { + name: "Add network labels with standard name", + networkName: "test-network", + expected: map[string]string{ + LabelToolHive: "true", + LabelName: "test-network", + }, + }, + { + name: "Add network labels with complex name", + networkName: "my-app-network-prod", + expected: map[string]string{ + LabelToolHive: "true", + LabelName: "my-app-network-prod", + }, + }, + { + name: "Add network labels with empty name", + networkName: "", + expected: map[string]string{ + LabelToolHive: "true", + LabelName: "", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + labels := make(map[string]string) + AddNetworkLabels(labels, tc.networkName) + + // Verify all expected labels are present with correct values + for key, expectedValue := range tc.expected { + actualValue, exists := labels[key] + if !exists { + t.Errorf("Expected label %s to exist, but it doesn't", key) + } + if actualValue != expectedValue { + t.Errorf("Expected label %s to be %s, but got %s", key, expectedValue, actualValue) + } + } + + // Verify no unexpected labels are present + if len(labels) != len(tc.expected) { + t.Errorf("Expected %d labels, but got %d", len(tc.expected), len(labels)) + } + }) + } +} + +func TestAddNetworkIsolationLabel(t *testing.T) { + t.Parallel() + tests := []struct { + name string + networkIsolation bool + expectedValue string + }{ + { + name: "Network isolation enabled", + networkIsolation: true, + expectedValue: "true", + }, + { + name: "Network isolation disabled", + networkIsolation: false, + expectedValue: "false", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + labels := make(map[string]string) + AddNetworkIsolationLabel(labels, tc.networkIsolation) + + // Verify the network isolation label is set correctly + actualValue, exists := labels[LabelNetworkIsolation] + if !exists { + t.Errorf("Expected label %s to exist, but it doesn't", LabelNetworkIsolation) + } + if actualValue != tc.expectedValue { + t.Errorf("Expected label %s to be %s, but got %s", LabelNetworkIsolation, tc.expectedValue, actualValue) + } + + // Verify only one label was added + if len(labels) != 1 { + t.Errorf("Expected 1 label, but got %d", len(labels)) + } + }) + } +} + +func TestGetGroup(t *testing.T) { + t.Parallel() + tests := []struct { + name string + labels map[string]string + expected string + }{ + { + name: "Group name exists", + labels: map[string]string{ + LabelGroup: "production", + }, + expected: "production", + }, + { + name: "Group name with complex value", + labels: map[string]string{ + LabelGroup: "my-app-group", + }, + expected: "my-app-group", + }, + { + name: "Group name doesn't exist", + labels: map[string]string{}, + expected: "", + }, + { + name: "Group name exists but empty", + labels: map[string]string{ + LabelGroup: "", + }, + expected: "", + }, + { + name: "Other labels present but no group", + labels: map[string]string{ + LabelName: "test-container", + LabelToolHive: "true", + }, + expected: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := GetGroup(tc.labels) + if result != tc.expected { + t.Errorf("Expected group name to be %s, but got %s", tc.expected, result) + } + }) + } +} + +func TestSetGroup(t *testing.T) { + t.Parallel() + tests := []struct { + name string + groupName string + expected string + }{ + { + name: "Set standard group name", + groupName: "production", + expected: "production", + }, + { + name: "Set complex group name", + groupName: "my-app-group-dev", + expected: "my-app-group-dev", + }, + { + name: "Set empty group name", + groupName: "", + expected: "", + }, + { + name: "Set group name with special characters", + groupName: "group_123", + expected: "group_123", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + labels := make(map[string]string) + SetGroup(labels, tc.groupName) + + // Verify the group label is set correctly + actualValue, exists := labels[LabelGroup] + if !exists { + t.Errorf("Expected label %s to exist, but it doesn't", LabelGroup) + } + if actualValue != tc.expected { + t.Errorf("Expected label %s to be %s, but got %s", LabelGroup, tc.expected, actualValue) + } + + // Verify only one label was added + if len(labels) != 1 { + t.Errorf("Expected 1 label, but got %d", len(labels)) + } + }) + } +} + +func TestSetGroupOverwrite(t *testing.T) { + t.Parallel() + // Test that SetGroup overwrites existing group value + labels := map[string]string{ + LabelGroup: "old-group", + LabelName: "test-container", + LabelToolHive: "true", + } + + SetGroup(labels, "new-group") + + // Verify the group was updated + if labels[LabelGroup] != "new-group" { + t.Errorf("Expected group to be updated to new-group, but got %s", labels[LabelGroup]) + } + + // Verify other labels were not affected + if labels[LabelName] != "test-container" { + t.Errorf("Expected name label to remain unchanged") + } + if labels[LabelToolHive] != "true" { + t.Errorf("Expected toolhive label to remain unchanged") + } + + // Verify total number of labels + if len(labels) != 3 { + t.Errorf("Expected 3 labels, but got %d", len(labels)) + } +} + +func TestIsAuxiliaryWorkload(t *testing.T) { + t.Parallel() + tests := []struct { + name string + labels map[string]string + expected bool + }{ + { + name: "Auxiliary workload with lowercase true", + labels: map[string]string{ + LabelAuxiliary: "true", + }, + expected: true, + }, + { + name: "Auxiliary workload with uppercase TRUE", + labels: map[string]string{ + LabelAuxiliary: "TRUE", + }, + expected: true, + }, + { + name: "Auxiliary workload with mixed case TrUe", + labels: map[string]string{ + LabelAuxiliary: "TrUe", + }, + expected: true, + }, + { + name: "Not an auxiliary workload - false value", + labels: map[string]string{ + LabelAuxiliary: "false", + }, + expected: false, + }, + { + name: "Not an auxiliary workload - other value", + labels: map[string]string{ + LabelAuxiliary: "yes", + }, + expected: false, + }, + { + name: "Not an auxiliary workload - label missing", + labels: map[string]string{}, + expected: false, + }, + { + name: "Not an auxiliary workload - other labels present", + labels: map[string]string{ + LabelName: "test-container", + LabelToolHive: "true", + }, + expected: false, + }, + { + name: "Auxiliary workload with other labels", + labels: map[string]string{ + LabelAuxiliary: "true", + LabelName: "inspector", + LabelToolHive: "true", + }, + expected: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := IsAuxiliaryWorkload(tc.labels) + if result != tc.expected { + t.Errorf("Expected IsAuxiliaryWorkload to return %v, but got %v", tc.expected, result) + } + }) + } +} diff --git a/pkg/mcp/tool_middleware_test.go b/pkg/mcp/tool_middleware_test.go index 3de2ec3fc..2f12247ff 100644 --- a/pkg/mcp/tool_middleware_test.go +++ b/pkg/mcp/tool_middleware_test.go @@ -483,7 +483,12 @@ func TestNewListToolsMappingMiddleware_SSE_Scenarios(t *testing.T) { // Verify results assert.Equal(t, "2.0", response.JSONRPC) assert.Equal(t, float64(1), response.ID) - assert.Equal(t, tt.expected, response.Result.Tools) + // Use ElementsMatch for order-independent comparison of tools + if tt.expected != nil && response.Result.Tools != nil { + assert.ElementsMatch(t, *tt.expected, *response.Result.Tools, "Tools should match regardless of order") + } else { + assert.Equal(t, tt.expected, response.Result.Tools) + } }) } } diff --git a/pkg/networking/port_test.go b/pkg/networking/port_test.go new file mode 100644 index 000000000..70537acc3 --- /dev/null +++ b/pkg/networking/port_test.go @@ -0,0 +1,515 @@ +package networking + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsAvailable(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupPort func(t *testing.T) int + expected bool + description string + }{ + { + name: "available port returns true", + setupPort: func(t *testing.T) int { + t.Helper() + // Find a truly available port by binding to :0 + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + require.NoError(t, listener.Close()) + return port + }, + expected: true, + description: "Port should be available after closing listener", + }, + { + name: "tcp occupied port returns false", + setupPort: func(t *testing.T) int { + t.Helper() + // Bind to a port and keep it open + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + listener.Close() + }) + return listener.Addr().(*net.TCPAddr).Port + }, + expected: false, + description: "Port should not be available when TCP listener is active", + }, + { + name: "udp occupied port returns false", + setupPort: func(t *testing.T) int { + t.Helper() + // Bind UDP first to get a port + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + udpConn, err := net.ListenUDP("udp", udpAddr) + require.NoError(t, err) + t.Cleanup(func() { + udpConn.Close() + }) + return udpConn.LocalAddr().(*net.UDPAddr).Port + }, + expected: false, + description: "Port should not be available when UDP listener is active", + }, + { + name: "port 0 returns true (special case for any port)", + setupPort: func(t *testing.T) int { + t.Helper() + return 0 + }, + expected: true, + description: "Port 0 is a special case that means 'any available port' and succeeds", + }, + { + name: "negative port returns false", + setupPort: func(t *testing.T) int { + t.Helper() + return -1 + }, + expected: false, + description: "Negative port numbers should return false", + }, + { + name: "port above max range returns false", + setupPort: func(t *testing.T) int { + t.Helper() + return 65536 + }, + expected: false, + description: "Port above 65535 should return false", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + port := tt.setupPort(t) + result := IsAvailable(port) + + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} + +func TestIsIPv6Available(t *testing.T) { + t.Parallel() + + // This test checks if the function correctly detects IPv6 availability + // The result may vary based on the system configuration + t.Run("check IPv6 availability detection", func(t *testing.T) { + t.Parallel() + + result := IsIPv6Available() + + // We can't assert a specific value since it depends on the system, + // but we can verify the function returns without panicking + t.Logf("IPv6 available: %v", result) + + // Verify by manually checking interfaces + interfaces, err := net.Interfaces() + require.NoError(t, err) + + hasIPv6 := false + for _, iface := range interfaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + + if ipNet.IP.To4() == nil && !ipNet.IP.IsLoopback() { + hasIPv6 = true + break + } + } + if hasIPv6 { + break + } + } + + // The function result should match our manual check + assert.Equal(t, hasIPv6, result, "IsIPv6Available should match manual interface check") + }) + + t.Run("handles down interfaces correctly", func(t *testing.T) { + t.Parallel() + + // This test verifies the function skips down interfaces + // We can't mock net.Interfaces() directly, but we can verify + // the function completes successfully + result := IsIPv6Available() + + // Function should return a boolean value without error + assert.IsType(t, false, result) + }) +} + +func TestFindAvailable(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + validate func(t *testing.T, port int) + description string + }{ + { + name: "finds available port in valid range", + validate: func(t *testing.T, port int) { + t.Helper() + if port == 0 { + t.Skip("No available ports found (system may be under heavy load)") + } + assert.GreaterOrEqual(t, port, MinPort, "Port should be >= MinPort") + assert.LessOrEqual(t, port, MaxPort, "Port should be <= MaxPort") + }, + description: "Should find a port between MinPort and MaxPort", + }, + { + name: "returned port is actually available", + validate: func(t *testing.T, port int) { + t.Helper() + if port == 0 { + t.Skip("No available ports found (system may be under heavy load)") + } + + // Verify the port is actually available by trying to bind to it + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + require.NoError(t, err) + tcpAddr.Port = port + + tcpListener, err := net.ListenTCP("tcp", tcpAddr) + if err == nil { + defer tcpListener.Close() + } + assert.NoError(t, err, "Returned port should be available for TCP") + + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + udpAddr.Port = port + + udpConn, err := net.ListenUDP("udp", udpAddr) + if err == nil { + defer udpConn.Close() + } + assert.NoError(t, err, "Returned port should be available for UDP") + }, + description: "Returned port should be bindable for both TCP and UDP", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + port := FindAvailable() + tt.validate(t, port) + }) + } +} + +func TestFindAvailable_ConsecutiveCalls(t *testing.T) { + t.Parallel() + + // Test that consecutive calls can find multiple available ports + t.Run("finds multiple different ports", func(t *testing.T) { + t.Parallel() + + // Find first port + port1 := FindAvailable() + if port1 == 0 { + t.Skip("No available ports found") + } + + // Bind to the first port + listener1, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", string(rune(port1)))) + if err != nil { + // Try alternative binding method + listener1, err = net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + } + t.Cleanup(func() { + listener1.Close() + }) + + // Find second port (should be different) + port2 := FindAvailable() + if port2 == 0 { + t.Skip("Could not find second available port") + } + + // Ports could theoretically be the same if port1 was freed and randomly selected again, + // but that's extremely unlikely, so we just verify both are valid + assert.GreaterOrEqual(t, port1, MinPort) + assert.GreaterOrEqual(t, port2, MinPort) + }) +} + +func TestFindOrUsePort(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupPort func(t *testing.T) int + expectError bool + errorContains string + validate func(t *testing.T, inputPort, returnedPort int, err error) + description string + }{ + { + name: "port 0 finds available port", + setupPort: func(t *testing.T) int { + t.Helper() + return 0 + }, + expectError: false, + validate: func(t *testing.T, _ /* inputPort */, returnedPort int, err error) { + t.Helper() + assert.NoError(t, err) + if returnedPort == 0 { + t.Skip("No available ports found") + } + assert.GreaterOrEqual(t, returnedPort, MinPort, "Returned port should be >= MinPort") + assert.LessOrEqual(t, returnedPort, MaxPort, "Returned port should be <= MaxPort") + }, + description: "When port is 0, should find an available port", + }, + { + name: "available port is returned unchanged", + setupPort: func(t *testing.T) int { + t.Helper() + // Find an available port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + require.NoError(t, listener.Close()) + return port + }, + expectError: false, + validate: func(t *testing.T, inputPort, returnedPort int, err error) { + t.Helper() + assert.NoError(t, err) + assert.Equal(t, inputPort, returnedPort, "Available port should be returned unchanged") + }, + description: "When requested port is available, should return the same port", + }, + { + name: "unavailable port returns alternative", + setupPort: func(t *testing.T) int { + t.Helper() + // Bind to a port and keep it occupied + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + listener.Close() + }) + return listener.Addr().(*net.TCPAddr).Port + }, + expectError: false, + validate: func(t *testing.T, _ /* inputPort */, returnedPort int, err error) { + t.Helper() + assert.NoError(t, err) + if returnedPort == 0 { + t.Skip("No alternative port found") + } + // Note: We can't check that returnedPort differs from inputPort since + // inputPort parameter was removed, but the function will find an alternative + // when the originally requested port is unavailable + assert.GreaterOrEqual(t, returnedPort, MinPort, "Alternative port should be >= MinPort") + assert.LessOrEqual(t, returnedPort, MaxPort, "Alternative port should be <= MaxPort") + }, + description: "When requested port is unavailable, should return an alternative", + }, + { + name: "negative port returns alternative", + setupPort: func(t *testing.T) int { + t.Helper() + return -1 + }, + expectError: false, + validate: func(t *testing.T, inputPort, returnedPort int, err error) { + t.Helper() + assert.NoError(t, err) + if returnedPort == 0 { + t.Skip("No available port found") + } + assert.NotEqual(t, inputPort, returnedPort, "Should not return negative port") + assert.GreaterOrEqual(t, returnedPort, MinPort) + }, + description: "Invalid port should trigger finding an alternative", + }, + { + name: "port above max range returns alternative", + setupPort: func(t *testing.T) int { + t.Helper() + return 70000 + }, + expectError: false, + validate: func(t *testing.T, _ /* inputPort */, returnedPort int, err error) { + t.Helper() + assert.NoError(t, err) + if returnedPort == 0 { + t.Skip("No available port found") + } + assert.LessOrEqual(t, returnedPort, MaxPort, "Should return port within valid range") + }, + description: "Port above max should trigger finding an alternative", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + inputPort := tt.setupPort(t) + returnedPort, err := FindOrUsePort(inputPort) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else if tt.validate != nil { + tt.validate(t, inputPort, returnedPort, err) + } + }) + } +} + +func TestIsAvailable_Concurrent(t *testing.T) { + t.Parallel() + + // Test that IsAvailable is safe to call concurrently + t.Run("concurrent calls are safe", func(t *testing.T) { + t.Parallel() + + const numGoroutines = 10 + done := make(chan bool, numGoroutines) + + // Find a port to test + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + testPort := listener.Addr().(*net.TCPAddr).Port + require.NoError(t, listener.Close()) + + // Launch concurrent calls + for i := 0; i < numGoroutines; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("IsAvailable panicked: %v", r) + } + done <- true + }() + + // Call IsAvailable multiple times + for j := 0; j < 5; j++ { + IsAvailable(testPort) + } + }() + } + + // Wait for all goroutines to complete + for i := 0; i < numGoroutines; i++ { + <-done + } + }) +} + +func TestFindAvailable_Concurrent(t *testing.T) { + t.Parallel() + + // Test that FindAvailable is safe to call concurrently + t.Run("concurrent calls are safe", func(t *testing.T) { + t.Parallel() + + const numGoroutines = 10 + done := make(chan int, numGoroutines) + + // Launch concurrent calls + for i := 0; i < numGoroutines; i++ { + go func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("FindAvailable panicked: %v", r) + done <- 0 + return + } + }() + + port := FindAvailable() + done <- port + }() + } + + // Collect results + ports := make([]int, 0, numGoroutines) + for i := 0; i < numGoroutines; i++ { + port := <-done + ports = append(ports, port) + } + + // Verify all calls completed and returned valid results + for i, port := range ports { + if port == 0 { + t.Logf("Goroutine %d: no available port found", i) + continue + } + assert.GreaterOrEqual(t, port, MinPort, "Port from goroutine %d should be >= MinPort", i) + assert.LessOrEqual(t, port, MaxPort, "Port from goroutine %d should be <= MaxPort", i) + } + }) +} + +// Benchmark tests +func BenchmarkIsAvailable(b *testing.B) { + // Find an available port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(b, err) + port := listener.Addr().(*net.TCPAddr).Port + require.NoError(b, listener.Close()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + IsAvailable(port) + } +} + +func BenchmarkFindAvailable(b *testing.B) { + for i := 0; i < b.N; i++ { + FindAvailable() + } +} + +func BenchmarkFindOrUsePort(b *testing.B) { + for i := 0; i < b.N; i++ { + FindOrUsePort(0) + } +} + +func BenchmarkIsIPv6Available(b *testing.B) { + for i := 0; i < b.N; i++ { + IsIPv6Available() + } +}