Skip to content

Commit 891161c

Browse files
committed
mcp/streamable: add resumability for the Streamable transport
This CL implements a retry mechanism to resume SSE streams to recover from network failures.
1 parent de4b788 commit 891161c

File tree

2 files changed

+289
-14
lines changed

2 files changed

+289
-14
lines changed

mcp/streamable.go

Lines changed: 143 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"math"
13+
"math/rand/v2"
1214
"net/http"
1315
"strconv"
1416
"strings"
1517
"sync"
1618
"sync/atomic"
19+
"time"
1720

1821
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -594,12 +597,38 @@ type StreamableClientTransport struct {
594597
opts StreamableClientTransportOptions
595598
}
596599

600+
// StreamableClientReconnectionOptions defines parameters for client reconnection attempts.
601+
type StreamableClientReconnectionOptions struct {
602+
// InitialDelay is the base delay for the first reconnection attempt
603+
InitialDelay time.Duration // default: 1 second
604+
605+
// MaxDelay caps the backoff delay, preventing it from growing indefinitely.
606+
MaxDelay time.Duration // default: 30 seconds
607+
608+
// GrowFactor is the multiplicative factor by which the delay increases after each attempt.
609+
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
610+
GrowFactor float64 // default: 1.5
611+
612+
// MaxRetries is the maximum number of times to attempt reconnection before giving up.
613+
// A value of 0 or less means never retry.
614+
MaxRetries int // default: 5
615+
}
616+
617+
// DefaultReconnectionOptions provides sensible defaults for reconnection logic.
618+
var DefaultReconnectionOptions = &StreamableClientReconnectionOptions{
619+
InitialDelay: 1 * time.Second,
620+
MaxDelay: 30 * time.Second,
621+
GrowFactor: 1.5,
622+
MaxRetries: 5,
623+
}
624+
597625
// StreamableClientTransportOptions provides options for the
598626
// [NewStreamableClientTransport] constructor.
599627
type StreamableClientTransportOptions struct {
600628
// HTTPClient is the client to use for making HTTP requests. If nil,
601629
// http.DefaultClient is used.
602-
HTTPClient *http.Client
630+
HTTPClient *http.Client
631+
ReconnectionOptions *StreamableClientReconnectionOptions
603632
}
604633

605634
// NewStreamableClientTransport returns a new client transport that connects to
@@ -625,19 +654,26 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
625654
if client == nil {
626655
client = http.DefaultClient
627656
}
628-
return &streamableClientConn{
629-
url: t.url,
630-
client: client,
631-
incoming: make(chan []byte, 100),
632-
done: make(chan struct{}),
633-
}, nil
657+
reconnOpts := t.opts.ReconnectionOptions
658+
if reconnOpts == nil {
659+
reconnOpts = DefaultReconnectionOptions
660+
}
661+
conn := &streamableClientConn{
662+
url: t.url,
663+
client: client,
664+
incoming: make(chan []byte, 100),
665+
done: make(chan struct{}),
666+
reconnectionOptions: reconnOpts,
667+
}
668+
return conn, nil
634669
}
635670

636671
type streamableClientConn struct {
637-
url string
638-
client *http.Client
639-
incoming chan []byte
640-
done chan struct{}
672+
url string
673+
client *http.Client
674+
incoming chan []byte
675+
done chan struct{}
676+
reconnectionOptions *StreamableClientReconnectionOptions
641677

642678
closeOnce sync.Once
643679
closeErr error
@@ -704,11 +740,19 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
704740
if sessionID == "" {
705741
// locked
706742
s._sessionID = gotSessionID
743+
// With the session now established, launch the persistent background listener for server-pushed events.
744+
go s.establishSSE(context.Background(), &startSSEOptions{})
707745
}
708746

709747
return nil
710748
}
711749

750+
// startSSEOptions holds parameters for initiating an SSE stream.
751+
type startSSEOptions struct {
752+
lastEventID string
753+
attempt int
754+
}
755+
712756
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
713757
data, err := jsonrpc2.EncodeMessage(msg)
714758
if err != nil {
@@ -742,7 +786,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
742786
sessionID = resp.Header.Get(sessionIDHeader)
743787
switch ct := resp.Header.Get("Content-Type"); ct {
744788
case "text/event-stream":
745-
go s.handleSSE(resp)
789+
go s.handleSSE(context.Background(), resp, &startSSEOptions{})
746790
case "application/json":
747791
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
748792
resp.Body.Close()
@@ -754,17 +798,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
754798
return sessionID, nil
755799
}
756800

757-
func (s *streamableClientConn) handleSSE(resp *http.Response) {
801+
// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
802+
// If the stream breaks, it uses the last received event ID to automatically trigger the reconnection logic.
803+
func (s *streamableClientConn) handleSSE(ctx context.Context, resp *http.Response, opts *startSSEOptions) {
758804
defer resp.Body.Close()
759805

760806
done := make(chan struct{})
761807
go func() {
762808
defer close(done)
763809
for evt, err := range scanEvents(resp.Body) {
764810
if err != nil {
765-
// TODO: surface this error; possibly break the stream
811+
s.scheduleReconnection(ctx, opts)
766812
return
767813
}
814+
opts.lastEventID = evt.id
768815
select {
769816
case <-s.done:
770817
return
@@ -800,3 +847,85 @@ func (s *streamableClientConn) Close() error {
800847
})
801848
return s.closeErr
802849
}
850+
851+
// establishSSE creates and manages the persistent SSE listening stream.
852+
// It is used for both the initial connection and all subsequent reconnection attempts,
853+
// using the Last-Event-ID header to resume a broken stream where it left off.
854+
// On a successful response, it delegates to handleSSE to process events;
855+
// on failure, it triggers the client's reconnection logic.
856+
func (s *streamableClientConn) establishSSE(ctx context.Context, opts *startSSEOptions) {
857+
select {
858+
case <-s.done:
859+
return
860+
case <-ctx.Done():
861+
return
862+
default:
863+
}
864+
865+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.url, nil)
866+
if err != nil {
867+
return
868+
}
869+
s.mu.Lock()
870+
if s._sessionID != "" {
871+
req.Header.Set("Mcp-Session-Id", s._sessionID)
872+
}
873+
s.mu.Unlock()
874+
if opts.lastEventID != "" {
875+
req.Header.Set("Last-Event-ID", opts.lastEventID)
876+
}
877+
req.Header.Set("Accept", "text/event-stream")
878+
879+
resp, err := s.client.Do(req)
880+
if err != nil {
881+
// On connection error, schedule a retry.
882+
s.scheduleReconnection(ctx, opts)
883+
return
884+
}
885+
886+
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
887+
if resp.StatusCode == http.StatusMethodNotAllowed {
888+
resp.Body.Close()
889+
return
890+
}
891+
892+
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
893+
resp.Body.Close()
894+
s.scheduleReconnection(ctx, opts)
895+
return
896+
}
897+
898+
s.handleSSE(ctx, resp, opts)
899+
}
900+
901+
// scheduleReconnection schedules the next SSE reconnection attempt after a delay.
902+
func (s *streamableClientConn) scheduleReconnection(ctx context.Context, opts *startSSEOptions) {
903+
reconnOpts := s.reconnectionOptions
904+
if opts.attempt >= reconnOpts.MaxRetries {
905+
return
906+
}
907+
908+
delay := calculateReconnectionDelay(reconnOpts, opts.attempt)
909+
910+
select {
911+
case <-s.done:
912+
return
913+
case <-time.After(delay):
914+
opts.attempt++
915+
s.establishSSE(ctx, opts)
916+
}
917+
}
918+
919+
// calculateReconnectionDelay calculates a delay using exponential backoff with full jitter.
920+
func calculateReconnectionDelay(opts *StreamableClientReconnectionOptions, attempt int) time.Duration {
921+
// Calculate the exponential backoff using the grow factor.
922+
backoffDuration := time.Duration(float64(opts.InitialDelay) * math.Pow(opts.GrowFactor, float64(attempt)))
923+
924+
// Cap the backoffDuration at maxDelay.
925+
backoffDuration = min(backoffDuration, opts.MaxDelay)
926+
927+
// Use a full jitter using backoffDuration
928+
jitter := rand.N(backoffDuration)
929+
930+
return backoffDuration + jitter
931+
}

mcp/streamable_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ import (
1010
"encoding/json"
1111
"fmt"
1212
"io"
13+
"net"
1314
"net/http"
1415
"net/http/cookiejar"
1516
"net/http/httptest"
17+
"net/http/httputil"
1618
"net/url"
1719
"strings"
1820
"sync"
1921
"sync/atomic"
2022
"testing"
23+
"time"
2124

2225
"github.com/google/go-cmp/cmp"
2326
"github.com/google/go-cmp/cmp/cmpopts"
@@ -105,6 +108,149 @@ func TestStreamableTransports(t *testing.T) {
105108
}
106109
}
107110

111+
// TestClientReplayAfterProxyBreak verifies that the client can recover from a
112+
// mid-stream network failure and receive replayed messages. It uses a proxy
113+
// that is killed and restarted to simulate a recoverable network outage.
114+
func TestClientReplayAfterProxyBreak(t *testing.T) {
115+
// 1. Configure the real MCP server.
116+
server := NewServer(testImpl, nil)
117+
118+
// Use a channel to synchronize the server's message sending with the test's
119+
// proxy-killing action.
120+
serverReadyToKillProxy := make(chan struct{})
121+
var serverClosed sync.WaitGroup
122+
AddTool(server, &Tool{Name: "multiMessageTool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) {
123+
go func() {
124+
bgCtx := context.Background()
125+
// Send the first two messages immediately.
126+
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"})
127+
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"})
128+
129+
// Signal the test that it can now kill the proxy.
130+
serverClosed.Add(1)
131+
close(serverReadyToKillProxy)
132+
// Wait for the test to kill the proxy before sending the rest.
133+
serverClosed.Wait()
134+
135+
// These messages should be queued for replay by the server after
136+
// the client's connection drops.
137+
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"})
138+
_ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"})
139+
}()
140+
return &CallToolResultFor[any]{}, nil
141+
})
142+
realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil))
143+
defer realServer.Close()
144+
realServerURL, err := url.Parse(realServer.URL)
145+
if err != nil {
146+
t.Fatalf("Failed to parse real server URL: %v", err)
147+
}
148+
149+
// 2. Configure a proxy that sits between the client and the real server.
150+
proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL)
151+
proxy := httptest.NewServer(proxyHandler)
152+
proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later.
153+
154+
// 3. Configure the client to connect to the proxy with default options.
155+
clientTransport := NewStreamableClientTransport(proxy.URL, nil)
156+
157+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
158+
defer cancel()
159+
160+
// 4. Connect, perform handshake, and trigger the tool.
161+
conn, err := clientTransport.Connect(ctx)
162+
if err != nil {
163+
t.Fatalf("Connect() failed: %v", err)
164+
}
165+
166+
// Perform handshake.
167+
initReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(100), Method: "initialize", Params: mustMarshal(t, &InitializeParams{})}
168+
if err := conn.Write(ctx, initReq); err != nil {
169+
t.Fatalf("Write(initialize) failed: %v", err)
170+
}
171+
if _, err := conn.Read(ctx); err != nil {
172+
t.Fatalf("Read(initialize resp) failed: %v", err)
173+
}
174+
if err := conn.Write(ctx, &jsonrpc.Request{Method: "initialized", Params: mustMarshal(t, &InitializedParams{})}); err != nil {
175+
t.Fatalf("Write(initialized) failed: %v", err)
176+
}
177+
178+
callReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "tools/call", Params: mustMarshal(t, &CallToolParams{Name: "multiMessageTool"})}
179+
if err := conn.Write(ctx, callReq); err != nil {
180+
t.Fatalf("Write(tool/call) failed: %v", err)
181+
}
182+
183+
// 5. Read and verify messages until the server signals it's ready for the proxy kill.
184+
receivedNotifications := readProgressNotifications(t, ctx, conn, 2)
185+
186+
wantReceived := []jsonrpc.Message{
187+
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg1"})},
188+
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg2"})},
189+
}
190+
transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() })
191+
192+
if diff := cmp.Diff(wantReceived, receivedNotifications, transform); diff != "" {
193+
t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff)
194+
}
195+
196+
select {
197+
case <-serverReadyToKillProxy:
198+
// Server has sent the first two messages and is paused.
199+
case <-ctx.Done():
200+
t.Fatalf("Context timed out before server was ready to kill proxy")
201+
}
202+
203+
// 6. Simulate a total network failure by closing the proxy.
204+
t.Log("--- Killing proxy to simulate network failure ---")
205+
proxy.CloseClientConnections()
206+
proxy.Close()
207+
serverClosed.Done()
208+
209+
// 7. Simulate network recovery by restarting the proxy on the same address.
210+
t.Logf("--- Restarting proxy on %s ---", proxyAddr)
211+
listener, err := net.Listen("tcp", proxyAddr)
212+
if err != nil {
213+
t.Fatalf("Failed to listen on proxy address: %v", err)
214+
}
215+
restartedProxy := &http.Server{Handler: proxyHandler}
216+
go restartedProxy.Serve(listener)
217+
defer restartedProxy.Close()
218+
219+
// 8. Continue reading from the same connection object.
220+
// Its internal logic should successfully retry, reconnect to the new proxy,
221+
// and receive the replayed messages.
222+
recoveredNotifications := readProgressNotifications(t, ctx, conn, 2)
223+
224+
// 9. Verify the correct messages were received on the recovered connection.
225+
wantRecovered := []jsonrpc.Message{
226+
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg3"})},
227+
&jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg4"})},
228+
}
229+
230+
if diff := cmp.Diff(wantRecovered, recoveredNotifications, transform); diff != "" {
231+
t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff)
232+
}
233+
}
234+
235+
// Helper to read a specific number of progress notifications.
236+
func readProgressNotifications(t *testing.T, ctx context.Context, conn Connection, count int) []jsonrpc.Message {
237+
t.Helper()
238+
var notifications []jsonrpc.Message
239+
for len(notifications) < count && ctx.Err() == nil {
240+
msg, err := conn.Read(ctx)
241+
if err != nil {
242+
t.Fatalf("Failed to read notification: %v", err)
243+
}
244+
if req, ok := msg.(*jsonrpc.Request); ok && req.Method == "notifications/progress" {
245+
notifications = append(notifications, req)
246+
}
247+
}
248+
if len(notifications) != count {
249+
t.Fatalf("Expected to read %d notifications, but got %d", count, len(notifications))
250+
}
251+
return notifications
252+
}
253+
108254
func TestStreamableServerTransport(t *testing.T) {
109255
// This test checks detailed behavior of the streamable server transport, by
110256
// faking the behavior of a streamable client using a sequence of HTTP

0 commit comments

Comments
 (0)