Skip to content
130 changes: 90 additions & 40 deletions src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
ListResourceTemplatesRequest,
ListResourceTemplatesResultSchema,
ListToolsRequest,
ListToolsResult,
ListToolsResultSchema,
LoggingLevel,
Notification,
Expand Down Expand Up @@ -76,7 +77,7 @@ export type ClientOptions = ProtocolOptions & {
export class Client<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result,
ResultT extends Result = Result
> extends Protocol<
ClientRequest | RequestT,
ClientNotification | NotificationT,
Expand All @@ -87,15 +88,38 @@ export class Client<
private _capabilities: ClientCapabilities;
private _instructions?: string;

/**
* Callback for when the server indicates that the tools list has changed.
* Client should typically refresh its list of tools in response.
*/
onToolListChanged?: (tools?: ListToolsResult["tools"]) => void;

/**
* Initializes this client with the given name and version information.
*/
constructor(
private _clientInfo: Implementation,
options?: ClientOptions,
) {
constructor(private _clientInfo: Implementation, options?: ClientOptions) {
super(options);
this._capabilities = options?.capabilities ?? {};

// Set up notification handlers
this.setNotificationHandler(
"notifications/tools/list_changed",
async () => {
// Automatically refresh the tools list when the server indicates a change
try {
// Only refresh if the server supports tools
if (this._serverCapabilities?.tools) {
const result = await this.listTools();
// Call the user's callback with the updated tools list
this.onToolListChanged?.(result.tools);
}
} catch (error) {
console.error("Failed to refresh tools list:", error);
// Still call the callback even if refresh failed
this.onToolListChanged?.(undefined);
}
}
);
}

/**
Expand All @@ -106,7 +130,7 @@ export class Client<
public registerCapabilities(capabilities: ClientCapabilities): void {
if (this.transport) {
throw new Error(
"Cannot register capabilities after connecting to transport",
"Cannot register capabilities after connecting to transport"
);
}

Expand All @@ -115,11 +139,11 @@ export class Client<

protected assertCapability(
capability: keyof ServerCapabilities,
method: string,
method: string
): void {
if (!this._serverCapabilities?.[capability]) {
throw new Error(
`Server does not support ${capability} (required for ${method})`,
`Server does not support ${String(capability)} (required for ${method})`
);
}
}
Expand All @@ -137,7 +161,7 @@ export class Client<
clientInfo: this._clientInfo,
},
},
InitializeResultSchema,
InitializeResultSchema
);

if (result === undefined) {
Expand All @@ -146,7 +170,7 @@ export class Client<

if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) {
throw new Error(
`Server's protocol version is not supported: ${result.protocolVersion}`,
`Server's protocol version is not supported: ${result.protocolVersion}`
);
}

Expand Down Expand Up @@ -191,7 +215,7 @@ export class Client<
case "logging/setLevel":
if (!this._serverCapabilities?.logging) {
throw new Error(
`Server does not support logging (required for ${method})`,
`Server does not support logging (required for ${method})`
);
}
break;
Expand All @@ -200,7 +224,7 @@ export class Client<
case "prompts/list":
if (!this._serverCapabilities?.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
`Server does not support prompts (required for ${method})`
);
}
break;
Expand All @@ -212,7 +236,7 @@ export class Client<
case "resources/unsubscribe":
if (!this._serverCapabilities?.resources) {
throw new Error(
`Server does not support resources (required for ${method})`,
`Server does not support resources (required for ${method})`
);
}

Expand All @@ -221,7 +245,7 @@ export class Client<
!this._serverCapabilities.resources.subscribe
) {
throw new Error(
`Server does not support resource subscriptions (required for ${method})`,
`Server does not support resource subscriptions (required for ${method})`
);
}

Expand All @@ -231,15 +255,15 @@ export class Client<
case "tools/list":
if (!this._serverCapabilities?.tools) {
throw new Error(
`Server does not support tools (required for ${method})`,
`Server does not support tools (required for ${method})`
);
}
break;

case "completion/complete":
if (!this._serverCapabilities?.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
`Server does not support prompts (required for ${method})`
);
}
break;
Expand All @@ -255,13 +279,23 @@ export class Client<
}

protected assertNotificationCapability(
method: NotificationT["method"],
method: NotificationT["method"]
): void {
switch (method as ClientNotification["method"]) {
case "notifications/roots/list_changed":
if (!this._capabilities.roots?.listChanged) {
throw new Error(
`Client does not support roots list changed notifications (required for ${method})`,
`Client does not support roots list changed notifications (required for ${method})`
);
}
break;

case "notifications/tools/list_changed":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, client will not send notifications/tools/list_changed

if (!this._capabilities.tools?.listChanged) {
throw new Error(
`Client does not support tools capability (required for ${String(
method
)})`
);
}
break;
Expand All @@ -285,15 +319,15 @@ export class Client<
case "sampling/createMessage":
if (!this._capabilities.sampling) {
throw new Error(
`Client does not support sampling capability (required for ${method})`,
`Client does not support sampling capability (required for ${method})`
);
}
break;

case "roots/list":
if (!this._capabilities.roots) {
throw new Error(
`Client does not support roots capability (required for ${method})`,
`Client does not support roots capability (required for ${method})`
);
}
break;
Expand All @@ -312,92 +346,92 @@ export class Client<
return this.request(
{ method: "completion/complete", params },
CompleteResultSchema,
options,
options
);
}

async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) {
return this.request(
{ method: "logging/setLevel", params: { level } },
EmptyResultSchema,
options,
options
);
}

async getPrompt(
params: GetPromptRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "prompts/get", params },
GetPromptResultSchema,
options,
options
);
}

async listPrompts(
params?: ListPromptsRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "prompts/list", params },
ListPromptsResultSchema,
options,
options
);
}

async listResources(
params?: ListResourcesRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/list", params },
ListResourcesResultSchema,
options,
options
);
}

async listResourceTemplates(
params?: ListResourceTemplatesRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/templates/list", params },
ListResourceTemplatesResultSchema,
options,
options
);
}

async readResource(
params: ReadResourceRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/read", params },
ReadResourceResultSchema,
options,
options
);
}

async subscribeResource(
params: SubscribeRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/subscribe", params },
EmptyResultSchema,
options,
options
);
}

async unsubscribeResource(
params: UnsubscribeRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "resources/unsubscribe", params },
EmptyResultSchema,
options,
options
);
}

Expand All @@ -406,27 +440,43 @@ export class Client<
resultSchema:
| typeof CallToolResultSchema
| typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "tools/call", params },
resultSchema,
options,
options
);
}

async listTools(
params?: ListToolsRequest["params"],
options?: RequestOptions,
options?: RequestOptions
) {
return this.request(
{ method: "tools/list", params },
ListToolsResultSchema,
options,
options
);
}

/**
* Registers a callback to be called when the server indicates that
* the tools list has changed. The callback should typically refresh the tools list.
*
* @param callback Function to call when tools list changes
*/
setToolListChangedCallback(
callback: (tools?: ListToolsResult["tools"]) => void
): void {
this.onToolListChanged = callback;
}

async sendRootsListChanged() {
return this.notification({ method: "notifications/roots/list_changed" });
}

async sendToolListChanged() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't look right, client should receive the "notifications/tools/list_changed" notification, not send it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnjjung did you have a chance to check this?

return this.notification({ method: "notifications/tools/list_changed" });
}
}