From 353224d22c8c4f428936b852ba36b428c0647f87 Mon Sep 17 00:00:00 2001 From: Jesse Date: Thu, 3 Jul 2025 14:50:02 -0600 Subject: [PATCH 1/8] fix: always try to reconnect on stream disconnection --- src/client/streamableHttp.ts | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index b81f1a5d..375dc411 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -349,18 +349,16 @@ const response = await (this._fetch ?? fetch)(this._url, { // Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing if (this._abortController && !this._abortController.signal.aborted) { // Use the exponential backoff reconnection strategy - if (lastEventId !== undefined) { - try { - this._scheduleReconnection({ - resumptionToken: lastEventId, - onresumptiontoken, - replayMessageId - }, 0); - } - catch (error) { - this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); + try { + this._scheduleReconnection({ + resumptionToken: lastEventId, + onresumptiontoken, + replayMessageId + }, 0); + } + catch (error) { + this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); - } } } } From aec85cdbc955da1ab9448c7f75142b7cda1dd71a Mon Sep 17 00:00:00 2001 From: Jesse Date: Thu, 3 Jul 2025 16:48:23 -0600 Subject: [PATCH 2/8] test: tests should ensure reconnection functionality --- src/client/streamableHttp.test.ts | 79 +++++++++++++++++++++++++++++++ src/client/streamableHttp.ts | 8 ++-- 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index dcd76528..dfbca8a4 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -406,6 +406,7 @@ describe("StreamableHTTPClientTransport", () => { const headers = fetchCall[1].headers; expect(headers.get("last-event-id")).toBe("test-event-id"); }); + it("should throw error when invalid content-type is received", async () => { // Clear any previous state from other tests @@ -592,4 +593,82 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); + + + + it('should reconnect on stream failure even without a lastEventId', async () => { + // This test simulates a scenario where the initial SSE stream fails immediately, + // but the reconnection logic kicks in and successfully reconnects without needing a lastEventId. + + // ARRANGE + + // 1. Configure a transport that will retry quickly and at least once. + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + reconnectionOptions: { + initialReconnectionDelay: 10, // Reconnect almost instantly for the test + maxReconnectionDelay: 100, + reconnectionDelayGrowFactor: 1, + maxRetries: 1, // We only need to see one successful retry attempt + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + // 2. Mock the initial GET request. It will connect, but the stream will die immediately. + // This simulates the GCloud proxy killing the connection. + const failingStream = new ReadableStream({ + start(controller) { + // Simulate an abrupt network error. + controller.error(new Error("Network connection terminated")); + } + }); + + const fetchMock = global.fetch as jest.Mock; + fetchMock.mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: failingStream, + }); + + // 3. Mock the SECOND GET request (the reconnection attempt). This one can succeed. + fetchMock.mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: new ReadableStream(), // A stable, empty stream + }); + + // ACT + + // 4. Start the transport and initiate the SSE connection. + await transport.start(); + // We call the internal method directly to trigger the GET request. + // This is cleaner than sending a full 'initialize' message for this test. + await transport["_startOrAuthSse"]({}); + + // 5. Advance timers to trigger the setTimeout in _scheduleReconnection. + await jest.advanceTimersByTimeAsync(20); // More than the 10ms delay + + // ASSERT + + // 6. Verify the initial disconnect error was caught. + expect(errorSpy).toHaveBeenCalledTimes(1); + expect(errorSpy).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network connection terminated'), + }) + ); + + // 7. THIS IS THE KEY ASSERTION: Verify that a second fetch call was made. + // This proves the reconnection logic was triggered. + expect(fetchMock).toHaveBeenCalledTimes(2); + + // 8. Verify the second call was a GET request without a last-event-id header. + const secondCall = fetchMock.mock.calls[1]; + const secondRequest = secondCall[1]; + expect(secondRequest.method).toBe('GET'); + expect(secondRequest.headers.has('last-event-id')).toBe(false); + }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 375dc411..bce65048 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -291,10 +291,10 @@ const response = await (this._fetch ?? fetch)(this._url, { // Schedule the reconnection setTimeout(() => { - // Use the last event ID to resume where we left off + this._abortController = new AbortController(); + this._startOrAuthSse(options).catch(error => { this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`)); - // Schedule another attempt if this one failed, incrementing the attempt counter this._scheduleReconnection(options, attemptCount + 1); }); }, delay); @@ -367,12 +367,12 @@ const response = await (this._fetch ?? fetch)(this._url, { } async start() { - if (this._abortController) { + if (this._abortController && !this._abortController.signal.aborted) { // Check if it's already running throw new Error( "StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.", ); } - + // Always create a fresh AbortController when starting a new connection sequence. this._abortController = new AbortController(); } From dfbcecfcb128c49e36e088debaa00453d4739d77 Mon Sep 17 00:00:00 2001 From: Jesse Date: Thu, 3 Jul 2025 17:06:09 -0600 Subject: [PATCH 3/8] fix: add jest timers to test --- src/client/streamableHttp.test.ts | 130 +++++++++++++++--------------- src/client/streamableHttp.ts | 3 +- 2 files changed, 67 insertions(+), 66 deletions(-) diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index dfbca8a4..89137abe 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -595,80 +595,82 @@ describe("StreamableHTTPClientTransport", () => { }); + describe('Reconnection Logic', () => { + // Use fake timers to control setTimeout and make the test instant. + beforeEach(() => jest.useFakeTimers()); + afterEach(() => jest.useRealTimers()); + + it('should reconnect on stream failure even without a lastEventId', async () => { + // ARRANGE + + // 1. Configure a transport that will retry quickly and at least once. + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + reconnectionOptions: { + initialReconnectionDelay: 10, // Reconnect almost instantly for the test + maxReconnectionDelay: 100, + reconnectionDelayGrowFactor: 1, + maxRetries: 1, // We only need to see one successful retry attempt + } + }); - it('should reconnect on stream failure even without a lastEventId', async () => { - // This test simulates a scenario where the initial SSE stream fails immediately, - // but the reconnection logic kicks in and successfully reconnects without needing a lastEventId. - - // ARRANGE - - // 1. Configure a transport that will retry quickly and at least once. - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - reconnectionOptions: { - initialReconnectionDelay: 10, // Reconnect almost instantly for the test - maxReconnectionDelay: 100, - reconnectionDelayGrowFactor: 1, - maxRetries: 1, // We only need to see one successful retry attempt - } - }); - - const errorSpy = jest.fn(); - transport.onerror = errorSpy; + const errorSpy = jest.fn(); + transport.onerror = errorSpy; - // 2. Mock the initial GET request. It will connect, but the stream will die immediately. - // This simulates the GCloud proxy killing the connection. - const failingStream = new ReadableStream({ - start(controller) { - // Simulate an abrupt network error. - controller.error(new Error("Network connection terminated")); - } - }); + // 2. Mock the initial GET request. It will connect, but the stream will die immediately. + // This simulates the GCloud proxy killing the connection. + const failingStream = new ReadableStream({ + start(controller) { + // Simulate an abrupt network error. + controller.error(new Error("Network connection terminated")); + } + }); - const fetchMock = global.fetch as jest.Mock; - fetchMock.mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: failingStream, - }); + const fetchMock = global.fetch as jest.Mock; + fetchMock.mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: failingStream, + }); - // 3. Mock the SECOND GET request (the reconnection attempt). This one can succeed. - fetchMock.mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: new ReadableStream(), // A stable, empty stream - }); + // 3. Mock the SECOND GET request (the reconnection attempt). This one can succeed. + fetchMock.mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: new ReadableStream(), // A stable, empty stream + }); - // ACT + // ACT - // 4. Start the transport and initiate the SSE connection. - await transport.start(); - // We call the internal method directly to trigger the GET request. - // This is cleaner than sending a full 'initialize' message for this test. - await transport["_startOrAuthSse"]({}); + // 4. Start the transport and initiate the SSE connection. + await transport.start(); + // We call the internal method directly to trigger the GET request. + // This is cleaner than sending a full 'initialize' message for this test. + await transport["_startOrAuthSse"]({}); - // 5. Advance timers to trigger the setTimeout in _scheduleReconnection. - await jest.advanceTimersByTimeAsync(20); // More than the 10ms delay + // 5. Advance timers to trigger the setTimeout in _scheduleReconnection. + await jest.advanceTimersByTimeAsync(20); // More than the 10ms delay - // ASSERT + // ASSERT - // 6. Verify the initial disconnect error was caught. - expect(errorSpy).toHaveBeenCalledTimes(1); - expect(errorSpy).toHaveBeenCalledWith( - expect.objectContaining({ - message: expect.stringContaining('SSE stream disconnected: Error: Network connection terminated'), - }) - ); + // 6. Verify the initial disconnect error was caught. + expect(errorSpy).toHaveBeenCalledTimes(1); + expect(errorSpy).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network connection terminated'), + }) + ); - // 7. THIS IS THE KEY ASSERTION: Verify that a second fetch call was made. - // This proves the reconnection logic was triggered. - expect(fetchMock).toHaveBeenCalledTimes(2); + // 7. THIS IS THE KEY ASSERTION: Verify that a second fetch call was made. + // This proves the reconnection logic was triggered. + expect(fetchMock).toHaveBeenCalledTimes(2); - // 8. Verify the second call was a GET request without a last-event-id header. - const secondCall = fetchMock.mock.calls[1]; - const secondRequest = secondCall[1]; - expect(secondRequest.method).toBe('GET'); - expect(secondRequest.headers.has('last-event-id')).toBe(false); + // 8. Verify the second call was a GET request without a last-event-id header. + const secondCall = fetchMock.mock.calls[1]; + const secondRequest = secondCall[1]; + expect(secondRequest.method).toBe('GET'); + expect(secondRequest.headers.has('last-event-id')).toBe(false); + }); }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index bce65048..725b0457 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -367,12 +367,11 @@ const response = await (this._fetch ?? fetch)(this._url, { } async start() { - if (this._abortController && !this._abortController.signal.aborted) { // Check if it's already running + if (this._abortController && !this._abortController.signal.aborted) { throw new Error( "StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.", ); } - // Always create a fresh AbortController when starting a new connection sequence. this._abortController = new AbortController(); } From 8bbb18545917032647306318e0ca3b4e1125d06e Mon Sep 17 00:00:00 2001 From: Jesse Date: Thu, 3 Jul 2025 17:09:05 -0600 Subject: [PATCH 4/8] refactor: remove unused edits --- src/client/streamableHttp.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 725b0457..9b247b27 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -291,8 +291,7 @@ const response = await (this._fetch ?? fetch)(this._url, { // Schedule the reconnection setTimeout(() => { - this._abortController = new AbortController(); - + // Use the last event ID to resume where we left off this._startOrAuthSse(options).catch(error => { this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`)); this._scheduleReconnection(options, attemptCount + 1); @@ -367,11 +366,12 @@ const response = await (this._fetch ?? fetch)(this._url, { } async start() { - if (this._abortController && !this._abortController.signal.aborted) { + if (this._abortController) { throw new Error( "StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.", ); } + this._abortController = new AbortController(); } From ad8264d485c5c30a52b3f9a5a6dd8a4d2b65d481 Mon Sep 17 00:00:00 2001 From: Jesse Date: Thu, 3 Jul 2025 17:10:15 -0600 Subject: [PATCH 5/8] refactor: remove unused edits --- src/client/streamableHttp.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 9b247b27..9459ea3a 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -294,6 +294,7 @@ const response = await (this._fetch ?? fetch)(this._url, { // Use the last event ID to resume where we left off this._startOrAuthSse(options).catch(error => { this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`)); + // Schedule another attempt if this one failed, incrementing the attempt counter this._scheduleReconnection(options, attemptCount + 1); }); }, delay); @@ -360,6 +361,7 @@ const response = await (this._fetch ?? fetch)(this._url, { } } + } }; processStream(); From 2145cf288e7500c42e1c980682852d43664b8742 Mon Sep 17 00:00:00 2001 From: Jesse Date: Thu, 3 Jul 2025 17:11:27 -0600 Subject: [PATCH 6/8] refactor: remove unused edits --- src/client/streamableHttp.test.ts | 1 - src/client/streamableHttp.ts | 1 - 2 files changed, 2 deletions(-) diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 89137abe..92e7fe82 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -406,7 +406,6 @@ describe("StreamableHTTPClientTransport", () => { const headers = fetchCall[1].headers; expect(headers.get("last-event-id")).toBe("test-event-id"); }); - it("should throw error when invalid content-type is received", async () => { // Clear any previous state from other tests diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 9459ea3a..375dc411 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -361,7 +361,6 @@ const response = await (this._fetch ?? fetch)(this._url, { } } - } }; processStream(); From d72e33d46a57bddf14dccbe06c579ede4f391e63 Mon Sep 17 00:00:00 2001 From: Jesse Date: Tue, 8 Jul 2025 07:38:46 -0600 Subject: [PATCH 7/8] fix: only resume GET SSE stream connections --- src/client/streamableHttp.test.ts | 132 ++++++++++++++++++------------ src/client/streamableHttp.ts | 16 +++- 2 files changed, 90 insertions(+), 58 deletions(-) diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 92e7fe82..c2c48366 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,6 +1,6 @@ import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js"; import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { JSONRPCMessage } from "../types.js"; +import { JSONRPCMessage, JSONRPCRequest } from "../types.js"; describe("StreamableHTTPClientTransport", () => { @@ -592,84 +592,108 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); - - describe('Reconnection Logic', () => { + let transport: StreamableHTTPClientTransport; + // Use fake timers to control setTimeout and make the test instant. beforeEach(() => jest.useFakeTimers()); afterEach(() => jest.useRealTimers()); - - it('should reconnect on stream failure even without a lastEventId', async () => { + + it('should reconnect a GET-initiated notification stream that fails', async () => { // ARRANGE - - // 1. Configure a transport that will retry quickly and at least once. transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { reconnectionOptions: { - initialReconnectionDelay: 10, // Reconnect almost instantly for the test - maxReconnectionDelay: 100, - reconnectionDelayGrowFactor: 1, - maxRetries: 1, // We only need to see one successful retry attempt - } + initialReconnectionDelay: 10, + maxRetries: 1, + maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely + reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity + } }); - + const errorSpy = jest.fn(); transport.onerror = errorSpy; - - // 2. Mock the initial GET request. It will connect, but the stream will die immediately. - // This simulates the GCloud proxy killing the connection. + const failingStream = new ReadableStream({ - start(controller) { - // Simulate an abrupt network error. - controller.error(new Error("Network connection terminated")); - } + start(controller) { controller.error(new Error("Network failure")); } }); - + const fetchMock = global.fetch as jest.Mock; + // Mock the initial GET request, which will fail. fetchMock.mockResolvedValueOnce({ - ok: true, - status: 200, + ok: true, status: 200, headers: new Headers({ "content-type": "text/event-stream" }), body: failingStream, }); - - // 3. Mock the SECOND GET request (the reconnection attempt). This one can succeed. + // Mock the reconnection GET request, which will succeed. fetchMock.mockResolvedValueOnce({ - ok: true, - status: 200, + ok: true, status: 200, headers: new Headers({ "content-type": "text/event-stream" }), - body: new ReadableStream(), // A stable, empty stream + body: new ReadableStream(), }); - + // ACT - - // 4. Start the transport and initiate the SSE connection. await transport.start(); - // We call the internal method directly to trigger the GET request. - // This is cleaner than sending a full 'initialize' message for this test. + // Trigger the GET stream directly using the internal method for a clean test. await transport["_startOrAuthSse"]({}); - - // 5. Advance timers to trigger the setTimeout in _scheduleReconnection. - await jest.advanceTimersByTimeAsync(20); // More than the 10ms delay - + await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout + // ASSERT - - // 6. Verify the initial disconnect error was caught. - expect(errorSpy).toHaveBeenCalledTimes(1); - expect(errorSpy).toHaveBeenCalledWith( - expect.objectContaining({ - message: expect.stringContaining('SSE stream disconnected: Error: Network connection terminated'), - }) - ); - - // 7. THIS IS THE KEY ASSERTION: Verify that a second fetch call was made. - // This proves the reconnection logic was triggered. + expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), + })); + // THE KEY ASSERTION: A second fetch call proves reconnection was attempted. expect(fetchMock).toHaveBeenCalledTimes(2); - - // 8. Verify the second call was a GET request without a last-event-id header. - const secondCall = fetchMock.mock.calls[1]; - const secondRequest = secondCall[1]; - expect(secondRequest.method).toBe('GET'); - expect(secondRequest.headers.has('last-event-id')).toBe(false); + expect(fetchMock.mock.calls[0][1]?.method).toBe('GET'); + expect(fetchMock.mock.calls[1][1]?.method).toBe('GET'); + }); + + it('should NOT reconnect a POST-initiated stream that fails', async () => { + // ARRANGE + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + reconnectionOptions: { + initialReconnectionDelay: 10, + maxRetries: 1, + maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely + reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + const failingStream = new ReadableStream({ + start(controller) { controller.error(new Error("Network failure")); } + }); + + const fetchMock = global.fetch as jest.Mock; + // Mock the POST request. It returns a streaming content-type but a failing body. + fetchMock.mockResolvedValueOnce({ + ok: true, status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: failingStream, + }); + + // A dummy request message to trigger the `send` logic. + const requestMessage: JSONRPCRequest = { + jsonrpc: '2.0', + method: 'long_running_tool', + id: 'request-1', + params: {}, + }; + + // ACT + await transport.start(); + // Use the public `send` method to initiate a POST that gets a stream response. + await transport.send(requestMessage); + await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections + + // ASSERT + expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), + })); + // THE KEY ASSERTION: Fetch was only called ONCE. No reconnection was attempted. + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock.mock.calls[0][1]?.method).toBe('POST'); }); }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 375dc411..b0894fce 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -231,7 +231,7 @@ const response = await (this._fetch ?? fetch)(this._url, { ); } - this._handleSseStream(response.body, options); + this._handleSseStream(response.body, options, true); } catch (error) { this.onerror?.(error as Error); throw error; @@ -300,7 +300,11 @@ const response = await (this._fetch ?? fetch)(this._url, { }, delay); } - private _handleSseStream(stream: ReadableStream | null, options: StartSSEOptions): void { + private _handleSseStream( + stream: ReadableStream | null, + options: StartSSEOptions, + isReconnectable: boolean, + ): void { if (!stream) { return; } @@ -347,7 +351,11 @@ const response = await (this._fetch ?? fetch)(this._url, { this.onerror?.(new Error(`SSE stream disconnected: ${error}`)); // Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing - if (this._abortController && !this._abortController.signal.aborted) { + if ( + isReconnectable && + this._abortController && + !this._abortController.signal.aborted + ) { // Use the exponential backoff reconnection strategy try { this._scheduleReconnection({ @@ -471,7 +479,7 @@ const response = await (this._fetch ?? fetch)(this._url, init); // Handle SSE stream responses for requests // We use the same handler as standalone streams, which now supports // reconnection with the last event ID - this._handleSseStream(response.body, { onresumptiontoken }); + this._handleSseStream(response.body, { onresumptiontoken }, false); } else if (contentType?.includes("application/json")) { // For non-streaming servers, we might get direct JSON responses const data = await response.json(); From f44f960e02d7cfebfc07f58faba06abeaad6f372 Mon Sep 17 00:00:00 2001 From: Jesse Date: Tue, 15 Jul 2025 07:58:38 -0600 Subject: [PATCH 8/8] Merge branch '7-3-reconnection-bug' of https://github.com/jneums/typescript-sdk into 7-3-reconnection-bug --- .gitignore | 3 + README.md | 80 +- package-lock.json | 4 +- package.json | 5 +- src/cli.ts | 6 +- src/client/auth.test.ts | 695 +++++++++++++++-- src/client/auth.ts | 434 +++++++++-- src/client/cross-spawn.test.ts | 27 +- src/client/sse.test.ts | 196 +++++ src/client/sse.ts | 23 +- src/client/stdio.test.ts | 11 +- src/client/stdio.ts | 6 +- src/client/streamableHttp.test.ts | 2 + .../server/demoInMemoryOAuthProvider.ts | 6 +- .../server/jsonResponseStreamableHttp.ts | 13 +- src/examples/server/mcpServerOutputSchema.ts | 11 +- src/examples/server/simpleSseServer.ts | 6 +- .../server/simpleStatelessStreamableHttp.ts | 13 +- src/examples/server/simpleStreamableHttp.ts | 18 +- .../sseAndStreamableHttpCompatibleServer.ts | 13 +- .../standaloneSseWithGetStreamableHttp.ts | 6 +- src/server/auth/clients.ts | 2 +- src/server/auth/errors.ts | 104 +-- src/server/auth/handlers/register.test.ts | 20 + src/server/auth/handlers/register.ts | 18 +- src/server/auth/handlers/token.test.ts | 56 +- src/server/auth/middleware/bearerAuth.test.ts | 47 +- src/server/auth/middleware/bearerAuth.ts | 6 +- src/server/auth/router.ts | 2 +- src/server/mcp.test.ts | 2 +- src/server/mcp.ts | 68 +- src/server/streamableHttp.test.ts | 376 +++++++++- src/server/streamableHttp.ts | 21 +- src/shared/auth.ts | 1 + src/shared/protocol.test.ts | 183 +++++ src/shared/protocol.ts | 49 ++ src/spec.types.test.ts | 705 ++++++++++++++++++ 37 files changed, 2933 insertions(+), 305 deletions(-) create mode 100644 src/spec.types.test.ts diff --git a/.gitignore b/.gitignore index 6c4bf1a6..694735b6 100644 --- a/.gitignore +++ b/.gitignore @@ -69,6 +69,9 @@ web_modules/ # Output of 'npm pack' *.tgz +# Output of 'npm run fetch:spec-types' +spec.types.ts + # Yarn Integrity file .yarn-integrity diff --git a/README.md b/README.md index ac10e8cb..4684c67c 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ The Model Context Protocol allows applications to provide context for LLMs in a npm install @modelcontextprotocol/sdk ``` -> ⚠️ MCP requires Node v18.x up to work fine. +> ⚠️ MCP requires Node.js v18.x or higher to work fine. ## Quick Start @@ -570,20 +570,31 @@ app.listen(3000); ``` > [!TIP] -> When using this in a remote environment, make sure to allow the header parameter `mcp-session-id` in CORS. Otherwise, it may result in a `Bad Request: No valid session ID provided` error. -> -> For example, in Node.js you can configure it like this: -> -> ```ts -> app.use( -> cors({ -> origin: ['https://your-remote-domain.com, https://your-other-remote-domain.com'], -> exposedHeaders: ['mcp-session-id'], -> allowedHeaders: ['Content-Type', 'mcp-session-id'], -> }) -> ); +> When using this in a remote environment, make sure to allow the header parameter `mcp-session-id` in CORS. Otherwise, it may result in a `Bad Request: No valid session ID provided` error. Read the following section for examples. > ``` + +#### CORS Configuration for Browser-Based Clients + +If you'd like your server to be accessible by browser-based MCP clients, you'll need to configure CORS headers. The `Mcp-Session-Id` header must be exposed for browser clients to access it: + +```typescript +import cors from 'cors'; + +// Add CORS middleware before your MCP routes +app.use(cors({ + origin: '*', // Configure appropriately for production, for example: + // origin: ['https://your-remote-domain.com', 'https://your-other-remote-domain.com'], + exposedHeaders: ['Mcp-Session-Id'], + allowedHeaders: ['Content-Type', 'mcp-session-id'], +})); +``` + +This configuration is necessary because: +- The MCP streamable HTTP transport uses the `Mcp-Session-Id` header for session management +- Browsers restrict access to response headers unless explicitly exposed via CORS +- Without this configuration, browser-based clients won't be able to read the session ID from initialization responses + #### Without Session Management (Stateless) For simpler use cases where session management isn't needed: @@ -865,7 +876,7 @@ const putMessageTool = server.tool( "putMessage", { channel: z.string(), message: z.string() }, async ({ channel, message }) => ({ - content: [{ type: "text", text: await putMessage(channel, string) }] + content: [{ type: "text", text: await putMessage(channel, message) }] }) ); // Until we upgrade auth, `putMessage` is disabled (won't show up in listTools) @@ -873,7 +884,7 @@ putMessageTool.disable() const upgradeAuthTool = server.tool( "upgradeAuth", - { permission: z.enum(["write', admin"])}, + { permission: z.enum(["write", "admin"])}, // Any mutations here will automatically emit `listChanged` notifications async ({ permission }) => { const { ok, err, previous } = await upgradeAuthAndStoreToken(permission) @@ -902,6 +913,43 @@ const transport = new StdioServerTransport(); await server.connect(transport); ``` +### Improving Network Efficiency with Notification Debouncing + +When performing bulk updates that trigger notifications (e.g., enabling or disabling multiple tools in a loop), the SDK can send a large number of messages in a short period. To improve performance and reduce network traffic, you can enable notification debouncing. + +This feature coalesces multiple, rapid calls for the same notification type into a single message. For example, if you disable five tools in a row, only one `notifications/tools/list_changed` message will be sent instead of five. + +> [!IMPORTANT] +> This feature is designed for "simple" notifications that do not carry unique data in their parameters. To prevent silent data loss, debouncing is **automatically bypassed** for any notification that contains a `params` object or a `relatedRequestId`. Such notifications will always be sent immediately. + +This is an opt-in feature configured during server initialization. + +```typescript +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; + +const server = new McpServer( + { + name: "efficient-server", + version: "1.0.0" + }, + { + // Enable notification debouncing for specific methods + debouncedNotificationMethods: [ + 'notifications/tools/list_changed', + 'notifications/resources/list_changed', + 'notifications/prompts/list_changed' + ] + } +); + +// Now, any rapid changes to tools, resources, or prompts will result +// in a single, consolidated notification for each type. +server.registerTool("tool1", ...).disable(); +server.registerTool("tool2", ...).disable(); +server.registerTool("tool3", ...).disable(); +// Only one 'notifications/tools/list_changed' is sent. +``` + ### Low-Level Server For more control, you can use the low-level Server class directly: @@ -1164,7 +1212,7 @@ This setup allows you to: ### Backwards Compatibility -Clients and servers with StreamableHttp tranport can maintain [backwards compatibility](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#backwards-compatibility) with the deprecated HTTP+SSE transport (from protocol version 2024-11-05) as follows +Clients and servers with StreamableHttp transport can maintain [backwards compatibility](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#backwards-compatibility) with the deprecated HTTP+SSE transport (from protocol version 2024-11-05) as follows #### Client-Side Compatibility diff --git a/package-lock.json b/package-lock.json index f2c8cbfa..01bc0953 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.14.0", + "version": "1.15.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.14.0", + "version": "1.15.0", "license": "MIT", "dependencies": { "ajv": "^6.12.6", diff --git a/package.json b/package.json index 15b7753b..894081d7 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.14.0", + "version": "1.15.1", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)", @@ -35,6 +35,7 @@ "dist" ], "scripts": { + "fetch:spec-types": "curl -o spec.types.ts https://raw.githubusercontent.com/modelcontextprotocol/modelcontextprotocol/refs/heads/main/schema/draft/schema.ts", "build": "npm run build:esm && npm run build:cjs", "build:esm": "mkdir -p dist/esm && echo '{\"type\": \"module\"}' > dist/esm/package.json && tsc -p tsconfig.prod.json", "build:esm:w": "npm run build:esm -- -w", @@ -43,7 +44,7 @@ "examples:simple-server:w": "tsx --watch src/examples/server/simpleStreamableHttp.ts --oauth", "prepack": "npm run build:esm && npm run build:cjs", "lint": "eslint src/", - "test": "jest", + "test": "npm run fetch:spec-types && jest", "start": "npm run server", "server": "tsx watch --clear-screen=false src/cli.ts server", "client": "tsx src/cli.ts client" diff --git a/src/cli.ts b/src/cli.ts index b5000896..f580a624 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -102,7 +102,11 @@ async function runServer(port: number | null) { await transport.handlePostMessage(req, res); }); - app.listen(port, () => { + app.listen(port, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`Server running on http://localhost:${port}/sse`); }); } else { diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 8155e134..ce0cc708 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -10,6 +10,8 @@ import { auth, type OAuthClientProvider, } from "./auth.js"; +import {ServerError} from "../server/auth/errors.js"; +import { OAuthMetadata } from '../shared/auth.js'; // Mock fetch globally const mockFetch = jest.fn(); @@ -177,6 +179,174 @@ describe("OAuth Authorization", () => { await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com")) .rejects.toThrow(); }); + + it("returns metadata when discovery succeeds with path", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name"); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url] = calls[0]; + expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path/name"); + }); + + it("preserves query parameters in path-aware discovery", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/path?param=value"); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url] = calls[0]; + expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path?param=value"); + }); + + it("falls back to root discovery when path-aware discovery returns 404", async () => { + // First call (path-aware) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + // Second call (root fallback) succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name"); + expect(metadata).toEqual(validMetadata); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + + // First call should be path-aware + const [firstUrl, firstOptions] = calls[0]; + expect(firstUrl.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path/name"); + expect(firstOptions.headers).toEqual({ + "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION + }); + + // Second call should be root fallback + const [secondUrl, secondOptions] = calls[1]; + expect(secondUrl.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); + expect(secondOptions.headers).toEqual({ + "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION + }); + }); + + it("throws error when both path-aware and root discovery return 404", async () => { + // First call (path-aware) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + // Second call (root fallback) also returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name")) + .rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + }); + + it("does not fallback when the original URL is already at root path", async () => { + // First call (path-aware for root) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/")) + .rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + + const [url] = calls[0]; + expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); + }); + + it("does not fallback when the original URL has no path", async () => { + // First call (path-aware for no path) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com")) + .rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + + const [url] = calls[0]; + expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); + }); + + it("falls back when path-aware discovery encounters CORS error", async () => { + // First call (path-aware) fails with TypeError (CORS) + mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError("CORS error"))); + + // Retry path-aware without headers (simulating CORS retry) + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + // Second call (root fallback) succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/deep/path"); + expect(metadata).toEqual(validMetadata); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(3); + + // Final call should be root fallback + const [lastUrl, lastOptions] = calls[2]; + expect(lastUrl.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); + expect(lastOptions.headers).toEqual({ + "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION + }); + }); + + it("does not fallback when resourceMetadataUrl is provided", async () => { + // Call with explicit URL returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path", { + resourceMetadataUrl: "https://custom.example.com/metadata" + })).rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback when explicit URL is provided + + const [url] = calls[0]; + expect(url.toString()).toBe("https://custom.example.com/metadata"); + }); }); describe("discoverOAuthMetadata", () => { @@ -231,7 +401,7 @@ describe("OAuth Authorization", () => { ok: false, status: 404, }); - + // Second call (root fallback) succeeds mockFetch.mockResolvedValueOnce({ ok: true, @@ -241,17 +411,17 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name"); expect(metadata).toEqual(validMetadata); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(2); - + // First call should be path-aware const [firstUrl, firstOptions] = calls[0]; expect(firstUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/path/name"); expect(firstOptions.headers).toEqual({ "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION }); - + // Second call should be root fallback const [secondUrl, secondOptions] = calls[1]; expect(secondUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); @@ -266,7 +436,7 @@ describe("OAuth Authorization", () => { ok: false, status: 404, }); - + // Second call (root fallback) also returns 404 mockFetch.mockResolvedValueOnce({ ok: false, @@ -275,7 +445,7 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name"); expect(metadata).toBeUndefined(); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(2); }); @@ -289,10 +459,10 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/"); expect(metadata).toBeUndefined(); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(1); // Should not attempt fallback - + const [url] = calls[0]; expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); }); @@ -306,10 +476,10 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com"); expect(metadata).toBeUndefined(); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(1); // Should not attempt fallback - + const [url] = calls[0]; expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); }); @@ -317,13 +487,13 @@ describe("OAuth Authorization", () => { it("falls back when path-aware discovery encounters CORS error", async () => { // First call (path-aware) fails with TypeError (CORS) mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError("CORS error"))); - + // Retry path-aware without headers (simulating CORS retry) mockFetch.mockResolvedValueOnce({ ok: false, status: 404, }); - + // Second call (root fallback) succeeds mockFetch.mockResolvedValueOnce({ ok: true, @@ -333,10 +503,10 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/deep/path"); expect(metadata).toEqual(validMetadata); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(3); - + // Final call should be root fallback const [lastUrl, lastOptions] = calls[2]; expect(lastUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); @@ -427,10 +597,7 @@ describe("OAuth Authorization", () => { }); it("throws on non-404 errors", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 500, - }); + mockFetch.mockResolvedValueOnce(new Response(null, { status: 500 })); await expect( discoverOAuthMetadata("https://auth.example.com") @@ -438,14 +605,15 @@ describe("OAuth Authorization", () => { }); it("validates metadata schema", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - // Missing required fields - issuer: "https://auth.example.com", - }), - }); + mockFetch.mockResolvedValueOnce( + Response.json( + { + // Missing required fields + issuer: "https://auth.example.com", + }, + { status: 200 } + ) + ); await expect( discoverOAuthMetadata("https://auth.example.com") @@ -545,6 +713,20 @@ describe("OAuth Authorization", () => { expect(authorizationUrl.searchParams.has("state")).toBe(false); }); + // OpenID Connect requires that the user is prompted for consent if the scope includes 'offline_access' + it("includes consent prompt parameter if scope includes 'offline_access'", async () => { + const { authorizationUrl } = await startAuthorization( + "https://auth.example.com", + { + clientInformation: validClientInfo, + redirectUrl: "http://localhost:3000/callback", + scope: "read write profile offline_access", + } + ); + + expect(authorizationUrl.searchParams.get("prompt")).toBe("consent"); + }); + it("uses metadata authorization_endpoint when provided", async () => { const { authorizationUrl } = await startAuthorization( "https://auth.example.com", @@ -600,6 +782,13 @@ describe("OAuth Authorization", () => { refresh_token: "refresh123", }; + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"] + }; + const validClientInfo = { client_id: "client123", client_secret: "secret123", @@ -629,9 +818,9 @@ describe("OAuth Authorization", () => { }), expect.objectContaining({ method: "POST", - headers: { + headers: new Headers({ "Content-Type": "application/x-www-form-urlencoded", - }, + }), }) ); @@ -645,6 +834,52 @@ describe("OAuth Authorization", () => { expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); }); + it("exchanges code for tokens with auth", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: validMetadata, + clientInformation: validClientInfo, + authorizationCode: "code123", + codeVerifier: "verifier123", + redirectUri: "http://localhost:3000/callback", + addClientAuthentication: (headers: Headers, params: URLSearchParams, url: string | URL, metadata: OAuthMetadata) => { + headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret)); + params.set("example_url", typeof url === 'string' ? url : url.toString()); + params.set("example_metadata", metadata.authorization_endpoint); + params.set("example_param", "example_value"); + }, + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + }) + ); + + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); + expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw=="); + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("authorization_code"); + expect(body.get("code")).toBe("code123"); + expect(body.get("code_verifier")).toBe("verifier123"); + expect(body.get("client_id")).toBeNull(); + expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); + expect(body.get("example_url")).toBe("https://auth.example.com"); + expect(body.get("example_metadata")).toBe("https://auth.example.com/authorize"); + expect(body.get("example_param")).toBe("example_value"); + expect(body.get("client_secret")).toBeNull(); + }); + it("validates token response schema", async () => { mockFetch.mockResolvedValueOnce({ ok: true, @@ -666,10 +901,12 @@ describe("OAuth Authorization", () => { }); it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 400, - }); + mockFetch.mockResolvedValueOnce( + Response.json( + new ServerError("Token exchange failed").toResponseObject(), + { status: 400 } + ) + ); await expect( exchangeAuthorization("https://auth.example.com", { @@ -693,6 +930,13 @@ describe("OAuth Authorization", () => { refresh_token: "newrefresh123", }; + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"] + }; + const validClientInfo = { client_id: "client123", client_secret: "secret123", @@ -720,9 +964,9 @@ describe("OAuth Authorization", () => { }), expect.objectContaining({ method: "POST", - headers: { + headers: new Headers({ "Content-Type": "application/x-www-form-urlencoded", - }, + }), }) ); @@ -734,6 +978,48 @@ describe("OAuth Authorization", () => { expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); }); + it("exchanges refresh token for new tokens with auth", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokensWithNewRefreshToken, + }); + + const tokens = await refreshAuthorization("https://auth.example.com", { + metadata: validMetadata, + clientInformation: validClientInfo, + refreshToken: "refresh123", + addClientAuthentication: (headers: Headers, params: URLSearchParams, url: string | URL, metadata?: OAuthMetadata) => { + headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret)); + params.set("example_url", typeof url === 'string' ? url : url.toString()); + params.set("example_metadata", metadata?.authorization_endpoint ?? '?'); + params.set("example_param", "example_value"); + }, + }); + + expect(tokens).toEqual(validTokensWithNewRefreshToken); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + }) + ); + + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); + expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw=="); + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("refresh_token"); + expect(body.get("refresh_token")).toBe("refresh123"); + expect(body.get("client_id")).toBeNull(); + expect(body.get("example_url")).toBe("https://auth.example.com"); + expect(body.get("example_metadata")).toBe("https://auth.example.com/authorize"); + expect(body.get("example_param")).toBe("example_value"); + expect(body.get("client_secret")).toBeNull(); + }); + it("exchanges refresh token for new tokens and keep existing refresh token if none is returned", async () => { mockFetch.mockResolvedValueOnce({ ok: true, @@ -769,10 +1055,12 @@ describe("OAuth Authorization", () => { }); it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 400, - }); + mockFetch.mockResolvedValueOnce( + Response.json( + new ServerError("Token refresh failed").toResponseObject(), + { status: 400 } + ) + ); await expect( refreshAuthorization("https://auth.example.com", { @@ -857,10 +1145,12 @@ describe("OAuth Authorization", () => { }); it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 400, - }); + mockFetch.mockResolvedValueOnce( + Response.json( + new ServerError("Dynamic client registration failed").toResponseObject(), + { status: 400 } + ) + ); await expect( registerClient("https://auth.example.com", { @@ -1476,5 +1766,326 @@ describe("OAuth Authorization", () => { expect(body.get("grant_type")).toBe("refresh_token"); expect(body.get("refresh_token")).toBe("refresh123"); }); + + it("fetches AS metadata with path from serverUrl when PRM returns external AS", async () => { + // Mock PRM discovery that returns an external AS + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString === "https://my.resource.com/.well-known/oauth-protected-resource/path/name") { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://my.resource.com/", + authorization_servers: ["https://auth.example.com/oauth"], + }), + }); + } else if (urlString === "https://auth.example.com/.well-known/oauth-authorization-server/path/name") { + // Path-aware discovery on AS with path from serverUrl + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth with serverUrl that has a path + const result = await auth(mockProvider, { + serverUrl: "https://my.resource.com/path/name", + }); + + expect(result).toBe("REDIRECT"); + + // Verify the correct URLs were fetched + const calls = mockFetch.mock.calls; + + // First call should be to PRM + expect(calls[0][0].toString()).toBe("https://my.resource.com/.well-known/oauth-protected-resource/path/name"); + + // Second call should be to AS metadata with the path from authorization server + expect(calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/oauth"); + }); }); + + describe("exchangeAuthorization with multiple client authentication methods", () => { + const validTokens = { + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + const metadataWithBasicOnly = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/auth", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + token_endpoint_auth_methods_supported: ["client_secret_basic"], + }; + + const metadataWithPostOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ["client_secret_post"], + }; + + const metadataWithNoneOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ["none"], + }; + + const metadataWithAllBuiltinMethods = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ["client_secret_basic", "client_secret_post", "none"], + }; + + it("uses HTTP Basic authentication when client_secret_basic is supported", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: metadataWithBasicOnly, + clientInformation: validClientInfo, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header + const authHeader = request.headers.get("Authorization"); + const expected = "Basic " + btoa("client123:secret123"); + expect(authHeader).toBe(expected); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBeNull(); + expect(body.get("client_secret")).toBeNull(); + }); + + it("includes credentials in request body when client_secret_post is supported", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: metadataWithPostOnly, + clientInformation: validClientInfo, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check no Authorization header + expect(request.headers.get("Authorization")).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + }); + + it("it picks client_secret_basic when all builtin methods are supported", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: metadataWithAllBuiltinMethods, + clientInformation: validClientInfo, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header - should use Basic auth as it's the most secure + const authHeader = request.headers.get("Authorization"); + const expected = "Basic " + btoa("client123:secret123"); + expect(authHeader).toBe(expected); + + // Credentials should not be in body when using Basic auth + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBeNull(); + expect(body.get("client_secret")).toBeNull(); + }); + + it("uses public client authentication when none method is specified", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const clientInfoWithoutSecret = { + client_id: "client123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: metadataWithNoneOnly, + clientInformation: clientInfoWithoutSecret, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check no Authorization header + expect(request.headers.get("Authorization")).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBeNull(); + }); + + it("defaults to client_secret_post when no auth methods specified", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check headers + expect(request.headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); + expect(request.headers.get("Authorization")).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + }); + }); + + describe("refreshAuthorization with multiple client authentication methods", () => { + const validTokens = { + access_token: "newaccess123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "newrefresh123", + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + const metadataWithBasicOnly = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/auth", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + token_endpoint_auth_methods_supported: ["client_secret_basic"], + }; + + const metadataWithPostOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ["client_secret_post"], + }; + + it("uses client_secret_basic for refresh token", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await refreshAuthorization("https://auth.example.com", { + metadata: metadataWithBasicOnly, + clientInformation: validClientInfo, + refreshToken: "refresh123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header + const authHeader = request.headers.get("Authorization"); + const expected = "Basic " + btoa("client123:secret123"); + expect(authHeader).toBe(expected); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBeNull(); // should not be in body + expect(body.get("client_secret")).toBeNull(); // should not be in body + expect(body.get("refresh_token")).toBe("refresh123"); + }); + + it("uses client_secret_post for refresh token", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await refreshAuthorization("https://auth.example.com", { + metadata: metadataWithPostOnly, + clientInformation: validClientInfo, + refreshToken: "refresh123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check no Authorization header + expect(request.headers.get("Authorization")).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + expect(body.get("refresh_token")).toBe("refresh123"); + }); + }); + }); diff --git a/src/client/auth.ts b/src/client/auth.ts index 71101a42..4a8bbe2d 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -1,8 +1,24 @@ import pkceChallenge from "pkce-challenge"; import { LATEST_PROTOCOL_VERSION } from "../types.js"; -import type { OAuthClientMetadata, OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, OAuthProtectedResourceMetadata } from "../shared/auth.js"; +import { + OAuthClientMetadata, + OAuthClientInformation, + OAuthTokens, + OAuthMetadata, + OAuthClientInformationFull, + OAuthProtectedResourceMetadata, + OAuthErrorResponseSchema +} from "../shared/auth.js"; import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthProtectedResourceMetadataSchema, OAuthTokensSchema } from "../shared/auth.js"; import { checkResourceAllowed, resourceUrlFromServerUrl } from "../shared/auth-utils.js"; +import { + InvalidClientError, + InvalidGrantError, + OAUTH_ERRORS, + OAuthError, + ServerError, + UnauthorizedClientError +} from "../server/auth/errors.js"; /** * Implements an end-to-end OAuth client to be used with one MCP server. @@ -73,6 +89,26 @@ export interface OAuthClientProvider { */ codeVerifier(): string | Promise; + /** + * Adds custom client authentication to OAuth token requests. + * + * This optional method allows implementations to customize how client credentials + * are included in token exchange and refresh requests. When provided, this method + * is called instead of the default authentication logic, giving full control over + * the authentication mechanism. + * + * Common use cases include: + * - Supporting authentication methods beyond the standard OAuth 2.0 methods + * - Adding custom headers for proprietary authentication schemes + * - Implementing client assertion-based authentication (e.g., JWT bearer tokens) + * + * @param headers - The request headers (can be modified to add authentication) + * @param params - The request body parameters (can be modified to add credentials) + * @param url - The token endpoint URL being called + * @param metadata - Optional OAuth metadata for the server, which may include supported authentication methods + */ + addClientAuthentication?(headers: Headers, params: URLSearchParams, url: string | URL, metadata?: OAuthMetadata): void | Promise; + /** * If defined, overrides the selection and validation of the * RFC 8707 Resource Indicator. If left undefined, default @@ -81,6 +117,13 @@ export interface OAuthClientProvider { * Implementations must verify the returned resource matches the MCP server. */ validateResourceURL?(serverUrl: string | URL, resource?: string): Promise; + + /** + * If implemented, provides a way for the client to invalidate (e.g. delete) the specified + * credentials, in the case where the server has indicated that they are no longer valid. + * This avoids requiring the user to intervene manually. + */ + invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise; } export type AuthResult = "AUTHORIZED" | "REDIRECT"; @@ -91,6 +134,141 @@ export class UnauthorizedError extends Error { } } +type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none'; + +/** + * Determines the best client authentication method to use based on server support and client configuration. + * + * Priority order (highest to lowest): + * 1. client_secret_basic (if client secret is available) + * 2. client_secret_post (if client secret is available) + * 3. none (for public clients) + * + * @param clientInformation - OAuth client information containing credentials + * @param supportedMethods - Authentication methods supported by the authorization server + * @returns The selected authentication method + */ +function selectClientAuthMethod( + clientInformation: OAuthClientInformation, + supportedMethods: string[] +): ClientAuthMethod { + const hasClientSecret = clientInformation.client_secret !== undefined; + + // If server doesn't specify supported methods, use RFC 6749 defaults + if (supportedMethods.length === 0) { + return hasClientSecret ? "client_secret_post" : "none"; + } + + // Try methods in priority order (most secure first) + if (hasClientSecret && supportedMethods.includes("client_secret_basic")) { + return "client_secret_basic"; + } + + if (hasClientSecret && supportedMethods.includes("client_secret_post")) { + return "client_secret_post"; + } + + if (supportedMethods.includes("none")) { + return "none"; + } + + // Fallback: use what we have + return hasClientSecret ? "client_secret_post" : "none"; +} + +/** + * Applies client authentication to the request based on the specified method. + * + * Implements OAuth 2.1 client authentication methods: + * - client_secret_basic: HTTP Basic authentication (RFC 6749 Section 2.3.1) + * - client_secret_post: Credentials in request body (RFC 6749 Section 2.3.1) + * - none: Public client authentication (RFC 6749 Section 2.1) + * + * @param method - The authentication method to use + * @param clientInformation - OAuth client information containing credentials + * @param headers - HTTP headers object to modify + * @param params - URL search parameters to modify + * @throws {Error} When required credentials are missing + */ +function applyClientAuthentication( + method: ClientAuthMethod, + clientInformation: OAuthClientInformation, + headers: Headers, + params: URLSearchParams +): void { + const { client_id, client_secret } = clientInformation; + + switch (method) { + case "client_secret_basic": + applyBasicAuth(client_id, client_secret, headers); + return; + case "client_secret_post": + applyPostAuth(client_id, client_secret, params); + return; + case "none": + applyPublicAuth(client_id, params); + return; + default: + throw new Error(`Unsupported client authentication method: ${method}`); + } +} + +/** + * Applies HTTP Basic authentication (RFC 6749 Section 2.3.1) + */ +function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { + if (!clientSecret) { + throw new Error("client_secret_basic authentication requires a client_secret"); + } + + const credentials = btoa(`${clientId}:${clientSecret}`); + headers.set("Authorization", `Basic ${credentials}`); +} + +/** + * Applies POST body authentication (RFC 6749 Section 2.3.1) + */ +function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { + params.set("client_id", clientId); + if (clientSecret) { + params.set("client_secret", clientSecret); + } +} + +/** + * Applies public client authentication (RFC 6749 Section 2.1) + */ +function applyPublicAuth(clientId: string, params: URLSearchParams): void { + params.set("client_id", clientId); +} + +/** + * Parses an OAuth error response from a string or Response object. + * + * If the input is a standard OAuth2.0 error response, it will be parsed according to the spec + * and an instance of the appropriate OAuthError subclass will be returned. + * If parsing fails, it falls back to a generic ServerError that includes + * the response status (if available) and original content. + * + * @param input - A Response object or string containing the error response + * @returns A Promise that resolves to an OAuthError instance + */ +export async function parseErrorResponse(input: Response | string): Promise { + const statusCode = input instanceof Response ? input.status : undefined; + const body = input instanceof Response ? await input.text() : input; + + try { + const result = OAuthErrorResponseSchema.parse(JSON.parse(body)); + const { error, error_description, error_uri } = result; + const errorClass = OAUTH_ERRORS[error] || ServerError; + return new errorClass(error_description || '', error_uri); + } catch (error) { + // Not a valid OAuth error response, but try to inform the user of the raw data anyway + const errorMessage = `${statusCode ? `HTTP ${statusCode}: ` : ''}Invalid OAuth error response: ${error}. Raw body: ${body}`; + return new ServerError(errorMessage); + } +} + /** * Orchestrates the full auth flow with a server. * @@ -98,6 +276,31 @@ export class UnauthorizedError extends Error { * instead of linking together the other lower-level functions in this module. */ export async function auth( + provider: OAuthClientProvider, + options: { + serverUrl: string | URL; + authorizationCode?: string; + scope?: string; + resourceMetadataUrl?: URL }): Promise { + + try { + return await authInternal(provider, options); + } catch (error) { + // Handle recoverable error types by invalidating credentials and retrying + if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) { + await provider.invalidateCredentials?.('all'); + return await authInternal(provider, options); + } else if (error instanceof InvalidGrantError) { + await provider.invalidateCredentials?.('tokens'); + return await authInternal(provider, options); + } + + // Throw otherwise + throw error + } +} + +async function authInternal( provider: OAuthClientProvider, { serverUrl, authorizationCode, @@ -107,12 +310,13 @@ export async function auth( serverUrl: string | URL; authorizationCode?: string; scope?: string; - resourceMetadataUrl?: URL }): Promise { + resourceMetadataUrl?: URL + }): Promise { let resourceMetadata: OAuthProtectedResourceMetadata | undefined; let authorizationServerUrl = serverUrl; try { - resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {resourceMetadataUrl}); + resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }); if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) { authorizationServerUrl = resourceMetadata.authorization_servers[0]; } @@ -122,7 +326,9 @@ export async function auth( const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata); - const metadata = await discoverOAuthMetadata(authorizationServerUrl); + const metadata = await discoverOAuthMetadata(serverUrl, { + authorizationServerUrl + }); // Handle client registration if needed let clientInformation = await Promise.resolve(provider.clientInformation()); @@ -154,10 +360,11 @@ export async function auth( codeVerifier, redirectUri: provider.redirectUrl, resource, + addClientAuthentication: provider.addClientAuthentication, }); await provider.saveTokens(tokens); - return "AUTHORIZED"; + return "AUTHORIZED" } const tokens = await provider.tokens(); @@ -171,12 +378,19 @@ export async function auth( clientInformation, refreshToken: tokens.refresh_token, resource, + addClientAuthentication: provider.addClientAuthentication, }); await provider.saveTokens(newTokens); - return "AUTHORIZED"; - } catch { - // Could not refresh OAuth tokens + return "AUTHORIZED" + } catch (error) { + // If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry. + if (!(error instanceof OAuthError) || error instanceof ServerError) { + // Could not refresh OAuth tokens + } else { + // Refresh failed for another reason, re-throw + throw error; + } } } @@ -194,10 +408,10 @@ export async function auth( await provider.saveCodeVerifier(codeVerifier); await provider.redirectToAuthorization(authorizationUrl); - return "REDIRECT"; + return "REDIRECT" } -export async function selectResourceURL(serverUrl: string| URL, provider: OAuthClientProvider, resourceMetadata?: OAuthProtectedResourceMetadata): Promise { +export async function selectResourceURL(serverUrl: string | URL, provider: OAuthClientProvider, resourceMetadata?: OAuthProtectedResourceMetadata): Promise { const defaultResource = resourceUrlFromServerUrl(serverUrl); // If provider has custom validation, delegate to it @@ -256,31 +470,16 @@ export async function discoverOAuthProtectedResourceMetadata( serverUrl: string | URL, opts?: { protocolVersion?: string, resourceMetadataUrl?: string | URL }, ): Promise { + const response = await discoverMetadataWithFallback( + serverUrl, + 'oauth-protected-resource', + { + protocolVersion: opts?.protocolVersion, + metadataUrl: opts?.resourceMetadataUrl, + }, + ); - let url: URL - if (opts?.resourceMetadataUrl) { - url = new URL(opts?.resourceMetadataUrl); - } else { - url = new URL("/.well-known/oauth-protected-resource", serverUrl); - } - - let response: Response; - try { - response = await fetch(url, { - headers: { - "MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION - } - }); - } catch (error) { - // CORS errors come back as TypeError - if (error instanceof TypeError) { - response = await fetch(url); - } else { - throw error; - } - } - - if (response.status === 404) { + if (!response || response.status === 404) { throw new Error(`Resource server does not implement OAuth 2.0 Protected Resource Metadata.`); } @@ -318,8 +517,8 @@ async function fetchWithCorsRetry( /** * Constructs the well-known path for OAuth metadata discovery */ -function buildWellKnownPath(pathname: string): string { - let wellKnownPath = `/.well-known/oauth-authorization-server${pathname}`; +function buildWellKnownPath(wellKnownPrefix: string, pathname: string): string { + let wellKnownPath = `/.well-known/${wellKnownPrefix}${pathname}`; if (pathname.endsWith('/')) { // Strip trailing slash from pathname to avoid double slashes wellKnownPath = wellKnownPath.slice(0, -1); @@ -347,6 +546,38 @@ function shouldAttemptFallback(response: Response | undefined, pathname: string) return !response || response.status === 404 && pathname !== '/'; } +/** + * Generic function for discovering OAuth metadata with fallback support + */ +async function discoverMetadataWithFallback( + serverUrl: string | URL, + wellKnownType: 'oauth-authorization-server' | 'oauth-protected-resource', + opts?: { protocolVersion?: string; metadataUrl?: string | URL, metadataServerUrl?: string | URL }, +): Promise { + const issuer = new URL(serverUrl); + const protocolVersion = opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION; + + let url: URL; + if (opts?.metadataUrl) { + url = new URL(opts.metadataUrl); + } else { + // Try path-aware discovery first + const wellKnownPath = buildWellKnownPath(wellKnownType, issuer.pathname); + url = new URL(wellKnownPath, opts?.metadataServerUrl ?? issuer); + url.search = issuer.search; + } + + let response = await tryMetadataDiscovery(url, protocolVersion); + + // If path-aware discovery fails with 404 and we're not already at root, try fallback to root discovery + if (!opts?.metadataUrl && shouldAttemptFallback(response, issuer.pathname)) { + const rootUrl = new URL(`/.well-known/${wellKnownType}`, issuer); + response = await tryMetadataDiscovery(rootUrl, protocolVersion); + } + + return response; +} + /** * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. * @@ -354,22 +585,35 @@ function shouldAttemptFallback(response: Response | undefined, pathname: string) * return `undefined`. Any other errors will be thrown as exceptions. */ export async function discoverOAuthMetadata( - authorizationServerUrl: string | URL, - opts?: { protocolVersion?: string }, + issuer: string | URL, + { + authorizationServerUrl, + protocolVersion, + }: { + authorizationServerUrl?: string | URL, + protocolVersion?: string, + } = {}, ): Promise { - const issuer = new URL(authorizationServerUrl); - const protocolVersion = opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION; + if (typeof issuer === 'string') { + issuer = new URL(issuer); + } + if (!authorizationServerUrl) { + authorizationServerUrl = issuer; + } + if (typeof authorizationServerUrl === 'string') { + authorizationServerUrl = new URL(authorizationServerUrl); + } + protocolVersion ??= LATEST_PROTOCOL_VERSION; - // Try path-aware discovery first (RFC 8414 compliant) - const wellKnownPath = buildWellKnownPath(issuer.pathname); - const pathAwareUrl = new URL(wellKnownPath, issuer); - let response = await tryMetadataDiscovery(pathAwareUrl, protocolVersion); + const response = await discoverMetadataWithFallback( + authorizationServerUrl, + 'oauth-authorization-server', + { + protocolVersion, + metadataServerUrl: authorizationServerUrl, + }, + ); - // If path-aware discovery fails with 404, try fallback to root discovery - if (shouldAttemptFallback(response, issuer.pathname)) { - const rootUrl = new URL("/.well-known/oauth-authorization-server", issuer); - response = await tryMetadataDiscovery(rootUrl, protocolVersion); - } if (!response || response.status === 404) { return undefined; } @@ -451,6 +695,13 @@ export async function startAuthorization( authorizationUrl.searchParams.set("scope", scope); } + if (scope?.includes("offline_access")) { + // if the request includes the OIDC-only "offline_access" scope, + // we need to set the prompt to "consent" to ensure the user is prompted to grant offline access + // https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess + authorizationUrl.searchParams.append("prompt", "consent"); + } + if (resource) { authorizationUrl.searchParams.set("resource", resource.href); } @@ -460,6 +711,15 @@ export async function startAuthorization( /** * Exchanges an authorization code for an access token with the given server. + * + * Supports multiple client authentication methods as specified in OAuth 2.1: + * - Automatically selects the best authentication method based on server support + * - Falls back to appropriate defaults when server metadata is unavailable + * + * @param authorizationServerUrl - The authorization server's base URL + * @param options - Configuration object containing client info, auth code, etc. + * @returns Promise resolving to OAuth tokens + * @throws {Error} When token exchange fails or authentication is invalid */ export async function exchangeAuthorization( authorizationServerUrl: string | URL, @@ -470,6 +730,7 @@ export async function exchangeAuthorization( codeVerifier, redirectUri, resource, + addClientAuthentication }: { metadata?: OAuthMetadata; clientInformation: OAuthClientInformation; @@ -477,37 +738,43 @@ export async function exchangeAuthorization( codeVerifier: string; redirectUri: string | URL; resource?: URL; + addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; }, ): Promise { const grantType = "authorization_code"; - let tokenUrl: URL; - if (metadata) { - tokenUrl = new URL(metadata.token_endpoint); + const tokenUrl = metadata?.token_endpoint + ? new URL(metadata.token_endpoint) + : new URL("/token", authorizationServerUrl); - if ( - metadata.grant_types_supported && + if ( + metadata?.grant_types_supported && !metadata.grant_types_supported.includes(grantType) - ) { - throw new Error( + ) { + throw new Error( `Incompatible auth server: does not support grant type ${grantType}`, - ); - } - } else { - tokenUrl = new URL("/token", authorizationServerUrl); + ); } // Exchange code for tokens + const headers = new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + }); const params = new URLSearchParams({ grant_type: grantType, - client_id: clientInformation.client_id, code: authorizationCode, code_verifier: codeVerifier, redirect_uri: String(redirectUri), }); - if (clientInformation.client_secret) { - params.set("client_secret", clientInformation.client_secret); + if (addClientAuthentication) { + addClientAuthentication(headers, params, authorizationServerUrl, metadata); + } else { + // Determine and apply client authentication method + const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? []; + const authMethod = selectClientAuthMethod(clientInformation, supportedMethods); + + applyClientAuthentication(authMethod, clientInformation, headers, params); } if (resource) { @@ -516,14 +783,12 @@ export async function exchangeAuthorization( const response = await fetch(tokenUrl, { method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, + headers, body: params, }); if (!response.ok) { - throw new Error(`Token exchange failed: HTTP ${response.status}`); + throw await parseErrorResponse(response); } return OAuthTokensSchema.parse(await response.json()); @@ -531,6 +796,15 @@ export async function exchangeAuthorization( /** * Exchange a refresh token for an updated access token. + * + * Supports multiple client authentication methods as specified in OAuth 2.1: + * - Automatically selects the best authentication method based on server support + * - Preserves the original refresh token if a new one is not returned + * + * @param authorizationServerUrl - The authorization server's base URL + * @param options - Configuration object containing client info, refresh token, etc. + * @returns Promise resolving to OAuth tokens (preserves original refresh_token if not replaced) + * @throws {Error} When token refresh fails or authentication is invalid */ export async function refreshAuthorization( authorizationServerUrl: string | URL, @@ -539,12 +813,14 @@ export async function refreshAuthorization( clientInformation, refreshToken, resource, + addClientAuthentication, }: { metadata?: OAuthMetadata; clientInformation: OAuthClientInformation; refreshToken: string; resource?: URL; - }, + addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; + } ): Promise { const grantType = "refresh_token"; @@ -565,14 +841,22 @@ export async function refreshAuthorization( } // Exchange refresh token + const headers = new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + }); const params = new URLSearchParams({ grant_type: grantType, - client_id: clientInformation.client_id, refresh_token: refreshToken, }); - if (clientInformation.client_secret) { - params.set("client_secret", clientInformation.client_secret); + if (addClientAuthentication) { + addClientAuthentication(headers, params, authorizationServerUrl, metadata); + } else { + // Determine and apply client authentication method + const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? []; + const authMethod = selectClientAuthMethod(clientInformation, supportedMethods); + + applyClientAuthentication(authMethod, clientInformation, headers, params); } if (resource) { @@ -581,13 +865,11 @@ export async function refreshAuthorization( const response = await fetch(tokenUrl, { method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, + headers, body: params, }); if (!response.ok) { - throw new Error(`Token refresh failed: HTTP ${response.status}`); + throw await parseErrorResponse(response); } return OAuthTokensSchema.parse({ refresh_token: refreshToken, ...(await response.json()) }); @@ -627,7 +909,7 @@ export async function registerClient( }); if (!response.ok) { - throw new Error(`Dynamic client registration failed: HTTP ${response.status}`); + throw await parseErrorResponse(response); } return OAuthClientInformationFullSchema.parse(await response.json()); diff --git a/src/client/cross-spawn.test.ts b/src/client/cross-spawn.test.ts index 11e81bf6..8480d94f 100644 --- a/src/client/cross-spawn.test.ts +++ b/src/client/cross-spawn.test.ts @@ -1,4 +1,4 @@ -import { StdioClientTransport } from "./stdio.js"; +import { StdioClientTransport, getDefaultEnvironment } from "./stdio.js"; import spawn from "cross-spawn"; import { JSONRPCMessage } from "../types.js"; import { ChildProcess } from "node:child_process"; @@ -67,12 +67,33 @@ describe("StdioClientTransport using cross-spawn", () => { await transport.start(); - // verify environment variables are passed correctly + // verify environment variables are merged correctly expect(mockSpawn).toHaveBeenCalledWith( "test-command", [], expect.objectContaining({ - env: customEnv + env: { + ...getDefaultEnvironment(), + ...customEnv + } + }) + ); + }); + + test("should use default environment when env is undefined", async () => { + const transport = new StdioClientTransport({ + command: "test-command", + env: undefined + }); + + await transport.start(); + + // verify default environment is used + expect(mockSpawn).toHaveBeenCalledWith( + "test-command", + [], + expect.objectContaining({ + env: getDefaultEnvironment() }) ); }); diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 2d116344..2cc4a1dd 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -4,6 +4,7 @@ import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { OAuthTokens } from "../shared/auth.js"; +import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js"; describe("SSEClientTransport", () => { let resourceServer: Server; @@ -363,6 +364,7 @@ describe("SSEClientTransport", () => { redirectToAuthorization: jest.fn(), saveCodeVerifier: jest.fn(), codeVerifier: jest.fn(), + invalidateCredentials: jest.fn(), }; }); @@ -382,6 +384,29 @@ describe("SSEClientTransport", () => { expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); + it("attaches custom header from provider on initial SSE connection", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + const customHeaders = { + "X-Custom-Header": "custom-value", + }; + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + requestInit: { + headers: customHeaders, + }, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); + it("attaches auth header from provider on POST requests", async () => { mockAuthProvider.tokens.mockResolvedValue({ access_token: "test-token", @@ -911,5 +936,176 @@ describe("SSEClientTransport", () => { await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); }); + + it("invalidates all credentials on InvalidClientError during token refresh", async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }); + + let baseUrl = resourceBaseUrl; + + // Create server that returns InvalidClientError on token refresh + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + })); + return; + } + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - return InvalidClientError + const error = new InvalidClientError("Client authentication failed"); + res.writeHead(400, { 'Content-Type': 'application/json' }) + .end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(InvalidClientError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); + + it("invalidates all credentials on UnauthorizedClientError during token refresh", async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }); + + let baseUrl = resourceBaseUrl; + + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + })); + return; + } + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - return UnauthorizedClientError + const error = new UnauthorizedClientError("Client not authorized"); + res.writeHead(400, { 'Content-Type': 'application/json' }) + .end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(UnauthorizedClientError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); + + it("invalidates tokens on InvalidGrantError during token refresh", async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }); + let baseUrl = resourceBaseUrl; + + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + })); + return; + } + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - return InvalidGrantError + const error = new InvalidGrantError("Invalid refresh token"); + res.writeHead(400, { 'Content-Type': 'application/json' }) + .end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(InvalidGrantError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); + }); }); }); diff --git a/src/client/sse.ts b/src/client/sse.ts index faffecc4..568a5159 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -106,10 +106,8 @@ export class SSEClientTransport implements Transport { return await this._startOrAuth(); } - private async _commonHeaders(): Promise { - const headers = { - ...this._requestInit?.headers, - } as HeadersInit & Record; + private async _commonHeaders(): Promise { + const headers: HeadersInit = {}; if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { @@ -120,24 +118,24 @@ export class SSEClientTransport implements Transport { headers["mcp-protocol-version"] = this._protocolVersion; } - return headers; + return new Headers( + { ...headers, ...this._requestInit?.headers } + ); } private _startOrAuth(): Promise { -const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch + const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch return new Promise((resolve, reject) => { this._eventSource = new EventSource( this._url.href, { ...this._eventSourceInit, fetch: async (url, init) => { - const headers = await this._commonHeaders() + const headers = await this._commonHeaders(); + headers.set("Accept", "text/event-stream"); const response = await fetchImpl(url, { ...init, - headers: new Headers({ - ...headers, - Accept: "text/event-stream" - }) + headers, }) if (response.status === 401 && response.headers.has('www-authenticate')) { @@ -238,8 +236,7 @@ const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typ } try { - const commonHeaders = await this._commonHeaders(); - const headers = new Headers(commonHeaders); + const headers = await this._commonHeaders(); headers.set("content-type", "application/json"); const init = { ...this._requestInit, diff --git a/src/client/stdio.test.ts b/src/client/stdio.test.ts index b2132446..2e4d92c2 100644 --- a/src/client/stdio.test.ts +++ b/src/client/stdio.test.ts @@ -1,10 +1,17 @@ import { JSONRPCMessage } from "../types.js"; import { StdioClientTransport, StdioServerParameters } from "./stdio.js"; -const serverParameters: StdioServerParameters = { - command: "/usr/bin/tee", +// Configure default server parameters based on OS +// Uses 'more' command for Windows and 'tee' command for Unix/Linux +const getDefaultServerParameters = (): StdioServerParameters => { + if (process.platform === "win32") { + return { command: "more" }; + } + return { command: "/usr/bin/tee" }; }; +const serverParameters = getDefaultServerParameters(); + test("should start then close cleanly", async () => { const client = new StdioClientTransport(serverParameters); client.onerror = (error) => { diff --git a/src/client/stdio.ts b/src/client/stdio.ts index e9c9fa8f..62292ce1 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -122,7 +122,11 @@ export class StdioClientTransport implements Transport { this._serverParams.command, this._serverParams.args ?? [], { - env: this._serverParams.env ?? getDefaultEnvironment(), + // merge default env with server env because mcp server needs some env vars + env: { + ...getDefaultEnvironment(), + ...this._serverParams.env, + }, stdio: ["pipe", "pipe", this._serverParams.stderr ?? "inherit"], shell: false, signal: this._abortController.signal, diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index c2c48366..a731bd96 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -17,6 +17,7 @@ describe("StreamableHTTPClientTransport", () => { redirectToAuthorization: jest.fn(), saveCodeVerifier: jest.fn(), codeVerifier: jest.fn(), + invalidateCredentials: jest.fn(), }; transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { authProvider: mockAuthProvider }); jest.spyOn(global, "fetch"); @@ -592,6 +593,7 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); + describe('Reconnection Logic', () => { let transport: StreamableHTTPClientTransport; diff --git a/src/examples/server/demoInMemoryOAuthProvider.ts b/src/examples/server/demoInMemoryOAuthProvider.ts index 274a504a..c83748d3 100644 --- a/src/examples/server/demoInMemoryOAuthProvider.ts +++ b/src/examples/server/demoInMemoryOAuthProvider.ts @@ -200,7 +200,11 @@ export const setupAuthServer = ({authServerUrl, mcpServerUrl, strictResource}: { const auth_port = authServerUrl.port; // Start the auth server - authApp.listen(auth_port, () => { + authApp.listen(auth_port, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`OAuth Authorization Server listening on port ${auth_port}`); }); diff --git a/src/examples/server/jsonResponseStreamableHttp.ts b/src/examples/server/jsonResponseStreamableHttp.ts index 02d8c2de..d6501d27 100644 --- a/src/examples/server/jsonResponseStreamableHttp.ts +++ b/src/examples/server/jsonResponseStreamableHttp.ts @@ -4,6 +4,7 @@ import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { z } from 'zod'; import { CallToolResult, isInitializeRequest } from '../../types.js'; +import cors from 'cors'; // Create an MCP server with implementation details @@ -81,6 +82,12 @@ const getServer = () => { const app = express(); app.use(express.json()); +// Configure CORS to expose Mcp-Session-Id header for browser-based clients +app.use(cors({ + origin: '*', // Allow all origins - adjust as needed for production + exposedHeaders: ['Mcp-Session-Id'] +})); + // Map to store transports by session ID const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; @@ -151,7 +158,11 @@ app.get('/mcp', async (req: Request, res: Response) => { // Start the server const PORT = 3000; -app.listen(PORT, () => { +app.listen(PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`MCP Streamable HTTP Server listening on port ${PORT}`); }); diff --git a/src/examples/server/mcpServerOutputSchema.ts b/src/examples/server/mcpServerOutputSchema.ts index de3b363e..75bfe690 100644 --- a/src/examples/server/mcpServerOutputSchema.ts +++ b/src/examples/server/mcpServerOutputSchema.ts @@ -43,14 +43,7 @@ server.registerTool( void country; // Simulate weather API call const temp_c = Math.round((Math.random() * 35 - 5) * 10) / 10; - const conditionCandidates = [ - "sunny", - "cloudy", - "rainy", - "stormy", - "snowy", - ] as const; - const conditions = conditionCandidates[Math.floor(Math.random() * conditionCandidates.length)]; + const conditions = ["sunny", "cloudy", "rainy", "stormy", "snowy"][Math.floor(Math.random() * 5)]; const structuredContent = { temperature: { @@ -84,4 +77,4 @@ async function main() { main().catch((error) => { console.error("Server error:", error); process.exit(1); -}); +}); \ No newline at end of file diff --git a/src/examples/server/simpleSseServer.ts b/src/examples/server/simpleSseServer.ts index c3417920..f8bdd466 100644 --- a/src/examples/server/simpleSseServer.ts +++ b/src/examples/server/simpleSseServer.ts @@ -145,7 +145,11 @@ app.post('/messages', async (req: Request, res: Response) => { // Start the server const PORT = 3000; -app.listen(PORT, () => { +app.listen(PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`Simple SSE Server (deprecated protocol version 2024-11-05) listening on port ${PORT}`); }); diff --git a/src/examples/server/simpleStatelessStreamableHttp.ts b/src/examples/server/simpleStatelessStreamableHttp.ts index 6fb2ae83..b5a1e291 100644 --- a/src/examples/server/simpleStatelessStreamableHttp.ts +++ b/src/examples/server/simpleStatelessStreamableHttp.ts @@ -3,6 +3,7 @@ import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { z } from 'zod'; import { CallToolResult, GetPromptResult, ReadResourceResult } from '../../types.js'; +import cors from 'cors'; const getServer = () => { // Create an MCP server with implementation details @@ -96,6 +97,12 @@ const getServer = () => { const app = express(); app.use(express.json()); +// Configure CORS to expose Mcp-Session-Id header for browser-based clients +app.use(cors({ + origin: '*', // Allow all origins - adjust as needed for production + exposedHeaders: ['Mcp-Session-Id'] +})); + app.post('/mcp', async (req: Request, res: Response) => { const server = getServer(); try { @@ -151,7 +158,11 @@ app.delete('/mcp', async (req: Request, res: Response) => { // Start the server const PORT = 3000; -app.listen(PORT, () => { +app.listen(PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`MCP Stateless Streamable HTTP Server listening on port ${PORT}`); }); diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 3d523543..98f9d351 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -11,6 +11,8 @@ import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; import { OAuthMetadata } from 'src/shared/auth.js'; import { checkResourceAllowed } from 'src/shared/auth-utils.js'; +import cors from 'cors'; + // Check for OAuth flag const useOAuth = process.argv.includes('--oauth'); const strictOAuth = process.argv.includes('--oauth-strict'); @@ -420,12 +422,18 @@ const getServer = () => { return server; }; -const MCP_PORT = 3000; -const AUTH_PORT = 3001; +const MCP_PORT = process.env.MCP_PORT ? parseInt(process.env.MCP_PORT, 10) : 3000; +const AUTH_PORT = process.env.MCP_AUTH_PORT ? parseInt(process.env.MCP_AUTH_PORT, 10) : 3001; const app = express(); app.use(express.json()); +// Allow CORS all domains, expose the Mcp-Session-Id header +app.use(cors({ + origin: '*', // Allow all origins + exposedHeaders: ["Mcp-Session-Id"] +})); + // Set up OAuth if enabled let authMiddleware = null; if (useOAuth) { @@ -640,7 +648,11 @@ if (useOAuth && authMiddleware) { app.delete('/mcp', mcpDeleteHandler); } -app.listen(MCP_PORT, () => { +app.listen(MCP_PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`MCP Streamable HTTP Server listening on port ${MCP_PORT}`); }); diff --git a/src/examples/server/sseAndStreamableHttpCompatibleServer.ts b/src/examples/server/sseAndStreamableHttpCompatibleServer.ts index ded110a1..e097ca70 100644 --- a/src/examples/server/sseAndStreamableHttpCompatibleServer.ts +++ b/src/examples/server/sseAndStreamableHttpCompatibleServer.ts @@ -6,6 +6,7 @@ import { SSEServerTransport } from '../../server/sse.js'; import { z } from 'zod'; import { CallToolResult, isInitializeRequest } from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; +import cors from 'cors'; /** * This example server demonstrates backwards compatibility with both: @@ -71,6 +72,12 @@ const getServer = () => { const app = express(); app.use(express.json()); +// Configure CORS to expose Mcp-Session-Id header for browser-based clients +app.use(cors({ + origin: '*', // Allow all origins - adjust as needed for production + exposedHeaders: ['Mcp-Session-Id'] +})); + // Store transports by session ID const transports: Record = {}; @@ -203,7 +210,11 @@ app.post("/messages", async (req: Request, res: Response) => { // Start the server const PORT = 3000; -app.listen(PORT, () => { +app.listen(PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`Backwards compatible MCP server listening on port ${PORT}`); console.log(` ============================================== diff --git a/src/examples/server/standaloneSseWithGetStreamableHttp.ts b/src/examples/server/standaloneSseWithGetStreamableHttp.ts index 8c8c3baa..27981813 100644 --- a/src/examples/server/standaloneSseWithGetStreamableHttp.ts +++ b/src/examples/server/standaloneSseWithGetStreamableHttp.ts @@ -112,7 +112,11 @@ app.get('/mcp', async (req: Request, res: Response) => { // Start the server const PORT = 3000; -app.listen(PORT, () => { +app.listen(PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } console.log(`Server listening on port ${PORT}`); }); diff --git a/src/server/auth/clients.ts b/src/server/auth/clients.ts index 1b61a4de..3b9110d5 100644 --- a/src/server/auth/clients.ts +++ b/src/server/auth/clients.ts @@ -16,5 +16,5 @@ export interface OAuthRegisteredClientsStore { * * If unimplemented, dynamic client registration is unsupported. */ - registerClient?(client: OAuthClientInformationFull): OAuthClientInformationFull | Promise; + registerClient?(client: Omit): OAuthClientInformationFull | Promise; } \ No newline at end of file diff --git a/src/server/auth/errors.ts b/src/server/auth/errors.ts index 428199ce..791b3b86 100644 --- a/src/server/auth/errors.ts +++ b/src/server/auth/errors.ts @@ -4,8 +4,9 @@ import { OAuthErrorResponse } from "../../shared/auth.js"; * Base class for all OAuth errors */ export class OAuthError extends Error { + static errorCode: string; + constructor( - public readonly errorCode: string, message: string, public readonly errorUri?: string ) { @@ -28,6 +29,10 @@ export class OAuthError extends Error { return response; } + + get errorCode(): string { + return (this.constructor as typeof OAuthError).errorCode + } } /** @@ -36,9 +41,7 @@ export class OAuthError extends Error { * or is otherwise malformed. */ export class InvalidRequestError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_request", message, errorUri); - } + static errorCode = "invalid_request"; } /** @@ -46,9 +49,7 @@ export class InvalidRequestError extends OAuthError { * authentication included, or unsupported authentication method). */ export class InvalidClientError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_client", message, errorUri); - } + static errorCode = "invalid_client"; } /** @@ -57,9 +58,7 @@ export class InvalidClientError extends OAuthError { * authorization request, or was issued to another client. */ export class InvalidGrantError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_grant", message, errorUri); - } + static errorCode = "invalid_grant"; } /** @@ -67,9 +66,7 @@ export class InvalidGrantError extends OAuthError { * this authorization grant type. */ export class UnauthorizedClientError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unauthorized_client", message, errorUri); - } + static errorCode = "unauthorized_client"; } /** @@ -77,9 +74,7 @@ export class UnauthorizedClientError extends OAuthError { * by the authorization server. */ export class UnsupportedGrantTypeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unsupported_grant_type", message, errorUri); - } + static errorCode = "unsupported_grant_type"; } /** @@ -87,18 +82,14 @@ export class UnsupportedGrantTypeError extends OAuthError { * exceeds the scope granted by the resource owner. */ export class InvalidScopeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_scope", message, errorUri); - } + static errorCode = "invalid_scope"; } /** * Access denied error - The resource owner or authorization server denied the request. */ export class AccessDeniedError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("access_denied", message, errorUri); - } + static errorCode = "access_denied"; } /** @@ -106,9 +97,7 @@ export class AccessDeniedError extends OAuthError { * that prevented it from fulfilling the request. */ export class ServerError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("server_error", message, errorUri); - } + static errorCode = "server_error"; } /** @@ -116,9 +105,7 @@ export class ServerError extends OAuthError { * handle the request due to a temporary overloading or maintenance of the server. */ export class TemporarilyUnavailableError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("temporarily_unavailable", message, errorUri); - } + static errorCode = "temporarily_unavailable"; } /** @@ -126,9 +113,7 @@ export class TemporarilyUnavailableError extends OAuthError { * obtaining an authorization code using this method. */ export class UnsupportedResponseTypeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unsupported_response_type", message, errorUri); - } + static errorCode = "unsupported_response_type"; } /** @@ -136,9 +121,7 @@ export class UnsupportedResponseTypeError extends OAuthError { * the requested token type. */ export class UnsupportedTokenTypeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unsupported_token_type", message, errorUri); - } + static errorCode = "unsupported_token_type"; } /** @@ -146,9 +129,7 @@ export class UnsupportedTokenTypeError extends OAuthError { * or invalid for other reasons. */ export class InvalidTokenError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_token", message, errorUri); - } + static errorCode = "invalid_token"; } /** @@ -156,9 +137,7 @@ export class InvalidTokenError extends OAuthError { * (Custom, non-standard error) */ export class MethodNotAllowedError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("method_not_allowed", message, errorUri); - } + static errorCode = "method_not_allowed"; } /** @@ -166,9 +145,7 @@ export class MethodNotAllowedError extends OAuthError { * (Custom, non-standard error based on RFC 6585) */ export class TooManyRequestsError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("too_many_requests", message, errorUri); - } + static errorCode = "too_many_requests"; } /** @@ -176,16 +153,47 @@ export class TooManyRequestsError extends OAuthError { * (Custom error for dynamic client registration - RFC 7591) */ export class InvalidClientMetadataError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_client_metadata", message, errorUri); - } + static errorCode = "invalid_client_metadata"; } /** * Insufficient scope error - The request requires higher privileges than provided by the access token. */ export class InsufficientScopeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("insufficient_scope", message, errorUri); + static errorCode = "insufficient_scope"; +} + +/** + * A utility class for defining one-off error codes + */ +export class CustomOAuthError extends OAuthError { + constructor(private readonly customErrorCode: string, message: string, errorUri?: string) { + super(message, errorUri); + } + + get errorCode(): string { + return this.customErrorCode; } } + +/** + * A full list of all OAuthErrors, enabling parsing from error responses + */ +export const OAUTH_ERRORS = { + [InvalidRequestError.errorCode]: InvalidRequestError, + [InvalidClientError.errorCode]: InvalidClientError, + [InvalidGrantError.errorCode]: InvalidGrantError, + [UnauthorizedClientError.errorCode]: UnauthorizedClientError, + [UnsupportedGrantTypeError.errorCode]: UnsupportedGrantTypeError, + [InvalidScopeError.errorCode]: InvalidScopeError, + [AccessDeniedError.errorCode]: AccessDeniedError, + [ServerError.errorCode]: ServerError, + [TemporarilyUnavailableError.errorCode]: TemporarilyUnavailableError, + [UnsupportedResponseTypeError.errorCode]: UnsupportedResponseTypeError, + [UnsupportedTokenTypeError.errorCode]: UnsupportedTokenTypeError, + [InvalidTokenError.errorCode]: InvalidTokenError, + [MethodNotAllowedError.errorCode]: MethodNotAllowedError, + [TooManyRequestsError.errorCode]: TooManyRequestsError, + [InvalidClientMetadataError.errorCode]: InvalidClientMetadataError, + [InsufficientScopeError.errorCode]: InsufficientScopeError, +} as const; diff --git a/src/server/auth/handlers/register.test.ts b/src/server/auth/handlers/register.test.ts index a961f654..d95e6d82 100644 --- a/src/server/auth/handlers/register.test.ts +++ b/src/server/auth/handlers/register.test.ts @@ -218,6 +218,26 @@ describe('Client Registration Handler', () => { expect(response.body.client_secret_expires_at).toBe(0); }); + it('sets no client_id when clientIdGeneration=false', async () => { + // Create handler with no expiry + const customApp = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientIdGeneration: false + }; + + customApp.use('/register', clientRegistrationHandler(options)); + + const response = await supertest(customApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.status).toBe(201); + expect(response.body.client_id).toBeUndefined(); + }); + it('handles client with all metadata fields', async () => { const fullClientMetadata: OAuthClientMetadata = { redirect_uris: ['https://example.com/callback'], diff --git a/src/server/auth/handlers/register.ts b/src/server/auth/handlers/register.ts index c3137348..197e0053 100644 --- a/src/server/auth/handlers/register.ts +++ b/src/server/auth/handlers/register.ts @@ -31,6 +31,13 @@ export type ClientRegistrationHandlerOptions = { * Registration endpoints are particularly sensitive to abuse and should be rate limited. */ rateLimit?: Partial | false; + + /** + * Whether to generate a client ID before calling the client registration endpoint. + * + * If not set, defaults to true. + */ + clientIdGeneration?: boolean; }; const DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS = 30 * 24 * 60 * 60; // 30 days @@ -38,7 +45,8 @@ const DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS = 30 * 24 * 60 * 60; // 30 days export function clientRegistrationHandler({ clientsStore, clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS, - rateLimit: rateLimitConfig + rateLimit: rateLimitConfig, + clientIdGeneration = true, }: ClientRegistrationHandlerOptions): RequestHandler { if (!clientsStore.registerClient) { throw new Error("Client registration store does not support registering clients"); @@ -78,7 +86,6 @@ export function clientRegistrationHandler({ const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none' // Generate client credentials - const clientId = crypto.randomUUID(); const clientSecret = isPublicClient ? undefined : crypto.randomBytes(32).toString('hex'); @@ -89,14 +96,17 @@ export function clientRegistrationHandler({ const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0 const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime - let clientInfo: OAuthClientInformationFull = { + let clientInfo: Omit & { client_id?: string } = { ...clientMetadata, - client_id: clientId, client_secret: clientSecret, client_id_issued_at: clientIdIssuedAt, client_secret_expires_at: clientSecretExpiresAt, }; + if (clientIdGeneration) { + clientInfo.client_id = crypto.randomUUID(); + } + clientInfo = await clientsStore.registerClient!(clientInfo); res.status(201).json(clientInfo); } catch (error) { diff --git a/src/server/auth/handlers/token.test.ts b/src/server/auth/handlers/token.test.ts index 4b7fae02..946cc691 100644 --- a/src/server/auth/handlers/token.test.ts +++ b/src/server/auth/handlers/token.test.ts @@ -16,6 +16,18 @@ jest.mock('pkce-challenge', () => ({ }) })); +const mockTokens = { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' +}; + +const mockTokensWithIdToken = { + ...mockTokens, + id_token: 'mock_id_token' +} + describe('Token Handler', () => { // Mock client data const validClient: OAuthClientInformationFull = { @@ -58,12 +70,7 @@ describe('Token Handler', () => { async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { if (authorizationCode === 'valid_code') { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; + return mockTokens; } throw new InvalidGrantError('The authorization code is invalid or has expired'); }, @@ -291,18 +298,36 @@ describe('Token Handler', () => { ); }); + it('returns id token in code exchange if provided', async () => { + mockProvider.exchangeAuthorizationCode = async (client: OAuthClientInformationFull, authorizationCode: string): Promise => { + if (authorizationCode === 'valid_code') { + return mockTokensWithIdToken; + } + throw new InvalidGrantError('The authorization code is invalid or has expired'); + }; + + const response = await supertest(app) + .post('/token') + .type('form') + .send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(200); + expect(response.body.id_token).toBe('mock_id_token'); + }); + it('passes through code verifier when using proxy provider', async () => { const originalFetch = global.fetch; try { global.fetch = jest.fn().mockResolvedValue({ ok: true, - json: () => Promise.resolve({ - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }) + json: () => Promise.resolve(mockTokens) }); const proxyProvider = new ProxyOAuthServerProvider({ @@ -359,12 +384,7 @@ describe('Token Handler', () => { try { global.fetch = jest.fn().mockResolvedValue({ ok: true, - json: () => Promise.resolve({ - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }) + json: () => Promise.resolve(mockTokens) }); const proxyProvider = new ProxyOAuthServerProvider({ diff --git a/src/server/auth/middleware/bearerAuth.test.ts b/src/server/auth/middleware/bearerAuth.test.ts index b8953e5c..38639b1d 100644 --- a/src/server/auth/middleware/bearerAuth.test.ts +++ b/src/server/auth/middleware/bearerAuth.test.ts @@ -1,7 +1,7 @@ import { Request, Response } from "express"; import { requireBearerAuth } from "./bearerAuth.js"; import { AuthInfo } from "../types.js"; -import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from "../errors.js"; +import { InsufficientScopeError, InvalidTokenError, CustomOAuthError, ServerError } from "../errors.js"; import { OAuthTokenVerifier } from "../provider.js"; // Mock verifier @@ -37,6 +37,7 @@ describe("requireBearerAuth middleware", () => { token: "valid-token", clientId: "client-123", scopes: ["read", "write"], + expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour }; mockVerifyAccessToken.mockResolvedValue(validAuthInfo); @@ -53,13 +54,17 @@ describe("requireBearerAuth middleware", () => { expect(mockResponse.status).not.toHaveBeenCalled(); expect(mockResponse.json).not.toHaveBeenCalled(); }); - - it("should reject expired tokens", async () => { + + it.each([ + [100], // Token expired 100 seconds ago + [0], // Token expires at the same time as now + ])("should reject expired tokens (expired %s seconds ago)", async (expiredSecondsAgo: number) => { + const expiresAt = Math.floor(Date.now() / 1000) - expiredSecondsAgo; const expiredAuthInfo: AuthInfo = { token: "expired-token", clientId: "client-123", scopes: ["read", "write"], - expiresAt: Math.floor(Date.now() / 1000) - 100, // Token expired 100 seconds ago + expiresAt }; mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); @@ -82,6 +87,37 @@ describe("requireBearerAuth middleware", () => { expect(nextFunction).not.toHaveBeenCalled(); }); + it.each([ + [undefined], // Token has no expiration time + [NaN], // Token has no expiration time + ])("should reject tokens with no expiration time (expiresAt: %s)", async (expiresAt: number | undefined) => { + const noExpirationAuthInfo: AuthInfo = { + token: "no-expiration-token", + clientId: "client-123", + scopes: ["read", "write"], + expiresAt + }; + mockVerifyAccessToken.mockResolvedValue(noExpirationAuthInfo); + + mockRequest.headers = { + authorization: "Bearer expired-token", + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token"); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + "WWW-Authenticate", + expect.stringContaining('Bearer error="invalid_token"') + ); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: "invalid_token", error_description: "Token has no expiration time" }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + it("should accept non-expired tokens", async () => { const nonExpiredAuthInfo: AuthInfo = { token: "valid-token", @@ -141,6 +177,7 @@ describe("requireBearerAuth middleware", () => { token: "valid-token", clientId: "client-123", scopes: ["read", "write", "admin"], + expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour }; mockVerifyAccessToken.mockResolvedValue(authInfo); @@ -268,7 +305,7 @@ describe("requireBearerAuth middleware", () => { authorization: "Bearer valid-token", }; - mockVerifyAccessToken.mockRejectedValue(new OAuthError("custom_error", "Some OAuth error")); + mockVerifyAccessToken.mockRejectedValue(new CustomOAuthError("custom_error", "Some OAuth error")); const middleware = requireBearerAuth({ verifier: mockVerifier }); await middleware(mockRequest as Request, mockResponse as Response, nextFunction); diff --git a/src/server/auth/middleware/bearerAuth.ts b/src/server/auth/middleware/bearerAuth.ts index 91f763a9..7b6d8f61 100644 --- a/src/server/auth/middleware/bearerAuth.ts +++ b/src/server/auth/middleware/bearerAuth.ts @@ -63,8 +63,10 @@ export function requireBearerAuth({ verifier, requiredScopes = [], resourceMetad } } - // Check if the token is expired - if (!!authInfo.expiresAt && authInfo.expiresAt < Date.now() / 1000) { + // Check if the token is set to expire or if it is expired + if (typeof authInfo.expiresAt !== 'number' || isNaN(authInfo.expiresAt)) { + throw new InvalidTokenError("Token has no expiration time"); + } else if (authInfo.expiresAt < Date.now() / 1000) { throw new InvalidTokenError("Token has expired"); } diff --git a/src/server/auth/router.ts b/src/server/auth/router.ts index 3e752e7a..a06bf73a 100644 --- a/src/server/auth/router.ts +++ b/src/server/auth/router.ts @@ -142,7 +142,7 @@ export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { new URL(oauthMetadata.registration_endpoint).pathname, clientRegistrationHandler({ clientsStore: options.provider.clientsStore, - ...options, + ...options.clientRegistrationOptions, }) ); } diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index dc96a1b0..10e550df 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -1312,7 +1312,7 @@ describe("tool()", () => { resultType: "structured", // Missing required 'timestamp' field someExtraField: "unexpected" // Extra field not in schema - } as unknown as { processedInput: string; resultType: string; timestamp: string }, // Type assertion to bypass TypeScript validation for testing purposes + }, }) ); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index a5624e15..791facef 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -169,7 +169,7 @@ export class McpServer { } const args = parseResult.data; - const cb = tool.callback as ToolCallback; + const cb = tool.callback as ToolCallback; try { result = await Promise.resolve(cb(args, extra)); } catch (error) { @@ -184,7 +184,7 @@ export class McpServer { }; } } else { - const cb = tool.callback as ToolCallback; + const cb = tool.callback as ToolCallback; try { result = await Promise.resolve(cb(extra)); } catch (error) { @@ -772,7 +772,7 @@ export class McpServer { inputSchema: ZodRawShape | undefined, outputSchema: ZodRawShape | undefined, annotations: ToolAnnotations | undefined, - callback: ToolCallback + callback: ToolCallback ): RegisteredTool { const registeredTool: RegisteredTool = { title, @@ -929,7 +929,7 @@ export class McpServer { outputSchema?: OutputArgs; annotations?: ToolAnnotations; }, - cb: ToolCallback + cb: ToolCallback ): RegisteredTool { if (this._registeredTools[name]) { throw new Error(`Tool ${name} is already registered`); @@ -944,7 +944,7 @@ export class McpServer { inputSchema, outputSchema, annotations, - cb as ToolCallback + cb as ToolCallback ); } @@ -1138,16 +1138,6 @@ export class ResourceTemplate { } } -/** - * Type helper to create a strongly-typed CallToolResult with structuredContent - */ -type TypedCallToolResult = - OutputArgs extends ZodRawShape - ? CallToolResult & { - structuredContent?: z.objectOutputType; - } - : CallToolResult; - /** * Callback for a tool handler registered with Server.tool(). * @@ -1158,21 +1148,13 @@ type TypedCallToolResult = * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback< - InputArgs extends undefined | ZodRawShape = undefined, - OutputArgs extends undefined | ZodRawShape = undefined -> = InputArgs extends ZodRawShape +export type ToolCallback = + Args extends ZodRawShape ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra - ) => - | TypedCallToolResult - | Promise> - : ( - extra: RequestHandlerExtra - ) => - | TypedCallToolResult - | Promise>; + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise + : (extra: RequestHandlerExtra) => CallToolResult | Promise; export type RegisteredTool = { title?: string; @@ -1180,24 +1162,22 @@ export type RegisteredTool = { inputSchema?: AnyZodObject; outputSchema?: AnyZodObject; annotations?: ToolAnnotations; - callback: ToolCallback; + callback: ToolCallback; enabled: boolean; enable(): void; disable(): void; - update< - InputArgs extends ZodRawShape, - OutputArgs extends ZodRawShape - >(updates: { - name?: string | null; - title?: string; - description?: string; - paramsSchema?: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - callback?: ToolCallback - enabled?: boolean - }): void; - remove(): void; + update( + updates: { + name?: string | null, + title?: string, + description?: string, + paramsSchema?: InputArgs, + outputSchema?: OutputArgs, + annotations?: ToolAnnotations, + callback?: ToolCallback, + enabled?: boolean + }): void + remove(): void }; const EMPTY_OBJECT_JSON_SCHEMA = { diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 502435ea..3a0a5c06 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -29,6 +29,8 @@ interface TestServerConfig { enableJsonResponse?: boolean; customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise; eventStore?: EventStore; + onsessioninitialized?: (sessionId: string) => void | Promise; + onsessionclosed?: (sessionId: string) => void | Promise; } /** @@ -57,7 +59,9 @@ async function createTestServer(config: TestServerConfig = { sessionIdGenerator: const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, - eventStore: config.eventStore + eventStore: config.eventStore, + onsessioninitialized: config.onsessioninitialized, + onsessionclosed: config.onsessionclosed }); await mcpServer.connect(transport); @@ -111,7 +115,9 @@ async function createTestAuthServer(config: TestServerConfig = { sessionIdGenera const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, - eventStore: config.eventStore + eventStore: config.eventStore, + onsessioninitialized: config.onsessioninitialized, + onsessionclosed: config.onsessionclosed }); await mcpServer.connect(transport); @@ -1504,6 +1510,372 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { }); }); +// Test onsessionclosed callback +describe("StreamableHTTPServerTransport onsessionclosed callback", () => { + it("should call onsessionclosed callback when session is closed via DELETE", async () => { + const mockCallback = jest.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + expect(tempSessionId).toBeDefined(); + + // DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(tempSessionId); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // Clean up + tempServer.close(); + }); + + it("should not call onsessionclosed callback when not provided", async () => { + // Create server without onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + + // DELETE the session - should not throw error + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(200); + + // Clean up + tempServer.close(); + }); + + it("should not call onsessionclosed callback for invalid session DELETE", async () => { + const mockCallback = jest.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a valid session + await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + + // Try to DELETE with invalid session ID + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": "invalid-session-id", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(404); + expect(mockCallback).not.toHaveBeenCalled(); + + // Clean up + tempServer.close(); + }); + + it("should call onsessionclosed callback with correct session ID when multiple sessions exist", async () => { + const mockCallback = jest.fn(); + + // Create first server + const result1 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const server1 = result1.server; + const url1 = result1.baseUrl; + + // Create second server + const result2 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const server2 = result2.server; + const url2 = result2.baseUrl; + + // Initialize both servers + const initResponse1 = await sendPostRequest(url1, TEST_MESSAGES.initialize); + const sessionId1 = initResponse1.headers.get("mcp-session-id"); + + const initResponse2 = await sendPostRequest(url2, TEST_MESSAGES.initialize); + const sessionId2 = initResponse2.headers.get("mcp-session-id"); + + expect(sessionId1).toBeDefined(); + expect(sessionId2).toBeDefined(); + expect(sessionId1).not.toBe(sessionId2); + + // DELETE first session + const deleteResponse1 = await fetch(url1, { + method: "DELETE", + headers: { + "mcp-session-id": sessionId1 || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse1.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId1); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // DELETE second session + const deleteResponse2 = await fetch(url2, { + method: "DELETE", + headers: { + "mcp-session-id": sessionId2 || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse2.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId2); + expect(mockCallback).toHaveBeenCalledTimes(2); + + // Clean up + server1.close(); + server2.close(); + }); +}); + +// Test async callbacks for onsessioninitialized and onsessionclosed +describe("StreamableHTTPServerTransport async callbacks", () => { + it("should support async onsessioninitialized callback", async () => { + const initializationOrder: string[] = []; + + // Create server with async onsessioninitialized callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (sessionId: string) => { + initializationOrder.push('async-start'); + // Simulate async operation + await new Promise(resolve => setTimeout(resolve, 10)); + initializationOrder.push('async-end'); + initializationOrder.push(sessionId); + }, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to trigger the callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + + // Give time for async callback to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(initializationOrder).toEqual(['async-start', 'async-end', tempSessionId]); + + // Clean up + tempServer.close(); + }); + + it("should support sync onsessioninitialized callback (backwards compatibility)", async () => { + const capturedSessionId: string[] = []; + + // Create server with sync onsessioninitialized callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: (sessionId: string) => { + capturedSessionId.push(sessionId); + }, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to trigger the callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + + expect(capturedSessionId).toEqual([tempSessionId]); + + // Clean up + tempServer.close(); + }); + + it("should support async onsessionclosed callback", async () => { + const closureOrder: string[] = []; + + // Create server with async onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: async (sessionId: string) => { + closureOrder.push('async-close-start'); + // Simulate async operation + await new Promise(resolve => setTimeout(resolve, 10)); + closureOrder.push('async-close-end'); + closureOrder.push(sessionId); + }, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + expect(tempSessionId).toBeDefined(); + + // DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(200); + + // Give time for async callback to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(closureOrder).toEqual(['async-close-start', 'async-close-end', tempSessionId]); + + // Clean up + tempServer.close(); + }); + + it("should propagate errors from async onsessioninitialized callback", async () => { + const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); + + // Create server with async onsessioninitialized callback that throws + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (_sessionId: string) => { + throw new Error('Async initialization error'); + }, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize should fail when callback throws + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + expect(initResponse.status).toBe(400); + + // Clean up + consoleErrorSpy.mockRestore(); + tempServer.close(); + }); + + it("should propagate errors from async onsessionclosed callback", async () => { + const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); + + // Create server with async onsessionclosed callback that throws + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: async (_sessionId: string) => { + throw new Error('Async closure error'); + }, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + + // DELETE should fail when callback throws + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(500); + + // Clean up + consoleErrorSpy.mockRestore(); + tempServer.close(); + }); + + it("should handle both async callbacks together", async () => { + const events: string[] = []; + + // Create server with both async callbacks + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (sessionId: string) => { + await new Promise(resolve => setTimeout(resolve, 5)); + events.push(`initialized:${sessionId}`); + }, + onsessionclosed: async (sessionId: string) => { + await new Promise(resolve => setTimeout(resolve, 5)); + events.push(`closed:${sessionId}`); + }, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to trigger first callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + + // Wait for async callback + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(events).toContain(`initialized:${tempSessionId}`); + + // DELETE to trigger second callback + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(200); + + // Wait for async callback + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(events).toContain(`closed:${tempSessionId}`); + expect(events).toHaveLength(2); + + // Clean up + tempServer.close(); + }); +}); + // Test DNS rebinding protection describe("StreamableHTTPServerTransport DNS rebinding protection", () => { let server: Server; diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 022d1a47..3bf84e43 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -47,7 +47,19 @@ export interface StreamableHTTPServerTransportOptions { * and need to keep track of them. * @param sessionId The generated session ID */ - onsessioninitialized?: (sessionId: string) => void; + onsessioninitialized?: (sessionId: string) => void | Promise; + + /** + * A callback for session close events + * This is called when the server closes a session due to a DELETE request. + * Useful in cases when you need to clean up resources associated with the session. + * Note that this is different from the transport closing, if you are handling + * HTTP requests from multiple nodes you might want to close each + * StreamableHTTPServerTransport after a request is completed while still keeping the + * session open/running. + * @param sessionId The session ID that was closed + */ + onsessionclosed?: (sessionId: string) => void | Promise; /** * If true, the server will return JSON responses instead of starting an SSE stream. @@ -126,7 +138,8 @@ export class StreamableHTTPServerTransport implements Transport { private _enableJsonResponse: boolean = false; private _standaloneSseStreamId: string = '_GET_stream'; private _eventStore?: EventStore; - private _onsessioninitialized?: (sessionId: string) => void; + private _onsessioninitialized?: (sessionId: string) => void | Promise; + private _onsessionclosed?: (sessionId: string) => void | Promise; private _allowedHosts?: string[]; private _allowedOrigins?: string[]; private _enableDnsRebindingProtection: boolean; @@ -141,6 +154,7 @@ export class StreamableHTTPServerTransport implements Transport { this._enableJsonResponse = options.enableJsonResponse ?? false; this._eventStore = options.eventStore; this._onsessioninitialized = options.onsessioninitialized; + this._onsessionclosed = options.onsessionclosed; this._allowedHosts = options.allowedHosts; this._allowedOrigins = options.allowedOrigins; this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; @@ -446,7 +460,7 @@ export class StreamableHTTPServerTransport implements Transport { // If we have a session ID and an onsessioninitialized handler, call it immediately // This is needed in cases where the server needs to keep track of multiple sessions if (this.sessionId && this._onsessioninitialized) { - this._onsessioninitialized(this.sessionId); + await Promise.resolve(this._onsessioninitialized(this.sessionId)); } } @@ -538,6 +552,7 @@ export class StreamableHTTPServerTransport implements Transport { if (!this.validateProtocolVersion(req, res)) { return; } + await Promise.resolve(this._onsessionclosed?.(this.sessionId!)); await this.close(); res.writeHead(200).end(); } diff --git a/src/shared/auth.ts b/src/shared/auth.ts index b906de3d..467680a5 100644 --- a/src/shared/auth.ts +++ b/src/shared/auth.ts @@ -62,6 +62,7 @@ export const OAuthMetadataSchema = z export const OAuthTokensSchema = z .object({ access_token: z.string(), + id_token: z.string().optional(), // Optional for OAuth 2.1, but necessary in OpenID Connect token_type: z.string(), expires_in: z.number().optional(), scope: z.string().optional(), diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index b16db73f..f4e74c8b 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -466,6 +466,189 @@ describe("protocol tests", () => { await expect(requestPromise).resolves.toEqual({ result: "success" }); }); }); + + describe("Debounced Notifications", () => { + // We need to flush the microtask queue to test the debouncing logic. + // This helper function does that. + const flushMicrotasks = () => new Promise(resolve => setImmediate(resolve)); + + it("should NOT debounce a notification that has parameters", async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); + await protocol.connect(transport); + + // ACT + // These notifications are configured for debouncing but contain params, so they should be sent immediately. + await protocol.notification({ method: 'test/debounced_with_params', params: { data: 1 } }); + await protocol.notification({ method: 'test/debounced_with_params', params: { data: 2 } }); + + // ASSERT + // Both should have been sent immediately to avoid data loss. + expect(sendSpy).toHaveBeenCalledTimes(2); + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ params: { data: 1 } }), undefined); + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ params: { data: 2 } }), undefined); + }); + + it("should NOT debounce a notification that has a relatedRequestId", async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); + await protocol.connect(transport); + + // ACT + await protocol.notification({ method: 'test/debounced_with_options' }, { relatedRequestId: 'req-1' }); + await protocol.notification({ method: 'test/debounced_with_options' }, { relatedRequestId: 'req-2' }); + + // ASSERT + expect(sendSpy).toHaveBeenCalledTimes(2); + expect(sendSpy).toHaveBeenCalledWith(expect.any(Object), { relatedRequestId: 'req-1' }); + expect(sendSpy).toHaveBeenCalledWith(expect.any(Object), { relatedRequestId: 'req-2' }); + }); + + it("should clear pending debounced notifications on connection close", async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); + + // ACT + // Schedule a notification but don't flush the microtask queue. + protocol.notification({ method: 'test/debounced' }); + + // Close the connection. This should clear the pending set. + await protocol.close(); + + // Now, flush the microtask queue. + await flushMicrotasks(); + + // ASSERT + // The send should never have happened because the transport was cleared. + expect(sendSpy).not.toHaveBeenCalled(); + }); + + it("should debounce multiple synchronous calls when params property is omitted", async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); + + // ACT + // This is the more idiomatic way to write a notification with no params. + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + + expect(sendSpy).not.toHaveBeenCalled(); + await flushMicrotasks(); + + // ASSERT + expect(sendSpy).toHaveBeenCalledTimes(1); + // The final sent object might not even have the `params` key, which is fine. + // We can check that it was called and that the params are "falsy". + const sentNotification = sendSpy.mock.calls[0][0]; + expect(sentNotification.method).toBe('test/debounced'); + expect(sentNotification.params).toBeUndefined(); + }); + + it("should debounce calls when params is explicitly undefined", async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); + + // ACT + protocol.notification({ method: 'test/debounced', params: undefined }); + protocol.notification({ method: 'test/debounced', params: undefined }); + await flushMicrotasks(); + + // ASSERT + expect(sendSpy).toHaveBeenCalledTimes(1); + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'test/debounced', + params: undefined + }), + undefined + ); + }); + + it("should send non-debounced notifications immediately and multiple times", async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method + await protocol.connect(transport); + + // ACT + // Call a non-debounced notification method multiple times. + await protocol.notification({ method: 'test/immediate' }); + await protocol.notification({ method: 'test/immediate' }); + + // ASSERT + // Since this method is not in the debounce list, it should be sent every time. + expect(sendSpy).toHaveBeenCalledTimes(2); + }); + + it("should not debounce any notifications if the option is not provided", async () => { + // ARRANGE + // Use the default protocol from beforeEach, which has no debounce options. + await protocol.connect(transport); + + // ACT + await protocol.notification({ method: 'any/method' }); + await protocol.notification({ method: 'any/method' }); + + // ASSERT + // Without the config, behavior should be immediate sending. + expect(sendSpy).toHaveBeenCalledTimes(2); + }); + + it("should handle sequential batches of debounced notifications correctly", async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); + + // ACT (Batch 1) + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + await flushMicrotasks(); + + // ASSERT (Batch 1) + expect(sendSpy).toHaveBeenCalledTimes(1); + + // ACT (Batch 2) + // After the first batch has been sent, a new batch should be possible. + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + await flushMicrotasks(); + + // ASSERT (Batch 2) + // The total number of sends should now be 2. + expect(sendSpy).toHaveBeenCalledTimes(2); + }); + }); }); describe("mergeCapabilities", () => { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 50bdcc3c..6142140d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -45,6 +45,13 @@ export type ProtocolOptions = { * Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true. */ enforceStrictCapabilities?: boolean; + /** + * An array of notification method names that should be automatically debounced. + * Any notifications with a method in this list will be coalesced if they + * occur in the same tick of the event loop. + * e.g., ['notifications/tools/list_changed'] + */ + debouncedNotificationMethods?: string[]; }; /** @@ -191,6 +198,7 @@ export abstract class Protocol< > = new Map(); private _progressHandlers: Map = new Map(); private _timeoutInfo: Map = new Map(); + private _pendingDebouncedNotifications = new Set(); /** * Callback for when the connection is closed for any reason. @@ -321,6 +329,7 @@ export abstract class Protocol< const responseHandlers = this._responseHandlers; this._responseHandlers = new Map(); this._progressHandlers.clear(); + this._pendingDebouncedNotifications.clear(); this._transport = undefined; this.onclose?.(); @@ -632,6 +641,46 @@ export abstract class Protocol< this.assertNotificationCapability(notification.method); + const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; + // A notification can only be debounced if it's in the list AND it's "simple" + // (i.e., has no parameters and no related request ID that could be lost). + const canDebounce = debouncedMethods.includes(notification.method) + && !notification.params + && !(options?.relatedRequestId); + + if (canDebounce) { + // If a notification of this type is already scheduled, do nothing. + if (this._pendingDebouncedNotifications.has(notification.method)) { + return; + } + + // Mark this notification type as pending. + this._pendingDebouncedNotifications.add(notification.method); + + // Schedule the actual send to happen in the next microtask. + // This allows all synchronous calls in the current event loop tick to be coalesced. + Promise.resolve().then(() => { + // Un-mark the notification so the next one can be scheduled. + this._pendingDebouncedNotifications.delete(notification.method); + + // SAFETY CHECK: If the connection was closed while this was pending, abort. + if (!this._transport) { + return; + } + + const jsonrpcNotification: JSONRPCNotification = { + ...notification, + jsonrpc: "2.0", + }; + // Send the notification, but don't await it here to avoid blocking. + // Handle potential errors with a .catch(). + this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error)); + }); + + // Return immediately. + return; + } + const jsonrpcNotification: JSONRPCNotification = { ...notification, jsonrpc: "2.0", diff --git a/src/spec.types.test.ts b/src/spec.types.test.ts new file mode 100644 index 00000000..09cd6c2d --- /dev/null +++ b/src/spec.types.test.ts @@ -0,0 +1,705 @@ +/** + * This contains: + * - Static type checks to verify the Spec's types are compatible with the SDK's types + * (mutually assignable, w/ slight affordances to get rid of ZodObject.passthrough() index signatures, etc) + * - Runtime checks to verify each Spec type has a static check + * (note: a few don't have SDK types, see MISSING_SDK_TYPES below) + */ +import * as SDKTypes from "./types.js"; +import * as SpecTypes from "../spec.types.js"; +import fs from "node:fs"; + +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable @typescript-eslint/no-unsafe-function-type */ + +// Removes index signatures added by ZodObject.passthrough(). +type RemovePassthrough = T extends object + ? T extends Array + ? Array> + : T extends Function + ? T + : {[K in keyof T as string extends K ? never : K]: RemovePassthrough} + : T; + +type IsUnknown = [unknown] extends [T] ? [T] extends [unknown] ? true : false : false; + +// Turns {x?: unknown} into {x: unknown} but keeps {_meta?: unknown} unchanged (and leaves other optional properties unchanged, e.g. {x?: string}). +// This works around an apparent quirk of ZodObject.unknown() (makes fields optional) +type MakeUnknownsNotOptional = + IsUnknown extends true + ? unknown + : (T extends object + ? (T extends Array + ? Array> + : (T extends Function + ? T + : Pick & { + // Start with empty object to avoid duplicates + // Make unknown properties required (except _meta) + [K in keyof T as '_meta' extends K ? never : IsUnknown extends true ? K : never]-?: unknown; + } & + Pick extends true ? never : K + }[keyof T]> & { + // Recurse on the picked properties + [K in keyof Pick extends true ? never : K}[keyof T]>]: MakeUnknownsNotOptional + })) + : T); + +function checkCancelledNotification( + sdk: SDKTypes.CancelledNotification, + spec: SpecTypes.CancelledNotification +) { + sdk = spec; + spec = sdk; +} +function checkBaseMetadata( + sdk: RemovePassthrough, + spec: SpecTypes.BaseMetadata +) { + sdk = spec; + spec = sdk; +} +function checkImplementation( + sdk: RemovePassthrough, + spec: SpecTypes.Implementation +) { + sdk = spec; + spec = sdk; +} +function checkProgressNotification( + sdk: SDKTypes.ProgressNotification, + spec: SpecTypes.ProgressNotification +) { + sdk = spec; + spec = sdk; +} + +function checkSubscribeRequest( + sdk: SDKTypes.SubscribeRequest, + spec: SpecTypes.SubscribeRequest +) { + sdk = spec; + spec = sdk; +} +function checkUnsubscribeRequest( + sdk: SDKTypes.UnsubscribeRequest, + spec: SpecTypes.UnsubscribeRequest +) { + sdk = spec; + spec = sdk; +} +function checkPaginatedRequest( + sdk: SDKTypes.PaginatedRequest, + spec: SpecTypes.PaginatedRequest +) { + sdk = spec; + spec = sdk; +} +function checkPaginatedResult( + sdk: SDKTypes.PaginatedResult, + spec: SpecTypes.PaginatedResult +) { + sdk = spec; + spec = sdk; +} +function checkListRootsRequest( + sdk: SDKTypes.ListRootsRequest, + spec: SpecTypes.ListRootsRequest +) { + sdk = spec; + spec = sdk; +} +function checkListRootsResult( + sdk: RemovePassthrough, + spec: SpecTypes.ListRootsResult +) { + sdk = spec; + spec = sdk; +} +function checkRoot( + sdk: RemovePassthrough, + spec: SpecTypes.Root +) { + sdk = spec; + spec = sdk; +} +function checkElicitRequest( + sdk: RemovePassthrough, + spec: SpecTypes.ElicitRequest +) { + sdk = spec; + spec = sdk; +} +function checkElicitResult( + sdk: RemovePassthrough, + spec: SpecTypes.ElicitResult +) { + sdk = spec; + spec = sdk; +} +function checkCompleteRequest( + sdk: RemovePassthrough, + spec: SpecTypes.CompleteRequest +) { + sdk = spec; + spec = sdk; +} +function checkCompleteResult( + sdk: SDKTypes.CompleteResult, + spec: SpecTypes.CompleteResult +) { + sdk = spec; + spec = sdk; +} +function checkProgressToken( + sdk: SDKTypes.ProgressToken, + spec: SpecTypes.ProgressToken +) { + sdk = spec; + spec = sdk; +} +function checkCursor( + sdk: SDKTypes.Cursor, + spec: SpecTypes.Cursor +) { + sdk = spec; + spec = sdk; +} +function checkRequest( + sdk: SDKTypes.Request, + spec: SpecTypes.Request +) { + sdk = spec; + spec = sdk; +} +function checkResult( + sdk: SDKTypes.Result, + spec: SpecTypes.Result +) { + sdk = spec; + spec = sdk; +} +function checkRequestId( + sdk: SDKTypes.RequestId, + spec: SpecTypes.RequestId +) { + sdk = spec; + spec = sdk; +} +function checkJSONRPCRequest( + sdk: SDKTypes.JSONRPCRequest, + spec: SpecTypes.JSONRPCRequest +) { + sdk = spec; + spec = sdk; +} +function checkJSONRPCNotification( + sdk: SDKTypes.JSONRPCNotification, + spec: SpecTypes.JSONRPCNotification +) { + sdk = spec; + spec = sdk; +} +function checkJSONRPCResponse( + sdk: SDKTypes.JSONRPCResponse, + spec: SpecTypes.JSONRPCResponse +) { + sdk = spec; + spec = sdk; +} +function checkEmptyResult( + sdk: SDKTypes.EmptyResult, + spec: SpecTypes.EmptyResult +) { + sdk = spec; + spec = sdk; +} +function checkNotification( + sdk: SDKTypes.Notification, + spec: SpecTypes.Notification +) { + sdk = spec; + spec = sdk; +} +function checkClientResult( + sdk: SDKTypes.ClientResult, + spec: SpecTypes.ClientResult +) { + sdk = spec; + spec = sdk; +} +function checkClientNotification( + sdk: SDKTypes.ClientNotification, + spec: SpecTypes.ClientNotification +) { + sdk = spec; + spec = sdk; +} +function checkServerResult( + sdk: SDKTypes.ServerResult, + spec: SpecTypes.ServerResult +) { + sdk = spec; + spec = sdk; +} +function checkResourceTemplateReference( + sdk: RemovePassthrough, + spec: SpecTypes.ResourceTemplateReference +) { + sdk = spec; + spec = sdk; +} +function checkPromptReference( + sdk: RemovePassthrough, + spec: SpecTypes.PromptReference +) { + sdk = spec; + spec = sdk; +} +function checkToolAnnotations( + sdk: RemovePassthrough, + spec: SpecTypes.ToolAnnotations +) { + sdk = spec; + spec = sdk; +} +function checkTool( + sdk: RemovePassthrough, + spec: SpecTypes.Tool +) { + sdk = spec; + spec = sdk; +} +function checkListToolsRequest( + sdk: SDKTypes.ListToolsRequest, + spec: SpecTypes.ListToolsRequest +) { + sdk = spec; + spec = sdk; +} +function checkListToolsResult( + sdk: RemovePassthrough, + spec: SpecTypes.ListToolsResult +) { + sdk = spec; + spec = sdk; +} +function checkCallToolResult( + sdk: RemovePassthrough, + spec: SpecTypes.CallToolResult +) { + sdk = spec; + spec = sdk; +} +function checkCallToolRequest( + sdk: SDKTypes.CallToolRequest, + spec: SpecTypes.CallToolRequest +) { + sdk = spec; + spec = sdk; +} +function checkToolListChangedNotification( + sdk: SDKTypes.ToolListChangedNotification, + spec: SpecTypes.ToolListChangedNotification +) { + sdk = spec; + spec = sdk; +} +function checkResourceListChangedNotification( + sdk: SDKTypes.ResourceListChangedNotification, + spec: SpecTypes.ResourceListChangedNotification +) { + sdk = spec; + spec = sdk; +} +function checkPromptListChangedNotification( + sdk: SDKTypes.PromptListChangedNotification, + spec: SpecTypes.PromptListChangedNotification +) { + sdk = spec; + spec = sdk; +} +function checkRootsListChangedNotification( + sdk: SDKTypes.RootsListChangedNotification, + spec: SpecTypes.RootsListChangedNotification +) { + sdk = spec; + spec = sdk; +} +function checkResourceUpdatedNotification( + sdk: SDKTypes.ResourceUpdatedNotification, + spec: SpecTypes.ResourceUpdatedNotification +) { + sdk = spec; + spec = sdk; +} +function checkSamplingMessage( + sdk: RemovePassthrough, + spec: SpecTypes.SamplingMessage +) { + sdk = spec; + spec = sdk; +} +function checkCreateMessageResult( + sdk: RemovePassthrough, + spec: SpecTypes.CreateMessageResult +) { + sdk = spec; + spec = sdk; +} +function checkSetLevelRequest( + sdk: SDKTypes.SetLevelRequest, + spec: SpecTypes.SetLevelRequest +) { + sdk = spec; + spec = sdk; +} +function checkPingRequest( + sdk: SDKTypes.PingRequest, + spec: SpecTypes.PingRequest +) { + sdk = spec; + spec = sdk; +} +function checkInitializedNotification( + sdk: SDKTypes.InitializedNotification, + spec: SpecTypes.InitializedNotification +) { + sdk = spec; + spec = sdk; +} +function checkListResourcesRequest( + sdk: SDKTypes.ListResourcesRequest, + spec: SpecTypes.ListResourcesRequest +) { + sdk = spec; + spec = sdk; +} +function checkListResourcesResult( + sdk: RemovePassthrough, + spec: SpecTypes.ListResourcesResult +) { + sdk = spec; + spec = sdk; +} +function checkListResourceTemplatesRequest( + sdk: SDKTypes.ListResourceTemplatesRequest, + spec: SpecTypes.ListResourceTemplatesRequest +) { + sdk = spec; + spec = sdk; +} +function checkListResourceTemplatesResult( + sdk: RemovePassthrough, + spec: SpecTypes.ListResourceTemplatesResult +) { + sdk = spec; + spec = sdk; +} +function checkReadResourceRequest( + sdk: SDKTypes.ReadResourceRequest, + spec: SpecTypes.ReadResourceRequest +) { + sdk = spec; + spec = sdk; +} +function checkReadResourceResult( + sdk: RemovePassthrough, + spec: SpecTypes.ReadResourceResult +) { + sdk = spec; + spec = sdk; +} +function checkResourceContents( + sdk: RemovePassthrough, + spec: SpecTypes.ResourceContents +) { + sdk = spec; + spec = sdk; +} +function checkTextResourceContents( + sdk: RemovePassthrough, + spec: SpecTypes.TextResourceContents +) { + sdk = spec; + spec = sdk; +} +function checkBlobResourceContents( + sdk: RemovePassthrough, + spec: SpecTypes.BlobResourceContents +) { + sdk = spec; + spec = sdk; +} +function checkResource( + sdk: RemovePassthrough, + spec: SpecTypes.Resource +) { + sdk = spec; + spec = sdk; +} +function checkResourceTemplate( + sdk: RemovePassthrough, + spec: SpecTypes.ResourceTemplate +) { + sdk = spec; + spec = sdk; +} +function checkPromptArgument( + sdk: RemovePassthrough, + spec: SpecTypes.PromptArgument +) { + sdk = spec; + spec = sdk; +} +function checkPrompt( + sdk: RemovePassthrough, + spec: SpecTypes.Prompt +) { + sdk = spec; + spec = sdk; +} +function checkListPromptsRequest( + sdk: SDKTypes.ListPromptsRequest, + spec: SpecTypes.ListPromptsRequest +) { + sdk = spec; + spec = sdk; +} +function checkListPromptsResult( + sdk: RemovePassthrough, + spec: SpecTypes.ListPromptsResult +) { + sdk = spec; + spec = sdk; +} +function checkGetPromptRequest( + sdk: SDKTypes.GetPromptRequest, + spec: SpecTypes.GetPromptRequest +) { + sdk = spec; + spec = sdk; +} +function checkTextContent( + sdk: RemovePassthrough, + spec: SpecTypes.TextContent +) { + sdk = spec; + spec = sdk; +} +function checkImageContent( + sdk: RemovePassthrough, + spec: SpecTypes.ImageContent +) { + sdk = spec; + spec = sdk; +} +function checkAudioContent( + sdk: RemovePassthrough, + spec: SpecTypes.AudioContent +) { + sdk = spec; + spec = sdk; +} +function checkEmbeddedResource( + sdk: RemovePassthrough, + spec: SpecTypes.EmbeddedResource +) { + sdk = spec; + spec = sdk; +} +function checkResourceLink( + sdk: RemovePassthrough, + spec: SpecTypes.ResourceLink +) { + sdk = spec; + spec = sdk; +} +function checkContentBlock( + sdk: RemovePassthrough, + spec: SpecTypes.ContentBlock +) { + sdk = spec; + spec = sdk; +} +function checkPromptMessage( + sdk: RemovePassthrough, + spec: SpecTypes.PromptMessage +) { + sdk = spec; + spec = sdk; +} +function checkGetPromptResult( + sdk: RemovePassthrough, + spec: SpecTypes.GetPromptResult +) { + sdk = spec; + spec = sdk; +} +function checkBooleanSchema( + sdk: RemovePassthrough, + spec: SpecTypes.BooleanSchema +) { + sdk = spec; + spec = sdk; +} +function checkStringSchema( + sdk: RemovePassthrough, + spec: SpecTypes.StringSchema +) { + sdk = spec; + spec = sdk; +} +function checkNumberSchema( + sdk: RemovePassthrough, + spec: SpecTypes.NumberSchema +) { + sdk = spec; + spec = sdk; +} +function checkEnumSchema( + sdk: RemovePassthrough, + spec: SpecTypes.EnumSchema +) { + sdk = spec; + spec = sdk; +} +function checkPrimitiveSchemaDefinition( + sdk: RemovePassthrough, + spec: SpecTypes.PrimitiveSchemaDefinition +) { + sdk = spec; + spec = sdk; +} +function checkJSONRPCError( + sdk: SDKTypes.JSONRPCError, + spec: SpecTypes.JSONRPCError +) { + sdk = spec; + spec = sdk; +} +function checkJSONRPCMessage( + sdk: SDKTypes.JSONRPCMessage, + spec: SpecTypes.JSONRPCMessage +) { + sdk = spec; + spec = sdk; +} +function checkCreateMessageRequest( + sdk: RemovePassthrough, + spec: SpecTypes.CreateMessageRequest +) { + sdk = spec; + spec = sdk; +} +function checkInitializeRequest( + sdk: RemovePassthrough, + spec: SpecTypes.InitializeRequest +) { + sdk = spec; + spec = sdk; +} +function checkInitializeResult( + sdk: RemovePassthrough, + spec: SpecTypes.InitializeResult +) { + sdk = spec; + spec = sdk; +} +function checkClientCapabilities( + sdk: RemovePassthrough, + spec: SpecTypes.ClientCapabilities +) { + sdk = spec; + spec = sdk; +} +function checkServerCapabilities( + sdk: RemovePassthrough, + spec: SpecTypes.ServerCapabilities +) { + sdk = spec; + spec = sdk; +} +function checkClientRequest( + sdk: RemovePassthrough, + spec: SpecTypes.ClientRequest +) { + sdk = spec; + spec = sdk; +} +function checkServerRequest( + sdk: RemovePassthrough, + spec: SpecTypes.ServerRequest +) { + sdk = spec; + spec = sdk; +} +function checkLoggingMessageNotification( + sdk: MakeUnknownsNotOptional, + spec: SpecTypes.LoggingMessageNotification +) { + sdk = spec; + spec = sdk; +} +function checkServerNotification( + sdk: MakeUnknownsNotOptional, + spec: SpecTypes.ServerNotification +) { + sdk = spec; + spec = sdk; +} +function checkLoggingLevel( + sdk: SDKTypes.LoggingLevel, + spec: SpecTypes.LoggingLevel +) { + sdk = spec; + spec = sdk; +} + +// This file is .gitignore'd, and fetched by `npm run fetch:spec-types` (called by `npm run test`) +const SPEC_TYPES_FILE = 'spec.types.ts'; +const SDK_TYPES_FILE = 'src/types.ts'; + +const MISSING_SDK_TYPES = [ + // These are inlined in the SDK: + 'Role', + + // These aren't supported by the SDK yet: + // TODO: Add definitions to the SDK + 'Annotations', + 'ModelHint', + 'ModelPreferences', +] + +function extractExportedTypes(source: string): string[] { + return [...source.matchAll(/export\s+(?:interface|class|type)\s+(\w+)\b/g)].map(m => m[1]); +} + +describe('Spec Types', () => { + const specTypes = extractExportedTypes(fs.readFileSync(SPEC_TYPES_FILE, 'utf-8')); + const sdkTypes = extractExportedTypes(fs.readFileSync(SDK_TYPES_FILE, 'utf-8')); + const testSource = fs.readFileSync(__filename, 'utf-8'); + + it('should define some expected types', () => { + expect(specTypes).toContain('JSONRPCNotification'); + expect(specTypes).toContain('ElicitResult'); + expect(specTypes).toHaveLength(91); + }); + + it('should have up to date list of missing sdk types', () => { + for (const typeName of MISSING_SDK_TYPES) { + expect(sdkTypes).not.toContain(typeName); + } + }); + + for (const type of specTypes) { + if (MISSING_SDK_TYPES.includes(type)) { + continue; // Skip missing SDK types + } + it(`${type} should have a compatibility test`, () => { + expect(testSource).toContain(`function check${type}(`); + }); + } +});