Skip to content

mcp: add resource subscriptions #138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions design/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -748,13 +748,26 @@ Server sessions also support the spec methods `ListResources` and `ListResourceT

#### Subscriptions

ClientSessions can manage change notifications on particular resources:
##### Client-Side Usage

Use the Subscribe and Unsubscribe methods on a ClientSession to start or stop receiving updates for a specific resource.

```go
func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error
func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error
```

To process incoming update notifications, you must provide a ResourceUpdatedHandler in your ClientOptions. The SDK calls this function automatically whenever the server sends a notification for a resource you're subscribed to.

```go
type ClientOptions struct {
...
ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams)
}
```

##### Server-Side Implementation

The server does not implement resource subscriptions. It passes along subscription requests to the user, and supplies a method to notify clients of changes. It tracks which sessions have subscribed to which resources so the user doesn't have to.

If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers.
Expand All @@ -772,7 +785,7 @@ type ServerOptions struct {
User code should call `ResourceUpdated` when a subscribed resource changes.

```go
func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error
func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotificationParams) error
```

The server routes these notifications to the server sessions that subscribed to the resource.
Expand Down
20 changes: 20 additions & 0 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type ClientOptions struct {
ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams)
PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams)
ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams)
ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams)
LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams)
ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams)
// If non-zero, defines an interval for regular "ping" requests.
Expand Down Expand Up @@ -293,6 +294,7 @@ var clientMethodInfos = map[string]methodInfo{
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)),
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)),
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)),
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)),
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)),
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)),
}
Expand Down Expand Up @@ -386,6 +388,20 @@ func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (
return handleSend[*CompleteResult](ctx, cs, methodComplete, orZero[Params](params))
}

// Subscribe sends a "resources/subscribe" request to the server, asking for
// notifications when the specified resource changes.
func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error {
_, err := handleSend[*emptyResult](ctx, cs, methodSubscribe, orZero[Params](params))
return err
}

// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling
// a previous subscription.
func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error {
_, err := handleSend[*emptyResult](ctx, cs, methodUnsubscribe, orZero[Params](params))
return err
}

func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) {
return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params)
}
Expand All @@ -398,6 +414,10 @@ func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSessio
return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params)
}

func (c *Client) callResourceUpdatedHandler(ctx context.Context, s *ClientSession, params *ResourceUpdatedNotificationParams) (Result, error) {
return callNotificationHandler(ctx, c.opts.ResourceUpdatedHandler, s, params)
}

func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) {
if h := c.opts.LoggingMessageHandler; h != nil {
h(ctx, cs, params)
Expand Down
44 changes: 43 additions & 1 deletion mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestEndToEnd(t *testing.T) {

// Channels to check if notification callbacks happened.
notificationChans := map[string]chan int{}
for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} {
for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe"} {
notificationChans[name] = make(chan int, 1)
}
waitForNotification := func(t *testing.T, name string) {
Expand All @@ -78,6 +78,14 @@ func TestEndToEnd(t *testing.T) {
ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) {
notificationChans["progress_server"] <- 0
},
SubscribeHandler: func(context.Context, *SubscribeParams) error {
notificationChans["subscribe"] <- 0
return nil
},
UnsubscribeHandler: func(context.Context, *UnsubscribeParams) error {
notificationChans["unsubscribe"] <- 0
return nil
},
}
s := NewServer(testImpl, sopts)
AddTool(s, &Tool{
Expand Down Expand Up @@ -128,6 +136,9 @@ func TestEndToEnd(t *testing.T) {
ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) {
notificationChans["progress_client"] <- 0
},
ResourceUpdatedHandler: func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) {
notificationChans["resource_updated"] <- 0
},
}
c := NewClient(testImpl, opts)
rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files"))
Expand Down Expand Up @@ -421,6 +432,37 @@ func TestEndToEnd(t *testing.T) {
waitForNotification(t, "progress_server")
})

t.Run("resource_subscriptions", func(t *testing.T) {
err := cs.Subscribe(ctx, &SubscribeParams{
URI: "test",
})
if err != nil {
t.Fatal(err)
}
waitForNotification(t, "subscribe")
s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{
URI: "test",
})
waitForNotification(t, "resource_updated")
err = cs.Unsubscribe(ctx, &UnsubscribeParams{
URI: "test",
})
if err != nil {
t.Fatal(err)
}
waitForNotification(t, "unsubscribe")

// Verify the client does not receive the update after unsubscribing.
s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{
URI: "test",
})
select {
case <-notificationChans["resource_updated"]:
t.Fatalf("resource updated after unsubscription")
case <-time.After(time.Second):
}
})

// Disconnect.
cs.Close()
clientWG.Wait()
Expand Down
32 changes: 32 additions & 0 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,38 @@ type ToolListChangedParams struct {
func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) }
func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) }

// Sent from the client to request resources/updated notifications from the
// server whenever a particular resource changes.
type SubscribeParams struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
// The URI of the resource to subscribe to.
URI string `json:"uri"`
}

// Sent from the client to request cancellation of resources/updated
// notifications from the server. This should follow a previous
// resources/subscribe request.
type UnsubscribeParams struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
// The URI of the resource to unsubscribe from.
URI string `json:"uri"`
}

// A notification from the server to the client, informing it that a resource
// has changed and may need to be read again. This should only be sent if the
// client previously sent a resources/subscribe request.
type ResourceUpdatedNotificationParams struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
// The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to.
URI string `json:"uri"`
}

// TODO(jba): add CompleteRequest and related types.

// TODO(jba): add ElicitRequest and related types.
Expand Down
74 changes: 74 additions & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Server struct {
sessions []*ServerSession
sendingMethodHandler_ MethodHandler[*ServerSession]
receivingMethodHandler_ MethodHandler[*ServerSession]
resourceSubscriptions map[string][]*ServerSession // uri -> session
}

// ServerOptions is used to configure behavior of the server.
Expand All @@ -64,6 +65,10 @@ type ServerOptions struct {
// If the peer fails to respond to pings originating from the keepalive check,
// the session is automatically closed.
KeepAlive time.Duration
// Function called when a client session subscribes to a resource.
SubscribeHandler func(context.Context, *SubscribeParams) error
// Function called when a client session unsubscribes from a resource.
UnsubscribeHandler func(context.Context, *UnsubscribeParams) error
}

// NewServer creates a new MCP server. The resulting server has no features:
Expand All @@ -88,6 +93,12 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
if opts.PageSize == 0 {
opts.PageSize = DefaultPageSize
}
if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil {
panic("SubscribeHandler requires UnsubscribeHandler")
}
if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil {
panic("UnsubscribeHandler requires SubscribeHandler")
}
return &Server{
impl: impl,
opts: *opts,
Expand All @@ -97,6 +108,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }),
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession],
resourceSubscriptions: make(map[string][]*ServerSession),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be a little nicer if it was a map[string]map[string]bool. Adding and deleting would be trivial.

}
}

Expand Down Expand Up @@ -224,6 +236,12 @@ func (s *Server) capabilities() *serverCapabilities {
if s.resources.len() > 0 || s.resourceTemplates.len() > 0 {
caps.Resources = &resourceCapabilities{ListChanged: true}
}
if s.opts.SubscribeHandler != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.
I think that if there are no resources, it doesn't matter if there's a subscribe handler. There is no resource capability.
It's OK to provide a handler, it just doesn't do anything.

if caps.Resources == nil {
caps.Resources = &resourceCapabilities{}
}
caps.Resources.Subscribe = true
}
return caps
}

Expand Down Expand Up @@ -426,6 +444,55 @@ func fileResourceHandler(dir string) ResourceHandler {
}
}

func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error {
s.mu.Lock()
sessions := slices.Clone(s.resourceSubscriptions[params.URI])
s.mu.Unlock()
if len(sessions) == 0 {
return nil
}
notifySessions(sessions, notificationResourceUpdated, params)
return nil
}

func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *SubscribeParams) (*emptyResult, error) {
if s.opts.SubscribeHandler == nil {
return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound)
}
if err := s.opts.SubscribeHandler(ctx, params); err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
uri := params.URI
subscribers := s.resourceSubscriptions[uri]
if !slices.Contains(subscribers, ss) {
s.resourceSubscriptions[uri] = append(subscribers, ss)
}
return &emptyResult{}, nil
}

func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *UnsubscribeParams) (*emptyResult, error) {
if s.opts.UnsubscribeHandler == nil {
return nil, jsonrpc2.ErrMethodNotFound
}

if err := s.opts.UnsubscribeHandler(ctx, params); err != nil {
return nil, err
}

s.mu.Lock()
defer s.mu.Unlock()

uri := params.URI
if sessions, ok := s.resourceSubscriptions[uri]; ok {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well skip the check and just call DeleteFunc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If you prefer to use a slice.)

s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(s *ServerSession) bool {
return s == ss
})
}
return &emptyResult{}, nil
}

// Run runs the server over the given transport, which must be persistent.
//
// Run blocks until the client terminates the connection or the provided
Expand Down Expand Up @@ -473,6 +540,11 @@ func (s *Server) disconnect(cc *ServerSession) {
s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool {
return cc2 == cc
})
for uri, sessions := range s.resourceSubscriptions {
s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(cc2 *ServerSession) bool {
return cc2 == cc
})
}
}

// Connect connects the MCP server over the given transport and starts handling
Expand Down Expand Up @@ -614,6 +686,8 @@ var serverMethodInfos = map[string]methodInfo{
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)),
methodReadResource: newMethodInfo(serverMethod((*Server).readResource)),
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)),
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)),
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this ever be run in reality?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are run on the server side when the client calls Subscribe and Unsubscribe.

notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)),
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)),
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)),
Expand Down
31 changes: 29 additions & 2 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package mcp

import (
"context"
"log"
"slices"
"testing"
Expand Down Expand Up @@ -232,6 +233,7 @@ func TestServerCapabilities(t *testing.T) {
testCases := []struct {
name string
configureServer func(s *Server)
serverOpts ServerOptions
wantCapabilities *serverCapabilities
}{
{
Expand Down Expand Up @@ -275,6 +277,23 @@ func TestServerCapabilities(t *testing.T) {
Resources: &resourceCapabilities{ListChanged: true},
},
},
{
name: "With resource subscriptions",
configureServer: func(s *Server) {},
serverOpts: ServerOptions{
SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error {
return nil
},
UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error {
return nil
},
},
wantCapabilities: &serverCapabilities{
Completions: &completionCapabilities{},
Logging: &loggingCapabilities{},
Resources: &resourceCapabilities{Subscribe: true},
},
},
{
name: "With tools",
configureServer: func(s *Server) {
Expand All @@ -294,19 +313,27 @@ func TestServerCapabilities(t *testing.T) {
s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil)
s.AddTool(&Tool{Name: "t"}, nil)
},
serverOpts: ServerOptions{
SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error {
return nil
},
UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error {
return nil
},
},
wantCapabilities: &serverCapabilities{
Completions: &completionCapabilities{},
Logging: &loggingCapabilities{},
Prompts: &promptCapabilities{ListChanged: true},
Resources: &resourceCapabilities{ListChanged: true},
Resources: &resourceCapabilities{ListChanged: true, Subscribe: true},
Tools: &toolCapabilities{ListChanged: true},
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
server := NewServer(testImpl, nil)
server := NewServer(testImpl, &tc.serverOpts)
tc.configureServer(server)
gotCapabilities := server.capabilities()
if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" {
Expand Down
Loading