Skip to content

Commit c037ba5

Browse files
authored
mcp: negotiate protocol version (#112)
Negotiate the protocol version properly, and set the header on HTTP transports. We follow the logic of the Typescript SDK. - On initialization, the client sends the latest version it can support, and accepts the version that the server returns if the client supports it. If not, the connection fails. - The server accepts the client's version unless it doesn't support it, in which case it replies with its latest version. Fixes #103.
1 parent a1a3510 commit c037ba5

File tree

6 files changed

+73
-16
lines changed

6 files changed

+73
-16
lines changed

mcp/client.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package mcp
66

77
import (
88
"context"
9+
"fmt"
910
"iter"
1011
"slices"
1112
"sync"
@@ -86,6 +87,15 @@ func (c *Client) disconnect(cs *ClientSession) {
8687
})
8788
}
8889

90+
// TODO: Consider exporting this type and its field.
91+
type unsupportedProtocolVersionError struct {
92+
version string
93+
}
94+
95+
func (e unsupportedProtocolVersionError) Error() string {
96+
return fmt.Sprintf("unsupported protocol version: %q", e.version)
97+
}
98+
8999
// Connect begins an MCP session by connecting to a server over the given
90100
// transport, and initializing the session.
91101
//
@@ -106,19 +116,22 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e
106116
}
107117

108118
params := &InitializeParams{
119+
ProtocolVersion: latestProtocolVersion,
109120
ClientInfo: &implementation{Name: c.name, Version: c.version},
110121
Capabilities: caps,
111-
ProtocolVersion: "2025-03-26",
112122
}
113-
// TODO(rfindley): handle protocol negotiation gracefully. If the server
114-
// responds with 2024-11-05, surface that failure to the caller of connect,
115-
// so that they can choose a different transport.
116123
res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params)
117124
if err != nil {
118125
_ = cs.Close()
119126
return nil, err
120127
}
128+
if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) {
129+
return nil, unsupportedProtocolVersionError{res.ProtocolVersion}
130+
}
121131
cs.initializeResult = res
132+
if hc, ok := cs.mcpConn.(httpConnection); ok {
133+
hc.setProtocolVersion(res.ProtocolVersion)
134+
}
122135
if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil {
123136
_ = cs.Close()
124137
return nil, err

mcp/server.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -647,10 +647,11 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam
647647
ss.mu.Unlock()
648648
}()
649649

650-
version := "2025-03-26" // preferred version
651-
switch v := params.ProtocolVersion; v {
652-
case "2024-11-05", "2025-03-26":
653-
version = v
650+
// If we support the client's version, reply with it. Otherwise, reply with our
651+
// latest version.
652+
version := params.ProtocolVersion
653+
if !slices.Contains(supportedProtocolVersions, params.ProtocolVersion) {
654+
version = latestProtocolVersion
654655
}
655656

656657
return &InitializeResult{

mcp/shared.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ import (
2222
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2323
)
2424

25+
// latestProtocolVersion is the latest protocol version that this version of the SDK supports.
26+
// It is the version that the client sends in the initialization request.
27+
const latestProtocolVersion = "2025-06-18"
28+
29+
var supportedProtocolVersions = []string{
30+
latestProtocolVersion,
31+
"2025-03-26",
32+
"2024-11-05",
33+
}
34+
2535
// A MethodHandler handles MCP messages.
2636
// For methods, exactly one of the return values must be nil.
2737
// For notifications, both must be nil.

mcp/streamable.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ import (
1818
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
1919
)
2020

21+
const (
22+
protocolVersionHeader = "Mcp-Protocol-Version"
23+
sessionIDHeader = "Mcp-Session-Id"
24+
)
25+
2126
// A StreamableHTTPHandler is an http.Handler that serves streamable MCP
2227
// sessions, as defined by the [MCP spec].
2328
//
@@ -88,7 +93,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
8893
}
8994

9095
var session *StreamableServerTransport
91-
if id := req.Header.Get("Mcp-Session-Id"); id != "" {
96+
if id := req.Header.Get(sessionIDHeader); id != "" {
9297
h.sessionsMu.Lock()
9398
session, _ = h.sessions[id]
9499
h.sessionsMu.Unlock()
@@ -386,7 +391,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h
386391
t.mu.Unlock()
387392
}
388393

389-
w.Header().Set("Mcp-Session-Id", t.id)
394+
w.Header().Set(sessionIDHeader, t.id)
390395
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
391396
w.Header().Set("Cache-Control", "no-cache, no-transform")
392397
w.Header().Set("Connection", "keep-alive")
@@ -636,12 +641,19 @@ type streamableClientConn struct {
636641
closeOnce sync.Once
637642
closeErr error
638643

639-
mu sync.Mutex
640-
_sessionID string
644+
mu sync.Mutex
645+
protocolVersion string
646+
_sessionID string
641647
// bodies map[*http.Response]io.Closer
642648
err error
643649
}
644650

651+
func (c *streamableClientConn) setProtocolVersion(s string) {
652+
c.mu.Lock()
653+
defer c.mu.Unlock()
654+
c.protocolVersion = s
655+
}
656+
645657
func (c *streamableClientConn) SessionID() string {
646658
c.mu.Lock()
647659
defer c.mu.Unlock()
@@ -707,8 +719,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
707719
if err != nil {
708720
return "", err
709721
}
722+
if s.protocolVersion != "" {
723+
req.Header.Set(protocolVersionHeader, s.protocolVersion)
724+
}
710725
if sessionID != "" {
711-
req.Header.Set("Mcp-Session-Id", sessionID)
726+
req.Header.Set(sessionIDHeader, sessionID)
712727
}
713728
req.Header.Set("Content-Type", "application/json")
714729
req.Header.Set("Accept", "application/json, text/event-stream")
@@ -724,7 +739,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
724739
return "", fmt.Errorf("broken session: %v", resp.Status)
725740
}
726741

727-
sessionID = resp.Header.Get("Mcp-Session-Id")
742+
sessionID = resp.Header.Get(sessionIDHeader)
728743
if resp.Header.Get("Content-Type") == "text/event-stream" {
729744
go s.handleSSE(resp)
730745
} else {
@@ -763,7 +778,11 @@ func (s *streamableClientConn) Close() error {
763778
if err != nil {
764779
s.closeErr = err
765780
} else {
766-
req.Header.Set("Mcp-Session-Id", s._sessionID)
781+
// TODO(jba): confirm that we don't need a lock here, or add locking.
782+
if s.protocolVersion != "" {
783+
req.Header.Set(protocolVersionHeader, s.protocolVersion)
784+
}
785+
req.Header.Set(sessionIDHeader, s._sessionID)
767786
if _, err := s.client.Do(req); err != nil {
768787
s.closeErr = err
769788
}

mcp/streamable_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ func TestStreamableTransports(t *testing.T) {
3636
// 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
3737
// cookie-checking middleware.
3838
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
39+
var header http.Header
3940
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41+
header = r.Header
4042
cookie, err := r.Cookie("test-cookie")
4143
if err != nil {
4244
t.Errorf("missing cookie: %v", err)
@@ -72,6 +74,9 @@ func TestStreamableTransports(t *testing.T) {
7274
if sid == "" {
7375
t.Error("empty session ID")
7476
}
77+
if g, w := session.mcpConn.(*streamableClientConn).protocolVersion, latestProtocolVersion; g != w {
78+
t.Fatalf("got protocol version %q, want %q", g, w)
79+
}
7580
// 4. The client calls the "greet" tool.
7681
params := &CallToolParams{
7782
Name: "greet",
@@ -84,6 +89,9 @@ func TestStreamableTransports(t *testing.T) {
8489
if g := session.ID(); g != sid {
8590
t.Errorf("session ID: got %q, want %q", g, sid)
8691
}
92+
if g, w := header.Get(protocolVersionHeader), latestProtocolVersion; g != w {
93+
t.Errorf("got protocol version header %q, want %q", g, w)
94+
}
8795

8896
// 5. Verify that the correct response is received.
8997
want := &CallToolResult{
@@ -154,7 +162,7 @@ func TestStreamableServerTransport(t *testing.T) {
154162
Resources: &resourceCapabilities{ListChanged: true},
155163
Tools: &toolCapabilities{ListChanged: true},
156164
},
157-
ProtocolVersion: "2025-03-26",
165+
ProtocolVersion: latestProtocolVersion,
158166
ServerInfo: &implementation{Name: "testServer", Version: "v1.0.0"},
159167
}, nil)
160168
initializedMsg := req(0, "initialized", &InitializedParams{})

mcp/transport.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ type Connection interface {
5353
SessionID() string
5454
}
5555

56+
// An httpConnection is a [Connection] that runs over HTTP.
57+
type httpConnection interface {
58+
Connection
59+
setProtocolVersion(string)
60+
}
61+
5662
// A StdioTransport is a [Transport] that communicates over stdin/stdout using
5763
// newline-delimited JSON.
5864
type StdioTransport struct {

0 commit comments

Comments
 (0)