Skip to content

Commit 9d77544

Browse files
committed
mcp/streamable: add resumability for the Streamable transport
This CL implements a retry mechanism to resume SSE streams to recover from network failures.
1 parent de4b788 commit 9d77544

File tree

2 files changed

+313
-14
lines changed

2 files changed

+313
-14
lines changed

mcp/streamable.go

Lines changed: 167 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"math"
13+
"math/rand/v2"
1214
"net/http"
1315
"strconv"
1416
"strings"
1517
"sync"
1618
"sync/atomic"
19+
"time"
1720

1821
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -594,12 +597,39 @@ type StreamableClientTransport struct {
594597
opts StreamableClientTransportOptions
595598
}
596599

600+
// StreamableClientReconnectOptions defines parameters for client reconnect attempts.
601+
type StreamableClientReconnectOptions struct {
602+
// InitialDelay is the base delay for the first reconnect attempt.
603+
InitialDelay time.Duration
604+
605+
// MaxDelay caps the backoff delay, preventing it from growing indefinitely.
606+
MaxDelay time.Duration
607+
608+
// GrowFactor is the multiplicative factor by which the delay increases after each attempt.
609+
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
610+
// It must be 1.0 or greater if MaxRetries is greater than 0.
611+
GrowFactor float64
612+
613+
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
614+
// A value of 0 or less means never retry.
615+
MaxRetries int
616+
}
617+
618+
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
619+
var DefaultReconnectOptions = &StreamableClientReconnectOptions{
620+
InitialDelay: 1 * time.Second,
621+
MaxDelay: 30 * time.Second,
622+
GrowFactor: 1.5,
623+
MaxRetries: 5,
624+
}
625+
597626
// StreamableClientTransportOptions provides options for the
598627
// [NewStreamableClientTransport] constructor.
599628
type StreamableClientTransportOptions struct {
600629
// HTTPClient is the client to use for making HTTP requests. If nil,
601630
// http.DefaultClient is used.
602-
HTTPClient *http.Client
631+
HTTPClient *http.Client
632+
ReconnectOptions *StreamableClientReconnectOptions
603633
}
604634

605635
// NewStreamableClientTransport returns a new client transport that connects to
@@ -625,22 +655,42 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
625655
if client == nil {
626656
client = http.DefaultClient
627657
}
628-
return &streamableClientConn{
629-
url: t.url,
630-
client: client,
631-
incoming: make(chan []byte, 100),
632-
done: make(chan struct{}),
633-
}, nil
658+
reconnOpts := t.opts.ReconnectOptions
659+
if reconnOpts == nil {
660+
reconnOpts = DefaultReconnectOptions
661+
}
662+
// A factor less than 1.0 would cause the delay to shrink, which defeats
663+
// the purpose of backoff.
664+
if reconnOpts.MaxRetries > 0 && reconnOpts.GrowFactor < 1.0 {
665+
return nil, fmt.Errorf("invalid grow factor, cannot be less than 1.0: %v", reconnOpts.GrowFactor)
666+
}
667+
// Create a new cancellable context that will manage the connection's lifecycle.
668+
// This is crucial for cleanly shutting down the background SSE listener by
669+
// cancelling its blocking network operations, which prevents hangs on exit.
670+
connCtx, cancel := context.WithCancel(context.Background())
671+
conn := &streamableClientConn{
672+
url: t.url,
673+
client: client,
674+
incoming: make(chan []byte, 100),
675+
done: make(chan struct{}),
676+
ReconnectOptions: reconnOpts,
677+
ctx: connCtx,
678+
cancel: cancel,
679+
}
680+
return conn, nil
634681
}
635682

636683
type streamableClientConn struct {
637-
url string
638-
client *http.Client
639-
incoming chan []byte
640-
done chan struct{}
684+
url string
685+
client *http.Client
686+
incoming chan []byte
687+
done chan struct{}
688+
ReconnectOptions *StreamableClientReconnectOptions
641689

642690
closeOnce sync.Once
643691
closeErr error
692+
ctx context.Context
693+
cancel context.CancelFunc
644694

645695
mu sync.Mutex
646696
protocolVersion string
@@ -662,6 +712,12 @@ func (c *streamableClientConn) SessionID() string {
662712

663713
// Read implements the [Connection] interface.
664714
func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
715+
s.mu.Lock()
716+
err := s.err
717+
s.mu.Unlock()
718+
if err != nil {
719+
return nil, err
720+
}
665721
select {
666722
case <-ctx.Done():
667723
return nil, ctx.Err()
@@ -704,11 +760,19 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
704760
if sessionID == "" {
705761
// locked
706762
s._sessionID = gotSessionID
763+
// With the session now established, launch the persistent background listener for server-pushed events.
764+
go s.establishSSE(&startSSEState{})
707765
}
708766

709767
return nil
710768
}
711769

770+
// startSSEState holds the state for initiating an SSE stream.
771+
type startSSEState struct {
772+
lastEventID string
773+
attempt int
774+
}
775+
712776
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
713777
data, err := jsonrpc2.EncodeMessage(msg)
714778
if err != nil {
@@ -742,7 +806,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
742806
sessionID = resp.Header.Get(sessionIDHeader)
743807
switch ct := resp.Header.Get("Content-Type"); ct {
744808
case "text/event-stream":
745-
go s.handleSSE(resp)
809+
go s.handleSSE(resp, &startSSEState{})
746810
case "application/json":
747811
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
748812
resp.Body.Close()
@@ -754,17 +818,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
754818
return sessionID, nil
755819
}
756820

757-
func (s *streamableClientConn) handleSSE(resp *http.Response) {
821+
// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
822+
// If the stream breaks, it uses the last received event ID to automatically trigger the reconnect logic.
823+
func (s *streamableClientConn) handleSSE(resp *http.Response, opts *startSSEState) {
758824
defer resp.Body.Close()
759825

760826
done := make(chan struct{})
761827
go func() {
762828
defer close(done)
763829
for evt, err := range scanEvents(resp.Body) {
764830
if err != nil {
765-
// TODO: surface this error; possibly break the stream
831+
s.scheduleReconnect(opts)
766832
return
767833
}
834+
opts.lastEventID = evt.id
768835
select {
769836
case <-s.done:
770837
return
@@ -782,6 +849,8 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) {
782849
// Close implements the [Connection] interface.
783850
func (s *streamableClientConn) Close() error {
784851
s.closeOnce.Do(func() {
852+
// Cancel any hanging network requests.
853+
s.cancel()
785854
close(s.done)
786855

787856
req, err := http.NewRequest(http.MethodDelete, s.url, nil)
@@ -800,3 +869,87 @@ func (s *streamableClientConn) Close() error {
800869
})
801870
return s.closeErr
802871
}
872+
873+
// establishSSE creates and manages the persistent SSE listening stream.
874+
// It is used for both the initial connection and all subsequent reconnect attempts,
875+
// using the Last-Event-ID header to resume a broken stream where it left off.
876+
// On a successful response, it delegates to handleSSE to process events;
877+
// on failure, it triggers the client's reconnect logic.
878+
func (s *streamableClientConn) establishSSE(opts *startSSEState) {
879+
select {
880+
case <-s.done:
881+
return
882+
default:
883+
}
884+
885+
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil)
886+
if err != nil {
887+
return
888+
}
889+
s.mu.Lock()
890+
if s._sessionID != "" {
891+
req.Header.Set("Mcp-Session-Id", s._sessionID)
892+
}
893+
s.mu.Unlock()
894+
if opts.lastEventID != "" {
895+
req.Header.Set("Last-Event-ID", opts.lastEventID)
896+
}
897+
req.Header.Set("Accept", "text/event-stream")
898+
899+
resp, err := s.client.Do(req)
900+
if err != nil {
901+
// On connection error, schedule a retry.
902+
s.scheduleReconnect(opts)
903+
return
904+
}
905+
906+
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
907+
if resp.StatusCode == http.StatusMethodNotAllowed {
908+
resp.Body.Close()
909+
return
910+
}
911+
912+
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
913+
resp.Body.Close()
914+
s.scheduleReconnect(opts)
915+
return
916+
}
917+
918+
s.handleSSE(resp, opts)
919+
}
920+
921+
// scheduleReconnect schedules the next SSE reconnect attempt after a delay.
922+
func (s *streamableClientConn) scheduleReconnect(opts *startSSEState) {
923+
reconnOpts := s.ReconnectOptions
924+
if opts.attempt >= reconnOpts.MaxRetries {
925+
s.mu.Lock()
926+
s.err = fmt.Errorf("connection failed after %d attempts", reconnOpts.MaxRetries)
927+
s.mu.Unlock()
928+
s.Close() // Close the connection to unblock any readers.
929+
return
930+
}
931+
932+
delay := calculateReconnectDelay(reconnOpts, opts.attempt)
933+
934+
select {
935+
case <-s.done:
936+
return
937+
case <-time.After(delay):
938+
opts.attempt++
939+
s.establishSSE(opts)
940+
}
941+
}
942+
943+
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
944+
func calculateReconnectDelay(opts *StreamableClientReconnectOptions, attempt int) time.Duration {
945+
// Calculate the exponential backoff using the grow factor.
946+
backoffDuration := time.Duration(float64(opts.InitialDelay) * math.Pow(opts.GrowFactor, float64(attempt)))
947+
948+
// Cap the backoffDuration at maxDelay.
949+
backoffDuration = min(backoffDuration, opts.MaxDelay)
950+
951+
// Use a full jitter using backoffDuration
952+
jitter := rand.N(backoffDuration)
953+
954+
return backoffDuration + jitter
955+
}

0 commit comments

Comments
 (0)