Skip to content

Commit abc4c9d

Browse files
committed
mcp/streamable: add resumability for the Streamable transport
This CL implements a retry mechanism to resume SSE streams to recover from network failures.
1 parent de4b788 commit abc4c9d

File tree

2 files changed

+297
-14
lines changed

2 files changed

+297
-14
lines changed

mcp/streamable.go

Lines changed: 151 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"math"
13+
"math/rand/v2"
1214
"net/http"
1315
"strconv"
1416
"strings"
1517
"sync"
1618
"sync/atomic"
19+
"time"
1720

1821
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -594,12 +597,38 @@ type StreamableClientTransport struct {
594597
opts StreamableClientTransportOptions
595598
}
596599

600+
// StreamableClientReconnectionOptions defines parameters for client reconnection attempts.
601+
type StreamableClientReconnectionOptions struct {
602+
// InitialDelay is the base delay for the first reconnection attempt
603+
InitialDelay time.Duration // default: 1 second
604+
605+
// MaxDelay caps the backoff delay, preventing it from growing indefinitely.
606+
MaxDelay time.Duration // default: 30 seconds
607+
608+
// GrowFactor is the multiplicative factor by which the delay increases after each attempt.
609+
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
610+
GrowFactor float64 // default: 1.5
611+
612+
// MaxRetries is the maximum number of times to attempt reconnection before giving up.
613+
// A value of 0 or less means never retry.
614+
MaxRetries int // default: 5
615+
}
616+
617+
// DefaultReconnectionOptions provides sensible defaults for reconnection logic.
618+
var DefaultReconnectionOptions = &StreamableClientReconnectionOptions{
619+
InitialDelay: 1 * time.Second,
620+
MaxDelay: 30 * time.Second,
621+
GrowFactor: 1.5,
622+
MaxRetries: 5,
623+
}
624+
597625
// StreamableClientTransportOptions provides options for the
598626
// [NewStreamableClientTransport] constructor.
599627
type StreamableClientTransportOptions struct {
600628
// HTTPClient is the client to use for making HTTP requests. If nil,
601629
// http.DefaultClient is used.
602-
HTTPClient *http.Client
630+
HTTPClient *http.Client
631+
ReconnectionOptions *StreamableClientReconnectionOptions
603632
}
604633

605634
// NewStreamableClientTransport returns a new client transport that connects to
@@ -625,22 +654,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
625654
if client == nil {
626655
client = http.DefaultClient
627656
}
628-
return &streamableClientConn{
629-
url: t.url,
630-
client: client,
631-
incoming: make(chan []byte, 100),
632-
done: make(chan struct{}),
633-
}, nil
657+
reconnOpts := t.opts.ReconnectionOptions
658+
if reconnOpts == nil {
659+
reconnOpts = DefaultReconnectionOptions
660+
}
661+
// Create a new cancellable context that will manage the connection's lifecycle.
662+
// This is crucial for cleanly shutting down the background SSE listener by
663+
// cancelling its blocking network operations, which prevents hangs on exit.
664+
connCtx, cancel := context.WithCancel(context.Background())
665+
conn := &streamableClientConn{
666+
url: t.url,
667+
client: client,
668+
incoming: make(chan []byte, 100),
669+
done: make(chan struct{}),
670+
reconnectionOptions: reconnOpts,
671+
ctx: connCtx,
672+
cancel: cancel,
673+
}
674+
return conn, nil
634675
}
635676

636677
type streamableClientConn struct {
637-
url string
638-
client *http.Client
639-
incoming chan []byte
640-
done chan struct{}
678+
url string
679+
client *http.Client
680+
incoming chan []byte
681+
done chan struct{}
682+
reconnectionOptions *StreamableClientReconnectionOptions
641683

642684
closeOnce sync.Once
643685
closeErr error
686+
ctx context.Context
687+
cancel context.CancelFunc
644688

645689
mu sync.Mutex
646690
protocolVersion string
@@ -704,11 +748,19 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
704748
if sessionID == "" {
705749
// locked
706750
s._sessionID = gotSessionID
751+
// With the session now established, launch the persistent background listener for server-pushed events.
752+
go s.establishSSE(&startSSEOptions{})
707753
}
708754

709755
return nil
710756
}
711757

758+
// startSSEOptions holds parameters for initiating an SSE stream.
759+
type startSSEOptions struct {
760+
lastEventID string
761+
attempt int
762+
}
763+
712764
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
713765
data, err := jsonrpc2.EncodeMessage(msg)
714766
if err != nil {
@@ -742,7 +794,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
742794
sessionID = resp.Header.Get(sessionIDHeader)
743795
switch ct := resp.Header.Get("Content-Type"); ct {
744796
case "text/event-stream":
745-
go s.handleSSE(resp)
797+
go s.handleSSE(resp, &startSSEOptions{})
746798
case "application/json":
747799
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
748800
resp.Body.Close()
@@ -754,17 +806,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
754806
return sessionID, nil
755807
}
756808

757-
func (s *streamableClientConn) handleSSE(resp *http.Response) {
809+
// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
810+
// If the stream breaks, it uses the last received event ID to automatically trigger the reconnection logic.
811+
func (s *streamableClientConn) handleSSE(resp *http.Response, opts *startSSEOptions) {
758812
defer resp.Body.Close()
759813

760814
done := make(chan struct{})
761815
go func() {
762816
defer close(done)
763817
for evt, err := range scanEvents(resp.Body) {
764818
if err != nil {
765-
// TODO: surface this error; possibly break the stream
819+
s.scheduleReconnection(opts)
766820
return
767821
}
822+
opts.lastEventID = evt.id
768823
select {
769824
case <-s.done:
770825
return
@@ -782,6 +837,8 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) {
782837
// Close implements the [Connection] interface.
783838
func (s *streamableClientConn) Close() error {
784839
s.closeOnce.Do(func() {
840+
// Cancel any hanging network requests.
841+
s.cancel()
785842
close(s.done)
786843

787844
req, err := http.NewRequest(http.MethodDelete, s.url, nil)
@@ -800,3 +857,83 @@ func (s *streamableClientConn) Close() error {
800857
})
801858
return s.closeErr
802859
}
860+
861+
// establishSSE creates and manages the persistent SSE listening stream.
862+
// It is used for both the initial connection and all subsequent reconnection attempts,
863+
// using the Last-Event-ID header to resume a broken stream where it left off.
864+
// On a successful response, it delegates to handleSSE to process events;
865+
// on failure, it triggers the client's reconnection logic.
866+
func (s *streamableClientConn) establishSSE(opts *startSSEOptions) {
867+
select {
868+
case <-s.done:
869+
return
870+
default:
871+
}
872+
873+
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil)
874+
if err != nil {
875+
return
876+
}
877+
s.mu.Lock()
878+
if s._sessionID != "" {
879+
req.Header.Set("Mcp-Session-Id", s._sessionID)
880+
}
881+
s.mu.Unlock()
882+
if opts.lastEventID != "" {
883+
req.Header.Set("Last-Event-ID", opts.lastEventID)
884+
}
885+
req.Header.Set("Accept", "text/event-stream")
886+
887+
resp, err := s.client.Do(req)
888+
if err != nil {
889+
// On connection error, schedule a retry.
890+
s.scheduleReconnection(opts)
891+
return
892+
}
893+
894+
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
895+
if resp.StatusCode == http.StatusMethodNotAllowed {
896+
resp.Body.Close()
897+
return
898+
}
899+
900+
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
901+
resp.Body.Close()
902+
s.scheduleReconnection(opts)
903+
return
904+
}
905+
906+
s.handleSSE(resp, opts)
907+
}
908+
909+
// scheduleReconnection schedules the next SSE reconnection attempt after a delay.
910+
func (s *streamableClientConn) scheduleReconnection(opts *startSSEOptions) {
911+
reconnOpts := s.reconnectionOptions
912+
if opts.attempt >= reconnOpts.MaxRetries {
913+
return
914+
}
915+
916+
delay := calculateReconnectionDelay(reconnOpts, opts.attempt)
917+
918+
select {
919+
case <-s.done:
920+
return
921+
case <-time.After(delay):
922+
opts.attempt++
923+
s.establishSSE(opts)
924+
}
925+
}
926+
927+
// calculateReconnectionDelay calculates a delay using exponential backoff with full jitter.
928+
func calculateReconnectionDelay(opts *StreamableClientReconnectionOptions, attempt int) time.Duration {
929+
// Calculate the exponential backoff using the grow factor.
930+
backoffDuration := time.Duration(float64(opts.InitialDelay) * math.Pow(opts.GrowFactor, float64(attempt)))
931+
932+
// Cap the backoffDuration at maxDelay.
933+
backoffDuration = min(backoffDuration, opts.MaxDelay)
934+
935+
// Use a full jitter using backoffDuration
936+
jitter := rand.N(backoffDuration)
937+
938+
return backoffDuration + jitter
939+
}

0 commit comments

Comments
 (0)