diff --git a/design/design.md b/design/design.md index 8ab7c15..9c08b60 100644 --- a/design/design.md +++ b/design/design.md @@ -748,13 +748,26 @@ Server sessions also support the spec methods `ListResources` and `ListResourceT #### Subscriptions -ClientSessions can manage change notifications on particular resources: +##### Client-Side Usage + +Use the Subscribe and Unsubscribe methods on a ClientSession to start or stop receiving updates for a specific resource. ```go func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error ``` +To process incoming update notifications, you must provide a ResourceUpdatedHandler in your ClientOptions. The SDK calls this function automatically whenever the server sends a notification for a resource you're subscribed to. + +```go +type ClientOptions struct { + ... + ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) +} +``` + +##### Server-Side Implementation + The server does not implement resource subscriptions. It passes along subscription requests to the user, and supplies a method to notify clients of changes. It tracks which sessions have subscribed to which resources so the user doesn't have to. If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers. @@ -772,7 +785,7 @@ type ServerOptions struct { User code should call `ResourceUpdated` when a subscribed resource changes. ```go -func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error +func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotificationParams) error ``` The server routes these notifications to the server sessions that subscribed to the resource. diff --git a/mcp/client.go b/mcp/client.go index b48ad7a..b386294 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -60,6 +60,7 @@ type ClientOptions struct { ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) + ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams) // If non-zero, defines an interval for regular "ping" requests. @@ -293,6 +294,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)), notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)), notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)), + notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)), notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)), } @@ -386,6 +388,20 @@ func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) ( return handleSend[*CompleteResult](ctx, cs, methodComplete, orZero[Params](params)) } +// Subscribe sends a "resources/subscribe" request to the server, asking for +// notifications when the specified resource changes. +func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, cs, methodSubscribe, orZero[Params](params)) + return err +} + +// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling +// a previous subscription. +func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, cs, methodUnsubscribe, orZero[Params](params)) + return err +} + func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params) } @@ -398,6 +414,10 @@ func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSessio return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) } +func (c *Client) callResourceUpdatedHandler(ctx context.Context, s *ClientSession, params *ResourceUpdatedNotificationParams) (Result, error) { + return callNotificationHandler(ctx, c.opts.ResourceUpdatedHandler, s, params) +} + func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { h(ctx, cs, params) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 7da2b85..032181a 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -60,7 +60,7 @@ func TestEndToEnd(t *testing.T) { // Channels to check if notification callbacks happened. notificationChans := map[string]chan int{} - for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} { + for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe"} { notificationChans[name] = make(chan int, 1) } waitForNotification := func(t *testing.T, name string) { @@ -78,6 +78,14 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { notificationChans["progress_server"] <- 0 }, + SubscribeHandler: func(context.Context, *SubscribeParams) error { + notificationChans["subscribe"] <- 0 + return nil + }, + UnsubscribeHandler: func(context.Context, *UnsubscribeParams) error { + notificationChans["unsubscribe"] <- 0 + return nil + }, } s := NewServer(testImpl, sopts) AddTool(s, &Tool{ @@ -128,6 +136,9 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) { notificationChans["progress_client"] <- 0 }, + ResourceUpdatedHandler: func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) { + notificationChans["resource_updated"] <- 0 + }, } c := NewClient(testImpl, opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) @@ -421,6 +432,37 @@ func TestEndToEnd(t *testing.T) { waitForNotification(t, "progress_server") }) + t.Run("resource_subscriptions", func(t *testing.T) { + err := cs.Subscribe(ctx, &SubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "subscribe") + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + waitForNotification(t, "resource_updated") + err = cs.Unsubscribe(ctx, &UnsubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "unsubscribe") + + // Verify the client does not receive the update after unsubscribing. + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + select { + case <-notificationChans["resource_updated"]: + t.Fatalf("resource updated after unsubscription") + case <-time.After(time.Second): + } + }) + // Disconnect. cs.Close() clientWG.Wait() diff --git a/mcp/protocol.go b/mcp/protocol.go index 4f47c96..00dcd14 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -859,6 +859,38 @@ type ToolListChangedParams struct { func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } +// Sent from the client to request resources/updated notifications from the +// server whenever a particular resource changes. +type SubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to subscribe to. + URI string `json:"uri"` +} + +// Sent from the client to request cancellation of resources/updated +// notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +// A notification from the server to the client, informing it that a resource +// has changed and may need to be read again. This should only be sent if the +// client previously sent a resources/subscribe request. +type ResourceUpdatedNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + // TODO(jba): add CompleteRequest and related types. // TODO(jba): add ElicitRequest and related types. diff --git a/mcp/server.go b/mcp/server.go index 6b287ad..c4113e7 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -43,6 +43,7 @@ type Server struct { sessions []*ServerSession sendingMethodHandler_ MethodHandler[*ServerSession] receivingMethodHandler_ MethodHandler[*ServerSession] + resourceSubscriptions map[string][]*ServerSession // uri -> session } // ServerOptions is used to configure behavior of the server. @@ -64,6 +65,10 @@ type ServerOptions struct { // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeParams) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeParams) error } // NewServer creates a new MCP server. The resulting server has no features: @@ -88,6 +93,12 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { if opts.PageSize == 0 { opts.PageSize = DefaultPageSize } + if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil { + panic("SubscribeHandler requires UnsubscribeHandler") + } + if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { + panic("UnsubscribeHandler requires SubscribeHandler") + } return &Server{ impl: impl, opts: *opts, @@ -97,6 +108,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + resourceSubscriptions: make(map[string][]*ServerSession), } } @@ -224,6 +236,12 @@ func (s *Server) capabilities() *serverCapabilities { if s.resources.len() > 0 || s.resourceTemplates.len() > 0 { caps.Resources = &resourceCapabilities{ListChanged: true} } + if s.opts.SubscribeHandler != nil { + if caps.Resources == nil { + caps.Resources = &resourceCapabilities{} + } + caps.Resources.Subscribe = true + } return caps } @@ -426,6 +444,55 @@ func fileResourceHandler(dir string) ResourceHandler { } } +func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error { + s.mu.Lock() + sessions := slices.Clone(s.resourceSubscriptions[params.URI]) + s.mu.Unlock() + if len(sessions) == 0 { + return nil + } + notifySessions(sessions, notificationResourceUpdated, params) + return nil +} + +func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *SubscribeParams) (*emptyResult, error) { + if s.opts.SubscribeHandler == nil { + return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) + } + if err := s.opts.SubscribeHandler(ctx, params); err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + uri := params.URI + subscribers := s.resourceSubscriptions[uri] + if !slices.Contains(subscribers, ss) { + s.resourceSubscriptions[uri] = append(subscribers, ss) + } + return &emptyResult{}, nil +} + +func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *UnsubscribeParams) (*emptyResult, error) { + if s.opts.UnsubscribeHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + + if err := s.opts.UnsubscribeHandler(ctx, params); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + + uri := params.URI + if sessions, ok := s.resourceSubscriptions[uri]; ok { + s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(s *ServerSession) bool { + return s == ss + }) + } + return &emptyResult{}, nil +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -473,6 +540,11 @@ func (s *Server) disconnect(cc *ServerSession) { s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { return cc2 == cc }) + for uri, sessions := range s.resourceSubscriptions { + s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(cc2 *ServerSession) bool { + return cc2 == cc + }) + } } // Connect connects the MCP server over the given transport and starts handling @@ -614,6 +686,8 @@ var serverMethodInfos = map[string]methodInfo{ methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)), methodReadResource: newMethodInfo(serverMethod((*Server).readResource)), methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)), + methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)), + methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)), notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)), notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)), notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)), diff --git a/mcp/server_test.go b/mcp/server_test.go index d4243d7..7ca4b92 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -5,6 +5,7 @@ package mcp import ( + "context" "log" "slices" "testing" @@ -232,6 +233,7 @@ func TestServerCapabilities(t *testing.T) { testCases := []struct { name string configureServer func(s *Server) + serverOpts ServerOptions wantCapabilities *serverCapabilities }{ { @@ -275,6 +277,23 @@ func TestServerCapabilities(t *testing.T) { Resources: &resourceCapabilities{ListChanged: true}, }, }, + { + name: "With resource subscriptions", + configureServer: func(s *Server) {}, + serverOpts: ServerOptions{ + SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + return nil + }, + UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + return nil + }, + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{Subscribe: true}, + }, + }, { name: "With tools", configureServer: func(s *Server) { @@ -294,11 +313,19 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) s.AddTool(&Tool{Name: "t"}, nil) }, + serverOpts: ServerOptions{ + SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + return nil + }, + UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + return nil + }, + }, wantCapabilities: &serverCapabilities{ Completions: &completionCapabilities{}, Logging: &loggingCapabilities{}, Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, + Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, Tools: &toolCapabilities{ListChanged: true}, }, }, @@ -306,7 +333,7 @@ func TestServerCapabilities(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - server := NewServer(testImpl, nil) + server := NewServer(testImpl, &tc.serverOpts) tc.configureServer(server) gotCapabilities := server.capabilities() if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" {