@@ -9,11 +9,14 @@ import (
9
9
"context"
10
10
"fmt"
11
11
"io"
12
+ "math"
13
+ "math/rand/v2"
12
14
"net/http"
13
15
"strconv"
14
16
"strings"
15
17
"sync"
16
18
"sync/atomic"
19
+ "time"
17
20
18
21
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
19
22
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -594,12 +597,38 @@ type StreamableClientTransport struct {
594
597
opts StreamableClientTransportOptions
595
598
}
596
599
600
+ // StreamableClientReconnectionOptions defines parameters for client reconnection attempts.
601
+ type StreamableClientReconnectionOptions struct {
602
+ // InitialDelay is the base delay for the first reconnection attempt
603
+ InitialDelay time.Duration // default: 1 second
604
+
605
+ // MaxDelay caps the backoff delay, preventing it from growing indefinitely.
606
+ MaxDelay time.Duration // default: 30 seconds
607
+
608
+ // GrowFactor is the multiplicative factor by which the delay increases after each attempt.
609
+ // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
610
+ GrowFactor float64 // default: 1.5
611
+
612
+ // MaxRetries is the maximum number of times to attempt reconnection before giving up.
613
+ // A value of 0 or less means never retry.
614
+ MaxRetries int // default: 5
615
+ }
616
+
617
+ // DefaultReconnectionOptions provides sensible defaults for reconnection logic.
618
+ var DefaultReconnectionOptions = & StreamableClientReconnectionOptions {
619
+ InitialDelay : 1 * time .Second ,
620
+ MaxDelay : 30 * time .Second ,
621
+ GrowFactor : 1.5 ,
622
+ MaxRetries : 5 ,
623
+ }
624
+
597
625
// StreamableClientTransportOptions provides options for the
598
626
// [NewStreamableClientTransport] constructor.
599
627
type StreamableClientTransportOptions struct {
600
628
// HTTPClient is the client to use for making HTTP requests. If nil,
601
629
// http.DefaultClient is used.
602
- HTTPClient * http.Client
630
+ HTTPClient * http.Client
631
+ ReconnectionOptions * StreamableClientReconnectionOptions
603
632
}
604
633
605
634
// NewStreamableClientTransport returns a new client transport that connects to
@@ -625,19 +654,26 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
625
654
if client == nil {
626
655
client = http .DefaultClient
627
656
}
628
- return & streamableClientConn {
629
- url : t .url ,
630
- client : client ,
631
- incoming : make (chan []byte , 100 ),
632
- done : make (chan struct {}),
633
- }, nil
657
+ reconnOpts := t .opts .ReconnectionOptions
658
+ if reconnOpts == nil {
659
+ reconnOpts = DefaultReconnectionOptions
660
+ }
661
+ conn := & streamableClientConn {
662
+ url : t .url ,
663
+ client : client ,
664
+ incoming : make (chan []byte , 100 ),
665
+ done : make (chan struct {}),
666
+ reconnectionOptions : reconnOpts ,
667
+ }
668
+ return conn , nil
634
669
}
635
670
636
671
type streamableClientConn struct {
637
- url string
638
- client * http.Client
639
- incoming chan []byte
640
- done chan struct {}
672
+ url string
673
+ client * http.Client
674
+ incoming chan []byte
675
+ done chan struct {}
676
+ reconnectionOptions * StreamableClientReconnectionOptions
641
677
642
678
closeOnce sync.Once
643
679
closeErr error
@@ -704,11 +740,19 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
704
740
if sessionID == "" {
705
741
// locked
706
742
s ._sessionID = gotSessionID
743
+ // With the session now established, launch the persistent background listener for server-pushed events.
744
+ go s .establishSSE (context .Background (), & startSSEOptions {})
707
745
}
708
746
709
747
return nil
710
748
}
711
749
750
+ // startSSEOptions holds parameters for initiating an SSE stream.
751
+ type startSSEOptions struct {
752
+ lastEventID string
753
+ attempt int
754
+ }
755
+
712
756
func (s * streamableClientConn ) postMessage (ctx context.Context , sessionID string , msg jsonrpc.Message ) (string , error ) {
713
757
data , err := jsonrpc2 .EncodeMessage (msg )
714
758
if err != nil {
@@ -742,7 +786,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
742
786
sessionID = resp .Header .Get (sessionIDHeader )
743
787
switch ct := resp .Header .Get ("Content-Type" ); ct {
744
788
case "text/event-stream" :
745
- go s .handleSSE (resp )
789
+ go s .handleSSE (context . Background (), resp , & startSSEOptions {} )
746
790
case "application/json" :
747
791
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
748
792
resp .Body .Close ()
@@ -754,17 +798,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
754
798
return sessionID , nil
755
799
}
756
800
757
- func (s * streamableClientConn ) handleSSE (resp * http.Response ) {
801
+ // handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
802
+ // If the stream breaks, it uses the last received event ID to automatically trigger the reconnection logic.
803
+ func (s * streamableClientConn ) handleSSE (ctx context.Context , resp * http.Response , opts * startSSEOptions ) {
758
804
defer resp .Body .Close ()
759
805
760
806
done := make (chan struct {})
761
807
go func () {
762
808
defer close (done )
763
809
for evt , err := range scanEvents (resp .Body ) {
764
810
if err != nil {
765
- // TODO: surface this error; possibly break the stream
811
+ s . scheduleReconnection ( ctx , opts )
766
812
return
767
813
}
814
+ opts .lastEventID = evt .id
768
815
select {
769
816
case <- s .done :
770
817
return
@@ -800,3 +847,85 @@ func (s *streamableClientConn) Close() error {
800
847
})
801
848
return s .closeErr
802
849
}
850
+
851
+ // establishSSE creates and manages the persistent SSE listening stream.
852
+ // It is used for both the initial connection and all subsequent reconnection attempts,
853
+ // using the Last-Event-ID header to resume a broken stream where it left off.
854
+ // On a successful response, it delegates to handleSSE to process events;
855
+ // on failure, it triggers the client's reconnection logic.
856
+ func (s * streamableClientConn ) establishSSE (ctx context.Context , opts * startSSEOptions ) {
857
+ select {
858
+ case <- s .done :
859
+ return
860
+ case <- ctx .Done ():
861
+ return
862
+ default :
863
+ }
864
+
865
+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , s .url , nil )
866
+ if err != nil {
867
+ return
868
+ }
869
+ s .mu .Lock ()
870
+ if s ._sessionID != "" {
871
+ req .Header .Set ("Mcp-Session-Id" , s ._sessionID )
872
+ }
873
+ s .mu .Unlock ()
874
+ if opts .lastEventID != "" {
875
+ req .Header .Set ("Last-Event-ID" , opts .lastEventID )
876
+ }
877
+ req .Header .Set ("Accept" , "text/event-stream" )
878
+
879
+ resp , err := s .client .Do (req )
880
+ if err != nil {
881
+ // On connection error, schedule a retry.
882
+ s .scheduleReconnection (ctx , opts )
883
+ return
884
+ }
885
+
886
+ // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
887
+ if resp .StatusCode == http .StatusMethodNotAllowed {
888
+ resp .Body .Close ()
889
+ return
890
+ }
891
+
892
+ if ! strings .Contains (resp .Header .Get ("Content-Type" ), "text/event-stream" ) {
893
+ resp .Body .Close ()
894
+ s .scheduleReconnection (ctx , opts )
895
+ return
896
+ }
897
+
898
+ s .handleSSE (ctx , resp , opts )
899
+ }
900
+
901
+ // scheduleReconnection schedules the next SSE reconnection attempt after a delay.
902
+ func (s * streamableClientConn ) scheduleReconnection (ctx context.Context , opts * startSSEOptions ) {
903
+ reconnOpts := s .reconnectionOptions
904
+ if opts .attempt >= reconnOpts .MaxRetries {
905
+ return
906
+ }
907
+
908
+ delay := calculateReconnectionDelay (reconnOpts , opts .attempt )
909
+
910
+ select {
911
+ case <- s .done :
912
+ return
913
+ case <- time .After (delay ):
914
+ opts .attempt ++
915
+ s .establishSSE (ctx , opts )
916
+ }
917
+ }
918
+
919
+ // calculateReconnectionDelay calculates a delay using exponential backoff with full jitter.
920
+ func calculateReconnectionDelay (opts * StreamableClientReconnectionOptions , attempt int ) time.Duration {
921
+ // Calculate the exponential backoff using the grow factor.
922
+ backoffDuration := time .Duration (float64 (opts .InitialDelay ) * math .Pow (opts .GrowFactor , float64 (attempt )))
923
+
924
+ // Cap the backoffDuration at maxDelay.
925
+ backoffDuration = min (backoffDuration , opts .MaxDelay )
926
+
927
+ // Use a full jitter using backoffDuration
928
+ jitter := rand .N (backoffDuration )
929
+
930
+ return backoffDuration + jitter
931
+ }
0 commit comments