Skip to content

mcp/streamable: add resumability for the Streamable transport #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 193 additions & 26 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ import (
"context"
"fmt"
"io"
"math"
"math/rand/v2"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
Expand Down Expand Up @@ -597,12 +600,39 @@ type StreamableClientTransport struct {
opts StreamableClientTransportOptions
}

// StreamableReconnectOptions defines parameters for client reconnect attempts.
type StreamableReconnectOptions struct {
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
// A value of 0 or less means never retry.
MaxRetries int

// growFactor is the multiplicative factor by which the delay increases after each attempt.
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
// It must be 1.0 or greater if MaxRetries is greater than 0.
growFactor float64

// initialDelay is the base delay for the first reconnect attempt.
initialDelay time.Duration

// maxDelay caps the backoff delay, preventing it from growing indefinitely.
maxDelay time.Duration
}

// DefaultReconnectOptions provides sensible defaults for reconnect logic.
var DefaultReconnectOptions = &StreamableReconnectOptions{
MaxRetries: 5,
growFactor: 1.5,
initialDelay: 1 * time.Second,
maxDelay: 30 * time.Second,
}

// StreamableClientTransportOptions provides options for the
// [NewStreamableClientTransport] constructor.
type StreamableClientTransportOptions struct {
// HTTPClient is the client to use for making HTTP requests. If nil,
// http.DefaultClient is used.
HTTPClient *http.Client
HTTPClient *http.Client
ReconnectOptions *StreamableReconnectOptions
}

// NewStreamableClientTransport returns a new client transport that connects to
Expand All @@ -628,22 +658,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
if client == nil {
client = http.DefaultClient
}
return &streamableClientConn{
url: t.url,
client: client,
incoming: make(chan []byte, 100),
done: make(chan struct{}),
}, nil
reconnOpts := t.opts.ReconnectOptions
if reconnOpts == nil {
reconnOpts = DefaultReconnectOptions
}
// Create a new cancellable context that will manage the connection's lifecycle.
// This is crucial for cleanly shutting down the background SSE listener by
// cancelling its blocking network operations, which prevents hangs on exit.
connCtx, cancel := context.WithCancel(context.Background())
conn := &streamableClientConn{
url: t.url,
client: client,
incoming: make(chan []byte, 100),
done: make(chan struct{}),
ReconnectOptions: reconnOpts,
ctx: connCtx,
cancel: cancel,
}
return conn, nil
}

type streamableClientConn struct {
url string
client *http.Client
incoming chan []byte
done chan struct{}
url string
client *http.Client
incoming chan []byte
done chan struct{}
ReconnectOptions *StreamableReconnectOptions

closeOnce sync.Once
closeErr error
ctx context.Context
cancel context.CancelFunc

mu sync.Mutex
protocolVersion string
Expand All @@ -665,6 +710,12 @@ func (c *streamableClientConn) SessionID() string {

// Read implements the [Connection] interface.
func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
s.mu.Lock()
err := s.err
s.mu.Unlock()
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
Expand Down Expand Up @@ -745,6 +796,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
sessionID = resp.Header.Get(sessionIDHeader)
switch ct := resp.Header.Get("Content-Type"); ct {
case "text/event-stream":
// Section 2.1: The SSE stream is initiated after a POST.
go s.handleSSE(resp)
case "application/json":
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
Expand All @@ -757,34 +809,115 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
return sessionID, nil
}

func (s *streamableClientConn) handleSSE(resp *http.Response) {
// handleSSE manages the entire lifecycle of an SSE connection. It processes
// an incoming Server-Sent Events stream and automatically handles reconnection
// logic if the stream breaks.
func (s *streamableClientConn) handleSSE(initialResp *http.Response) {
resp := initialResp
var lastEventID string

for {
eventID, clientClosed := s.processStream(resp)
lastEventID = eventID

// If the connection was closed by the client, we're done.
if clientClosed {
return
}

// The stream was interrupted or ended by the server. Attempt to reconnect.
newResp, err := s.reconnect(lastEventID)
if err != nil {
// All reconnection attempts failed. Set the final error, close the
// connection, and exit the goroutine.
s.mu.Lock()
s.err = err
s.mu.Unlock()
s.Close()
return
}

// Reconnection was successful. Continue the loop with the new response.
resp = newResp
}
}

// processStream reads from a single response body, sending events to the
// incoming channel. It returns the ID of the last processed event, any error
// that occurred, and a flag indicating if the connection was closed by the client.
func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) {
defer resp.Body.Close()

done := make(chan struct{})
go func() {
defer close(done)
for evt, err := range scanEvents(resp.Body) {
for evt, err := range scanEvents(resp.Body) {
if err != nil {
return lastEventID, false
}

if evt.ID != "" {
lastEventID = evt.ID
}

select {
case s.incoming <- evt.Data:
case <-s.done:
// The connection was closed by the client; exit gracefully.
return lastEventID, true
}
}

// The loop finished without an error, indicating the server closed the stream.
// We'll attempt to reconnect, so this is not a client-side close.
return lastEventID, false
}

// reconnect handles the logic of retrying a connection with an exponential
// backoff strategy. It returns a new, valid HTTP response if successful, or
// an error if all retries are exhausted.
func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
var finalErr error

for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ {
select {
case <-s.done:
return nil, fmt.Errorf("connection closed by client during reconnect")
case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)):
resp, err := s.establishSSE(lastEventID)
if err != nil {
// TODO: surface this error; possibly break the stream
return
finalErr = err // Store the error and try again.
continue
}
select {
case <-s.done:
return
case s.incoming <- evt.Data:

if !isResumable(resp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may also want to try resuming if the error from establishSSE is non-nil. For example, if the network is partitioned, that might manifest as a timeout error instead of an HTTP response. But we can leave that for a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the way it's written- we will retry another attempt if the error is non-nil. Are you saying that we have to do something special in that case?

// The server indicated we should not continue.
resp.Body.Close()
return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status)
}

return resp, nil
}
}()
}
// If the loop completes, all retries have failed.
if finalErr != nil {
return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr)
}
return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries)
}

select {
case <-s.done:
case <-done:
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
func isResumable(resp *http.Response) bool {
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
if resp.StatusCode == http.StatusMethodNotAllowed {
return false
}

return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
}

// Close implements the [Connection] interface.
func (s *streamableClientConn) Close() error {
s.closeOnce.Do(func() {
// Cancel any hanging network requests.
s.cancel()
close(s.done)

req, err := http.NewRequest(http.MethodDelete, s.url, nil)
Expand All @@ -803,3 +936,37 @@ func (s *streamableClientConn) Close() error {
})
return s.closeErr
}

// establishSSE establishes the persistent SSE listening stream.
// It is used for reconnect attempts using the Last-Event-ID header to
// resume a broken stream where it left off.
func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) {
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil)
if err != nil {
return nil, err
}
s.mu.Lock()
if s._sessionID != "" {
req.Header.Set("Mcp-Session-Id", s._sessionID)
}
s.mu.Unlock()
if lastEventID != "" {
req.Header.Set("Last-Event-ID", lastEventID)
}
req.Header.Set("Accept", "text/event-stream")

return s.client.Do(req)
}

// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration {
// Calculate the exponential backoff using the grow factor.
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt)))
// Cap the backoffDuration at maxDelay.
backoffDuration = min(backoffDuration, opts.maxDelay)

// Use a full jitter using backoffDuration
jitter := rand.N(backoffDuration)

return backoffDuration + jitter
}
Loading
Loading