From 649f399c5e096bd7532c68c6e097027d31f545f4 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 9 Jul 2025 16:19:49 +0000 Subject: [PATCH 1/3] mcp/streamable: add resumability for the Streamable transport This CL implements a retry mechanism to resume SSE streams to recover from network failures. --- mcp/streamable.go | 225 ++++++++++++++++++++++++++++++++++++----- mcp/streamable_test.go | 153 ++++++++++++++++++++++++++++ 2 files changed, 351 insertions(+), 27 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index ef740d3..2660eae 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -9,11 +9,14 @@ import ( "context" "fmt" "io" + "math" + "math/rand/v2" "net/http" "strconv" "strings" "sync" "sync/atomic" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -597,12 +600,39 @@ type StreamableClientTransport struct { opts StreamableClientTransportOptions } +// StreamableReconnectOptions defines parameters for client reconnect attempts. +type StreamableReconnectOptions struct { + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // A value of 0 or less means never retry. + MaxRetries int + + // growFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + growFactor float64 + + // initialDelay is the base delay for the first reconnect attempt. + initialDelay time.Duration + + // maxDelay caps the backoff delay, preventing it from growing indefinitely. + maxDelay time.Duration +} + +// DefaultReconnectOptions provides sensible defaults for reconnect logic. +var DefaultReconnectOptions = &StreamableReconnectOptions{ + MaxRetries: 5, + growFactor: 1.5, + initialDelay: 1 * time.Second, + maxDelay: 30 * time.Second, +} + // StreamableClientTransportOptions provides options for the // [NewStreamableClientTransport] constructor. type StreamableClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. - HTTPClient *http.Client + HTTPClient *http.Client + ReconnectOptions *StreamableReconnectOptions } // NewStreamableClientTransport returns a new client transport that connects to @@ -628,22 +658,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er if client == nil { client = http.DefaultClient } - return &streamableClientConn{ - url: t.url, - client: client, - incoming: make(chan []byte, 100), - done: make(chan struct{}), - }, nil + reconnOpts := t.opts.ReconnectOptions + if reconnOpts == nil { + reconnOpts = DefaultReconnectOptions + } + // Create a new cancellable context that will manage the connection's lifecycle. + // This is crucial for cleanly shutting down the background SSE listener by + // cancelling its blocking network operations, which prevents hangs on exit. + connCtx, cancel := context.WithCancel(context.Background()) + conn := &streamableClientConn{ + url: t.url, + client: client, + incoming: make(chan []byte, 100), + done: make(chan struct{}), + ReconnectOptions: reconnOpts, + ctx: connCtx, + cancel: cancel, + } + return conn, nil } type streamableClientConn struct { - url string - client *http.Client - incoming chan []byte - done chan struct{} + url string + client *http.Client + incoming chan []byte + done chan struct{} + ReconnectOptions *StreamableReconnectOptions closeOnce sync.Once closeErr error + ctx context.Context + cancel context.CancelFunc mu sync.Mutex protocolVersion string @@ -665,6 +710,12 @@ func (c *streamableClientConn) SessionID() string { // Read implements the [Connection] interface. func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + s.mu.Lock() + err := s.err + s.mu.Unlock() + if err != nil { + return nil, err + } select { case <-ctx.Done(): return nil, ctx.Err() @@ -745,6 +796,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string sessionID = resp.Header.Get(sessionIDHeader) switch ct := resp.Header.Get("Content-Type"); ct { case "text/event-stream": + // Section 2.1: The SSE stream is initiated after a POST. go s.handleSSE(resp) case "application/json": // TODO: read the body and send to s.incoming (in a select that also receives from s.done). @@ -757,34 +809,118 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return sessionID, nil } -func (s *streamableClientConn) handleSSE(resp *http.Response) { +// handleSSE manages the entire lifecycle of an SSE connection. It processes +// an incoming Server-Sent Events stream and automatically handles reconnection +// logic if the stream breaks. +func (s *streamableClientConn) handleSSE(initialResp *http.Response) { + resp := initialResp + var lastEventID string + + for { + eventID, clientClosed := s.processStream(resp) + lastEventID = eventID + + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + + // The stream was interrupted or ended by the server. Attempt to reconnect. + newResp, reconnectErr := s.reconnect(lastEventID) + if reconnectErr != nil { + // All reconnection attempts failed. Set the final error, close the + // connection, and exit the goroutine. + s.mu.Lock() + s.err = reconnectErr + s.mu.Unlock() + s.Close() + return + } + + // Reconnection was successful. Continue the loop with the new response. + resp = newResp + } +} + +// processStream reads from a single response body, sending events to the +// incoming channel. It returns the ID of the last processed event, any error +// that occurred, and a flag indicating if the connection was closed by the client. +func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { defer resp.Body.Close() - done := make(chan struct{}) - go func() { - defer close(done) - for evt, err := range scanEvents(resp.Body) { - if err != nil { - // TODO: surface this error; possibly break the stream - return + for evt, err := range scanEvents(resp.Body) { + if err != nil { + return lastEventID, false + } + + if evt.ID != "" { + lastEventID = evt.ID + } + + select { + case s.incoming <- evt.Data: + case <-s.done: + // The connection was closed by the client; exit gracefully. + return lastEventID, true + } + } + + // The loop finished without an error, indicating the server closed the stream. + // We'll attempt to reconnect, so this is not a client-side close. + return lastEventID, false +} + +// reconnect handles the logic of retrying a connection with an exponential +// backoff strategy. It returns a new, valid HTTP response if successful, or +// an error if all retries are exhausted. +func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { + var finalErr error + + for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ { + select { + case <-s.done: + return nil, fmt.Errorf("connection closed by client during reconnect") + case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)): + resp, reconnectErr := s.establishSSE(lastEventID) + if reconnectErr != nil { + finalErr = reconnectErr // Store the error and try again. + continue } - select { - case <-s.done: - return - case s.incoming <- evt.Data: + + if !isResumable(resp) { + // The server indicated we should not continue. + resp.Body.Close() + return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status) } + + return resp, nil } - }() + } + // If the loop completes, all retries have failed. + if finalErr != nil { + return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr) + } + return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries) +} - select { - case <-s.done: - case <-done: +// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. +func isResumable(resp *http.Response) bool { + // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. + if resp.StatusCode == http.StatusMethodNotAllowed { + return false } + + if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + return false + } + return true } // Close implements the [Connection] interface. func (s *streamableClientConn) Close() error { s.closeOnce.Do(func() { + // Cancel any hanging network requests. + s.cancel() close(s.done) req, err := http.NewRequest(http.MethodDelete, s.url, nil) @@ -803,3 +939,38 @@ func (s *streamableClientConn) Close() error { }) return s.closeErr } + +// establishSSE establishes the persistent SSE listening stream. +// It is used for reconnect attempts using the Last-Event-ID header to +// resume a broken stream where it left off. +func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) { + req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil) + if err != nil { + return nil, err + } + s.mu.Lock() + if s._sessionID != "" { + req.Header.Set("Mcp-Session-Id", s._sessionID) + } + s.mu.Unlock() + if lastEventID != "" { + req.Header.Set("Last-Event-ID", lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := s.client.Do(req) + return resp, err +} + +// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. +func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration { + // Calculate the exponential backoff using the grow factor. + backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt))) + // Cap the backoffDuration at maxDelay. + backoffDuration = min(backoffDuration, opts.maxDelay) + + // Use a full jitter using backoffDuration + jitter := rand.N(backoffDuration) + + return backoffDuration + jitter +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 05bf59e..4829c6e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -10,14 +10,17 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" + "net/http/httputil" "net/url" "strings" "sync" "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -105,6 +108,156 @@ func TestStreamableTransports(t *testing.T) { } } +// TestClientReplayAfterProxyBreak verifies that the client can recover from a +// mid-stream network failure and receive replayed messages. It uses a proxy +// that is killed and restarted to simulate a recoverable network outage. +func TestClientReplayAfterProxyBreak(t *testing.T) { + // 1. Configure the real MCP server. + server := NewServer(testImpl, nil) + + // Use a channel to synchronize the server's message sending with the test's + // proxy-killing action. + serverReadyToKillProxy := make(chan struct{}) + var serverClosed sync.WaitGroup + AddTool(server, &Tool{Name: "multiMessageTool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { + go func() { + bgCtx := context.Background() + // Send the first two messages immediately. + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + + // Signal the test that it can now kill the proxy. + serverClosed.Add(1) + close(serverReadyToKillProxy) + // Wait for the test to kill the proxy before sending the rest. + serverClosed.Wait() + + // These messages should be queued for replay by the server after + // the client's connection drops. + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + }() + return &CallToolResultFor[any]{}, nil + }) + realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer realServer.Close() + realServerURL, err := url.Parse(realServer.URL) + if err != nil { + t.Fatalf("Failed to parse real server URL: %v", err) + } + + // 2. Configure a proxy that sits between the client and the real server. + proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL) + proxy := httptest.NewServer(proxyHandler) + proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. + + // 3. Configure the client to connect to the proxy with default options. + clientTransport := NewStreamableClientTransport(proxy.URL, &StreamableClientTransportOptions{ + ReconnectOptions: &StreamableReconnectOptions{ + maxDelay: 50 * time.Millisecond, + MaxRetries: 5, + growFactor: 1.0, + initialDelay: 10 * time.Millisecond, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // 4. Connect, perform handshake, and trigger the tool. + conn, err := clientTransport.Connect(ctx) + if err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Perform handshake. + initReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(100), Method: "initialize", Params: mustMarshal(t, &InitializeParams{})} + if err := conn.Write(ctx, initReq); err != nil { + t.Fatalf("Write(initialize) failed: %v", err) + } + if _, err := conn.Read(ctx); err != nil { + t.Fatalf("Read(initialize resp) failed: %v", err) + } + if err := conn.Write(ctx, &jsonrpc.Request{Method: "initialized", Params: mustMarshal(t, &InitializedParams{})}); err != nil { + t.Fatalf("Write(initialized) failed: %v", err) + } + + callReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "tools/call", Params: mustMarshal(t, &CallToolParams{Name: "multiMessageTool"})} + if err := conn.Write(ctx, callReq); err != nil { + t.Fatalf("Write(tool/call) failed: %v", err) + } + + // 5. Read and verify messages until the server signals it's ready for the proxy kill. + receivedNotifications := readProgressNotifications(t, ctx, conn, 2) + + wantReceived := []jsonrpc.Message{ + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg1"})}, + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg2"})}, + } + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) + + if diff := cmp.Diff(wantReceived, receivedNotifications, transform); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } + + select { + case <-serverReadyToKillProxy: + // Server has sent the first two messages and is paused. + case <-ctx.Done(): + t.Fatalf("Context timed out before server was ready to kill proxy") + } + + // 6. Simulate a total network failure by closing the proxy. + t.Log("--- Killing proxy to simulate network failure ---") + proxy.CloseClientConnections() + proxy.Close() + serverClosed.Done() + + // 7. Simulate network recovery by restarting the proxy on the same address. + t.Logf("--- Restarting proxy on %s ---", proxyAddr) + listener, err := net.Listen("tcp", proxyAddr) + if err != nil { + t.Fatalf("Failed to listen on proxy address: %v", err) + } + restartedProxy := &http.Server{Handler: proxyHandler} + go restartedProxy.Serve(listener) + defer restartedProxy.Close() + + // 8. Continue reading from the same connection object. + // Its internal logic should successfully retry, reconnect to the new proxy, + // and receive the replayed messages. + recoveredNotifications := readProgressNotifications(t, ctx, conn, 2) + + // 9. Verify the correct messages were received on the recovered connection. + wantRecovered := []jsonrpc.Message{ + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg3"})}, + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg4"})}, + } + + if diff := cmp.Diff(wantRecovered, recoveredNotifications, transform); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } +} + +// Helper to read a specific number of progress notifications. +func readProgressNotifications(t *testing.T, ctx context.Context, conn Connection, count int) []jsonrpc.Message { + t.Helper() + var notifications []jsonrpc.Message + for len(notifications) < count && ctx.Err() == nil { + msg, err := conn.Read(ctx) + if err != nil { + t.Fatalf("Failed to read notification: %v", err) + } + if req, ok := msg.(*jsonrpc.Request); ok && req.Method == "notifications/progress" { + notifications = append(notifications, req) + } + } + if len(notifications) != count { + t.Fatalf("Expected to read %d notifications, but got %d", count, len(notifications)) + } + return notifications +} + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP From 38dcd47361caf84c6b8c6cedec2bb4630ed12567 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Tue, 22 Jul 2025 18:26:28 +0000 Subject: [PATCH 2/3] mcp/streamable: update test to use client and minor tweaks --- mcp/streamable.go | 20 +++--- mcp/streamable_test.go | 146 +++++++++++++++++------------------------ 2 files changed, 67 insertions(+), 99 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 2660eae..cd2d1cb 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -826,12 +826,12 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) { } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, reconnectErr := s.reconnect(lastEventID) - if reconnectErr != nil { + newResp, err := s.reconnect(lastEventID) + if err != nil { // All reconnection attempts failed. Set the final error, close the // connection, and exit the goroutine. s.mu.Lock() - s.err = reconnectErr + s.err = err s.mu.Unlock() s.Close() return @@ -881,9 +881,9 @@ func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, er case <-s.done: return nil, fmt.Errorf("connection closed by client during reconnect") case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)): - resp, reconnectErr := s.establishSSE(lastEventID) - if reconnectErr != nil { - finalErr = reconnectErr // Store the error and try again. + resp, err := s.establishSSE(lastEventID) + if err != nil { + finalErr = err // Store the error and try again. continue } @@ -910,10 +910,7 @@ func isResumable(resp *http.Response) bool { return false } - if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { - return false - } - return true + return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") } // Close implements the [Connection] interface. @@ -958,8 +955,7 @@ func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response, } req.Header.Set("Accept", "text/event-stream") - resp, err := s.client.Do(req) - return resp, err + return s.client.Do(req) } // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 4829c6e..873473a 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/jsonschema" ) func TestStreamableTransports(t *testing.T) { @@ -108,37 +109,37 @@ func TestStreamableTransports(t *testing.T) { } } -// TestClientReplayAfterProxyBreak verifies that the client can recover from a +// TestClientReplay verifies that the client can recover from a // mid-stream network failure and receive replayed messages. It uses a proxy // that is killed and restarted to simulate a recoverable network outage. -func TestClientReplayAfterProxyBreak(t *testing.T) { +func TestClientReplay(t *testing.T) { + notifications := make(chan string, 10) // 1. Configure the real MCP server. server := NewServer(testImpl, nil) // Use a channel to synchronize the server's message sending with the test's // proxy-killing action. serverReadyToKillProxy := make(chan struct{}) - var serverClosed sync.WaitGroup - AddTool(server, &Tool{Name: "multiMessageTool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { - go func() { - bgCtx := context.Background() - // Send the first two messages immediately. - _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) - _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) - - // Signal the test that it can now kill the proxy. - serverClosed.Add(1) - close(serverReadyToKillProxy) - // Wait for the test to kill the proxy before sending the rest. - serverClosed.Wait() - - // These messages should be queued for replay by the server after - // the client's connection drops. - _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) - _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) - }() - return &CallToolResultFor[any]{}, nil - }) + serverClosed := make(chan struct{}) + server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, + func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + go func() { + bgCtx := context.Background() + // Send the first two messages immediately. + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + + // Signal the test that it can now kill the proxy. + close(serverReadyToKillProxy) + <-serverClosed + + // These messages should be queued for replay by the server after + // the client's connection drops. + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + }() + return &CallToolResult{}, nil + }) realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) defer realServer.Close() realServerURL, err := url.Parse(realServer.URL) @@ -152,52 +153,25 @@ func TestClientReplayAfterProxyBreak(t *testing.T) { proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. // 3. Configure the client to connect to the proxy with default options. - clientTransport := NewStreamableClientTransport(proxy.URL, &StreamableClientTransportOptions{ - ReconnectOptions: &StreamableReconnectOptions{ - maxDelay: 50 * time.Millisecond, - MaxRetries: 5, - growFactor: 1.0, - initialDelay: 10 * time.Millisecond, - }, - }) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - - // 4. Connect, perform handshake, and trigger the tool. - conn, err := clientTransport.Connect(ctx) + client := NewClient(testImpl, &ClientOptions{ + ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) { + notifications <- params.Message + }}) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil)) if err != nil { - t.Fatalf("Connect() failed: %v", err) - } - - // Perform handshake. - initReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(100), Method: "initialize", Params: mustMarshal(t, &InitializeParams{})} - if err := conn.Write(ctx, initReq); err != nil { - t.Fatalf("Write(initialize) failed: %v", err) - } - if _, err := conn.Read(ctx); err != nil { - t.Fatalf("Read(initialize resp) failed: %v", err) - } - if err := conn.Write(ctx, &jsonrpc.Request{Method: "initialized", Params: mustMarshal(t, &InitializedParams{})}); err != nil { - t.Fatalf("Write(initialized) failed: %v", err) - } - - callReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "tools/call", Params: mustMarshal(t, &CallToolParams{Name: "multiMessageTool"})} - if err := conn.Write(ctx, callReq); err != nil { - t.Fatalf("Write(tool/call) failed: %v", err) + t.Fatalf("client.Connect() failed: %v", err) } + defer clientSession.Close() + clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) - // 5. Read and verify messages until the server signals it's ready for the proxy kill. - receivedNotifications := readProgressNotifications(t, ctx, conn, 2) - - wantReceived := []jsonrpc.Message{ - &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg1"})}, - &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg2"})}, - } - transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) + // 4. Read and verify messages until the server signals it's ready for the proxy kill. + receivedNotifications := readProgressNotifications(t, ctx, notifications, 2) - if diff := cmp.Diff(wantReceived, receivedNotifications, transform); diff != "" { - t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + wantReceived := []string{"msg1", "msg2"} + if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { + t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) } select { @@ -207,13 +181,13 @@ func TestClientReplayAfterProxyBreak(t *testing.T) { t.Fatalf("Context timed out before server was ready to kill proxy") } - // 6. Simulate a total network failure by closing the proxy. + // 5. Simulate a total network failure by closing the proxy. t.Log("--- Killing proxy to simulate network failure ---") proxy.CloseClientConnections() proxy.Close() - serverClosed.Done() + close(serverClosed) - // 7. Simulate network recovery by restarting the proxy on the same address. + // 6. Simulate network recovery by restarting the proxy on the same address. t.Logf("--- Restarting proxy on %s ---", proxyAddr) listener, err := net.Listen("tcp", proxyAddr) if err != nil { @@ -223,39 +197,37 @@ func TestClientReplayAfterProxyBreak(t *testing.T) { go restartedProxy.Serve(listener) defer restartedProxy.Close() - // 8. Continue reading from the same connection object. + // 7. Continue reading from the same connection object. // Its internal logic should successfully retry, reconnect to the new proxy, // and receive the replayed messages. - recoveredNotifications := readProgressNotifications(t, ctx, conn, 2) + recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2) - // 9. Verify the correct messages were received on the recovered connection. - wantRecovered := []jsonrpc.Message{ - &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg3"})}, - &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg4"})}, - } + // 8. Verify the correct messages were received on the recovered connection. + wantRecovered := []string{"msg3", "msg4"} - if diff := cmp.Diff(wantRecovered, recoveredNotifications, transform); diff != "" { + if diff := cmp.Diff(wantRecovered, recoveredNotifications); diff != "" { t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) } } // Helper to read a specific number of progress notifications. -func readProgressNotifications(t *testing.T, ctx context.Context, conn Connection, count int) []jsonrpc.Message { +func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { t.Helper() - var notifications []jsonrpc.Message - for len(notifications) < count && ctx.Err() == nil { - msg, err := conn.Read(ctx) - if err != nil { - t.Fatalf("Failed to read notification: %v", err) - } - if req, ok := msg.(*jsonrpc.Request); ok && req.Method == "notifications/progress" { - notifications = append(notifications, req) + var collectedNotifications []string + for { + select { + case n := <-notifications: + collectedNotifications = append(collectedNotifications, n) + if len(collectedNotifications) == count { + return collectedNotifications + } + case <-ctx.Done(): + if len(collectedNotifications) != count { + t.Fatalf("readProgressNotifications(): did not receive expected notifications, got %d, want %d", len(collectedNotifications), count) + } + return collectedNotifications } } - if len(notifications) != count { - t.Fatalf("Expected to read %d notifications, but got %d", count, len(notifications)) - } - return notifications } func TestStreamableServerTransport(t *testing.T) { From 7d6d014f5dc74e78bd36b291154dc3d339632926 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Tue, 22 Jul 2025 19:00:51 +0000 Subject: [PATCH 3/3] mcp/streamable: update notification channel size --- mcp/streamable_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 873473a..02721d6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -113,7 +113,7 @@ func TestStreamableTransports(t *testing.T) { // mid-stream network failure and receive replayed messages. It uses a proxy // that is killed and restarted to simulate a recoverable network outage. func TestClientReplay(t *testing.T) { - notifications := make(chan string, 10) + notifications := make(chan string) // 1. Configure the real MCP server. server := NewServer(testImpl, nil)