Skip to content

Commit 5a2d5aa

Browse files
committed
mcp/streamable: add resumability for the Streamable transport
This CL implements a retry mechanism to resume SSE streams to recover from network failures.
1 parent de4b788 commit 5a2d5aa

File tree

2 files changed

+307
-14
lines changed

2 files changed

+307
-14
lines changed

mcp/streamable.go

Lines changed: 161 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"math"
13+
"math/rand/v2"
1214
"net/http"
1315
"strconv"
1416
"strings"
1517
"sync"
1618
"sync/atomic"
19+
"time"
1720

1821
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1922
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -594,12 +597,39 @@ type StreamableClientTransport struct {
594597
opts StreamableClientTransportOptions
595598
}
596599

600+
// StreamableClientReconnectOptions defines parameters for client reconnect attempts.
601+
type StreamableClientReconnectOptions struct {
602+
// InitialDelay is the base delay for the first reconnect attempt.
603+
InitialDelay time.Duration
604+
605+
// MaxDelay caps the backoff delay, preventing it from growing indefinitely.
606+
MaxDelay time.Duration
607+
608+
// GrowFactor is the multiplicative factor by which the delay increases after each attempt.
609+
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
610+
// It must be 1.0 or greater if MaxRetries is greater than 0.
611+
GrowFactor float64
612+
613+
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
614+
// A value of 0 or less means never retry.
615+
MaxRetries int
616+
}
617+
618+
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
619+
var DefaultReconnectOptions = &StreamableClientReconnectOptions{
620+
InitialDelay: 1 * time.Second,
621+
MaxDelay: 30 * time.Second,
622+
GrowFactor: 1.5,
623+
MaxRetries: 5,
624+
}
625+
597626
// StreamableClientTransportOptions provides options for the
598627
// [NewStreamableClientTransport] constructor.
599628
type StreamableClientTransportOptions struct {
600629
// HTTPClient is the client to use for making HTTP requests. If nil,
601630
// http.DefaultClient is used.
602-
HTTPClient *http.Client
631+
HTTPClient *http.Client
632+
ReconnectOptions *StreamableClientReconnectOptions
603633
}
604634

605635
// NewStreamableClientTransport returns a new client transport that connects to
@@ -625,22 +655,42 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
625655
if client == nil {
626656
client = http.DefaultClient
627657
}
628-
return &streamableClientConn{
629-
url: t.url,
630-
client: client,
631-
incoming: make(chan []byte, 100),
632-
done: make(chan struct{}),
633-
}, nil
658+
reconnOpts := t.opts.ReconnectOptions
659+
if reconnOpts == nil {
660+
reconnOpts = DefaultReconnectOptions
661+
}
662+
// A factor less than 1.0 would cause the delay to shrink, which defeats
663+
// the purpose of backoff.
664+
if reconnOpts.MaxRetries > 0 && reconnOpts.GrowFactor < 1.0 {
665+
return nil, fmt.Errorf("invalid grow factor, cannot be less than 1.0: %v", reconnOpts.GrowFactor)
666+
}
667+
// Create a new cancellable context that will manage the connection's lifecycle.
668+
// This is crucial for cleanly shutting down the background SSE listener by
669+
// cancelling its blocking network operations, which prevents hangs on exit.
670+
connCtx, cancel := context.WithCancel(context.Background())
671+
conn := &streamableClientConn{
672+
url: t.url,
673+
client: client,
674+
incoming: make(chan []byte, 100),
675+
done: make(chan struct{}),
676+
ReconnectOptions: reconnOpts,
677+
ctx: connCtx,
678+
cancel: cancel,
679+
}
680+
return conn, nil
634681
}
635682

636683
type streamableClientConn struct {
637-
url string
638-
client *http.Client
639-
incoming chan []byte
640-
done chan struct{}
684+
url string
685+
client *http.Client
686+
incoming chan []byte
687+
done chan struct{}
688+
ReconnectOptions *StreamableClientReconnectOptions
641689

642690
closeOnce sync.Once
643691
closeErr error
692+
ctx context.Context
693+
cancel context.CancelFunc
644694

645695
mu sync.Mutex
646696
protocolVersion string
@@ -704,11 +754,19 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
704754
if sessionID == "" {
705755
// locked
706756
s._sessionID = gotSessionID
757+
// With the session now established, launch the persistent background listener for server-pushed events.
758+
go s.establishSSE(&startSSEState{})
707759
}
708760

709761
return nil
710762
}
711763

764+
// startSSEState holds the state for initiating an SSE stream.
765+
type startSSEState struct {
766+
lastEventID string
767+
attempt int
768+
}
769+
712770
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
713771
data, err := jsonrpc2.EncodeMessage(msg)
714772
if err != nil {
@@ -742,7 +800,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
742800
sessionID = resp.Header.Get(sessionIDHeader)
743801
switch ct := resp.Header.Get("Content-Type"); ct {
744802
case "text/event-stream":
745-
go s.handleSSE(resp)
803+
go s.handleSSE(resp, &startSSEState{})
746804
case "application/json":
747805
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
748806
resp.Body.Close()
@@ -754,17 +812,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
754812
return sessionID, nil
755813
}
756814

757-
func (s *streamableClientConn) handleSSE(resp *http.Response) {
815+
// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
816+
// If the stream breaks, it uses the last received event ID to automatically trigger the reconnect logic.
817+
func (s *streamableClientConn) handleSSE(resp *http.Response, opts *startSSEState) {
758818
defer resp.Body.Close()
759819

760820
done := make(chan struct{})
761821
go func() {
762822
defer close(done)
763823
for evt, err := range scanEvents(resp.Body) {
764824
if err != nil {
765-
// TODO: surface this error; possibly break the stream
825+
s.scheduleReconnect(opts)
766826
return
767827
}
828+
opts.lastEventID = evt.id
768829
select {
769830
case <-s.done:
770831
return
@@ -782,6 +843,8 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) {
782843
// Close implements the [Connection] interface.
783844
func (s *streamableClientConn) Close() error {
784845
s.closeOnce.Do(func() {
846+
// Cancel any hanging network requests.
847+
s.cancel()
785848
close(s.done)
786849

787850
req, err := http.NewRequest(http.MethodDelete, s.url, nil)
@@ -800,3 +863,87 @@ func (s *streamableClientConn) Close() error {
800863
})
801864
return s.closeErr
802865
}
866+
867+
// establishSSE creates and manages the persistent SSE listening stream.
868+
// It is used for both the initial connection and all subsequent reconnect attempts,
869+
// using the Last-Event-ID header to resume a broken stream where it left off.
870+
// On a successful response, it delegates to handleSSE to process events;
871+
// on failure, it triggers the client's reconnect logic.
872+
func (s *streamableClientConn) establishSSE(opts *startSSEState) {
873+
select {
874+
case <-s.done:
875+
return
876+
default:
877+
}
878+
879+
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil)
880+
if err != nil {
881+
return
882+
}
883+
s.mu.Lock()
884+
if s._sessionID != "" {
885+
req.Header.Set("Mcp-Session-Id", s._sessionID)
886+
}
887+
s.mu.Unlock()
888+
if opts.lastEventID != "" {
889+
req.Header.Set("Last-Event-ID", opts.lastEventID)
890+
}
891+
req.Header.Set("Accept", "text/event-stream")
892+
893+
resp, err := s.client.Do(req)
894+
if err != nil {
895+
// On connection error, schedule a retry.
896+
s.scheduleReconnect(opts)
897+
return
898+
}
899+
900+
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
901+
if resp.StatusCode == http.StatusMethodNotAllowed {
902+
resp.Body.Close()
903+
return
904+
}
905+
906+
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
907+
resp.Body.Close()
908+
s.scheduleReconnect(opts)
909+
return
910+
}
911+
912+
s.handleSSE(resp, opts)
913+
}
914+
915+
// scheduleReconnect schedules the next SSE reconnect attempt after a delay.
916+
func (s *streamableClientConn) scheduleReconnect(opts *startSSEState) {
917+
reconnOpts := s.ReconnectOptions
918+
if opts.attempt >= reconnOpts.MaxRetries {
919+
s.mu.Lock()
920+
s.err = fmt.Errorf("connection failed after %d attempts", reconnOpts.MaxRetries)
921+
s.mu.Unlock()
922+
s.Close() // Close the connection to unblock any readers.
923+
return
924+
}
925+
926+
delay := calculateReconnectDelay(reconnOpts, opts.attempt)
927+
928+
select {
929+
case <-s.done:
930+
return
931+
case <-time.After(delay):
932+
opts.attempt++
933+
s.establishSSE(opts)
934+
}
935+
}
936+
937+
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
938+
func calculateReconnectDelay(opts *StreamableClientReconnectOptions, attempt int) time.Duration {
939+
// Calculate the exponential backoff using the grow factor.
940+
backoffDuration := time.Duration(float64(opts.InitialDelay) * math.Pow(opts.GrowFactor, float64(attempt)))
941+
942+
// Cap the backoffDuration at maxDelay.
943+
backoffDuration = min(backoffDuration, opts.MaxDelay)
944+
945+
// Use a full jitter using backoffDuration
946+
jitter := rand.N(backoffDuration)
947+
948+
return backoffDuration + jitter
949+
}

0 commit comments

Comments
 (0)