diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e0e85a1..a76e68c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -14,10 +14,12 @@ import ( "net/http/cookiejar" "net/http/httptest" "net/url" + "runtime" "strings" "sync" "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -597,3 +599,54 @@ func TestEventID(t *testing.T) { }) } } + +func TestStreamableClientConnSSEGoroutineLeak(t *testing.T) { + // Initialize a streamableClientConn instance with channels + conn := &streamableClientConn{ + incoming: make(chan []byte, 1), + done: make(chan struct{}), + } + + // Construct mock SSE response data + var builder strings.Builder + for range 3 { + builder.WriteString("data: hello world\n\n") + } + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader(builder.String())), + } + + // Start the handleSSE goroutine manually + var wg sync.WaitGroup + wg.Add(1) + go func() { + conn.handleSSE(resp) + wg.Done() + }() + + // Wait until incoming channel is filled + deadlineCtx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelFunc() + for len(conn.incoming) < cap(conn.incoming) { // Ensure enough events are written + select { + case <-deadlineCtx.Done(): + t.Fatalf("timeout when waiting for streamableClientConn.incoming to be full: %v", deadlineCtx.Err()) + default: + // Continue checking until the channel is full + } + } + + // Now simulate calling Close() and blocking the goroutine + close(conn.done) + wg.Wait() + + // Check if "scanEvents" goroutine is still running + leakKey := "scanEvents" + buf := make([]byte, 1024*1024) + n := runtime.Stack(buf, true) + stack := string(buf[:n]) + + if idx := strings.Index(stack, leakKey); idx != -1 { + t.Fatalf("goroutine leak detected: %s still active", leakKey) + } +}