-
Notifications
You must be signed in to change notification settings - Fork 80
mcp/streamable: add resumability for the Streamable transport #133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,11 +9,14 @@ import ( | |
"context" | ||
"fmt" | ||
"io" | ||
"math" | ||
"math/rand/v2" | ||
"net/http" | ||
"strconv" | ||
"strings" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
|
||
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" | ||
"github.com/modelcontextprotocol/go-sdk/jsonrpc" | ||
|
@@ -597,12 +600,39 @@ type StreamableClientTransport struct { | |
opts StreamableClientTransportOptions | ||
} | ||
|
||
// StreamableReconnectOptions defines parameters for client reconnect attempts. | ||
type StreamableReconnectOptions struct { | ||
// MaxRetries is the maximum number of times to attempt a reconnect before giving up. | ||
// A value of 0 or less means never retry. | ||
MaxRetries int | ||
|
||
// growFactor is the multiplicative factor by which the delay increases after each attempt. | ||
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. | ||
// It must be 1.0 or greater if MaxRetries is greater than 0. | ||
growFactor float64 | ||
|
||
// initialDelay is the base delay for the first reconnect attempt. | ||
initialDelay time.Duration | ||
|
||
// maxDelay caps the backoff delay, preventing it from growing indefinitely. | ||
maxDelay time.Duration | ||
} | ||
|
||
// DefaultReconnectOptions provides sensible defaults for reconnect logic. | ||
var DefaultReconnectOptions = &StreamableReconnectOptions{ | ||
MaxRetries: 5, | ||
growFactor: 1.5, | ||
initialDelay: 1 * time.Second, | ||
maxDelay: 30 * time.Second, | ||
} | ||
|
||
// StreamableClientTransportOptions provides options for the | ||
// [NewStreamableClientTransport] constructor. | ||
type StreamableClientTransportOptions struct { | ||
// HTTPClient is the client to use for making HTTP requests. If nil, | ||
// http.DefaultClient is used. | ||
HTTPClient *http.Client | ||
HTTPClient *http.Client | ||
ReconnectOptions *StreamableReconnectOptions | ||
} | ||
|
||
// NewStreamableClientTransport returns a new client transport that connects to | ||
|
@@ -628,22 +658,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er | |
if client == nil { | ||
client = http.DefaultClient | ||
} | ||
return &streamableClientConn{ | ||
url: t.url, | ||
client: client, | ||
incoming: make(chan []byte, 100), | ||
done: make(chan struct{}), | ||
}, nil | ||
reconnOpts := t.opts.ReconnectOptions | ||
if reconnOpts == nil { | ||
reconnOpts = DefaultReconnectOptions | ||
} | ||
// Create a new cancellable context that will manage the connection's lifecycle. | ||
// This is crucial for cleanly shutting down the background SSE listener by | ||
// cancelling its blocking network operations, which prevents hangs on exit. | ||
connCtx, cancel := context.WithCancel(context.Background()) | ||
conn := &streamableClientConn{ | ||
url: t.url, | ||
client: client, | ||
incoming: make(chan []byte, 100), | ||
done: make(chan struct{}), | ||
ReconnectOptions: reconnOpts, | ||
ctx: connCtx, | ||
cancel: cancel, | ||
} | ||
return conn, nil | ||
} | ||
|
||
type streamableClientConn struct { | ||
url string | ||
client *http.Client | ||
incoming chan []byte | ||
done chan struct{} | ||
url string | ||
client *http.Client | ||
incoming chan []byte | ||
done chan struct{} | ||
ReconnectOptions *StreamableReconnectOptions | ||
|
||
closeOnce sync.Once | ||
closeErr error | ||
ctx context.Context | ||
cancel context.CancelFunc | ||
|
||
mu sync.Mutex | ||
protocolVersion string | ||
|
@@ -665,6 +710,12 @@ func (c *streamableClientConn) SessionID() string { | |
|
||
// Read implements the [Connection] interface. | ||
func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { | ||
s.mu.Lock() | ||
err := s.err | ||
s.mu.Unlock() | ||
if err != nil { | ||
return nil, err | ||
} | ||
select { | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
|
@@ -745,6 +796,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string | |
sessionID = resp.Header.Get(sessionIDHeader) | ||
switch ct := resp.Header.Get("Content-Type"); ct { | ||
case "text/event-stream": | ||
// Section 2.1: The SSE stream is initiated after a POST. | ||
go s.handleSSE(resp) | ||
case "application/json": | ||
// TODO: read the body and send to s.incoming (in a select that also receives from s.done). | ||
|
@@ -757,34 +809,115 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string | |
return sessionID, nil | ||
} | ||
|
||
func (s *streamableClientConn) handleSSE(resp *http.Response) { | ||
// handleSSE manages the entire lifecycle of an SSE connection. It processes | ||
// an incoming Server-Sent Events stream and automatically handles reconnection | ||
// logic if the stream breaks. | ||
func (s *streamableClientConn) handleSSE(initialResp *http.Response) { | ||
resp := initialResp | ||
var lastEventID string | ||
|
||
for { | ||
eventID, clientClosed := s.processStream(resp) | ||
lastEventID = eventID | ||
|
||
// If the connection was closed by the client, we're done. | ||
if clientClosed { | ||
return | ||
} | ||
|
||
// The stream was interrupted or ended by the server. Attempt to reconnect. | ||
newResp, err := s.reconnect(lastEventID) | ||
if err != nil { | ||
// All reconnection attempts failed. Set the final error, close the | ||
// connection, and exit the goroutine. | ||
s.mu.Lock() | ||
s.err = err | ||
s.mu.Unlock() | ||
s.Close() | ||
jba marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return | ||
} | ||
|
||
// Reconnection was successful. Continue the loop with the new response. | ||
resp = newResp | ||
} | ||
} | ||
|
||
// processStream reads from a single response body, sending events to the | ||
// incoming channel. It returns the ID of the last processed event, any error | ||
// that occurred, and a flag indicating if the connection was closed by the client. | ||
func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { | ||
defer resp.Body.Close() | ||
|
||
done := make(chan struct{}) | ||
go func() { | ||
defer close(done) | ||
for evt, err := range scanEvents(resp.Body) { | ||
for evt, err := range scanEvents(resp.Body) { | ||
if err != nil { | ||
return lastEventID, false | ||
} | ||
|
||
if evt.ID != "" { | ||
lastEventID = evt.ID | ||
} | ||
|
||
select { | ||
case s.incoming <- evt.Data: | ||
case <-s.done: | ||
// The connection was closed by the client; exit gracefully. | ||
return lastEventID, true | ||
} | ||
} | ||
|
||
// The loop finished without an error, indicating the server closed the stream. | ||
// We'll attempt to reconnect, so this is not a client-side close. | ||
return lastEventID, false | ||
} | ||
|
||
// reconnect handles the logic of retrying a connection with an exponential | ||
// backoff strategy. It returns a new, valid HTTP response if successful, or | ||
// an error if all retries are exhausted. | ||
func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { | ||
var finalErr error | ||
|
||
for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ { | ||
select { | ||
case <-s.done: | ||
return nil, fmt.Errorf("connection closed by client during reconnect") | ||
case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)): | ||
resp, err := s.establishSSE(lastEventID) | ||
if err != nil { | ||
// TODO: surface this error; possibly break the stream | ||
return | ||
finalErr = err // Store the error and try again. | ||
continue | ||
} | ||
select { | ||
case <-s.done: | ||
return | ||
case s.incoming <- evt.Data: | ||
|
||
if !isResumable(resp) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may also want to try resuming if the error from establishSSE is non-nil. For example, if the network is partitioned, that might manifest as a timeout error instead of an HTTP response. But we can leave that for a later PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the way it's written- we will retry another attempt if the error is non-nil. Are you saying that we have to do something special in that case? |
||
// The server indicated we should not continue. | ||
resp.Body.Close() | ||
return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status) | ||
} | ||
|
||
return resp, nil | ||
} | ||
}() | ||
} | ||
// If the loop completes, all retries have failed. | ||
if finalErr != nil { | ||
return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr) | ||
} | ||
return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries) | ||
} | ||
|
||
select { | ||
case <-s.done: | ||
case <-done: | ||
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. | ||
func isResumable(resp *http.Response) bool { | ||
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. | ||
if resp.StatusCode == http.StatusMethodNotAllowed { | ||
return false | ||
} | ||
|
||
return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") | ||
} | ||
|
||
// Close implements the [Connection] interface. | ||
func (s *streamableClientConn) Close() error { | ||
s.closeOnce.Do(func() { | ||
// Cancel any hanging network requests. | ||
s.cancel() | ||
close(s.done) | ||
|
||
req, err := http.NewRequest(http.MethodDelete, s.url, nil) | ||
|
@@ -803,3 +936,37 @@ func (s *streamableClientConn) Close() error { | |
}) | ||
return s.closeErr | ||
} | ||
|
||
// establishSSE establishes the persistent SSE listening stream. | ||
// It is used for reconnect attempts using the Last-Event-ID header to | ||
// resume a broken stream where it left off. | ||
func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) { | ||
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
s.mu.Lock() | ||
if s._sessionID != "" { | ||
req.Header.Set("Mcp-Session-Id", s._sessionID) | ||
} | ||
s.mu.Unlock() | ||
if lastEventID != "" { | ||
req.Header.Set("Last-Event-ID", lastEventID) | ||
} | ||
req.Header.Set("Accept", "text/event-stream") | ||
|
||
return s.client.Do(req) | ||
} | ||
|
||
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. | ||
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration { | ||
// Calculate the exponential backoff using the grow factor. | ||
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt))) | ||
// Cap the backoffDuration at maxDelay. | ||
backoffDuration = min(backoffDuration, opts.maxDelay) | ||
|
||
// Use a full jitter using backoffDuration | ||
jitter := rand.N(backoffDuration) | ||
|
||
return backoffDuration + jitter | ||
} |
Uh oh!
There was an error while loading. Please reload this page.