Skip to content
Merged
106 changes: 105 additions & 1 deletion src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js";
import { OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { JSONRPCMessage } from "../types.js";
import { JSONRPCMessage, JSONRPCRequest } from "../types.js";


describe("StreamableHTTPClientTransport", () => {
Expand Down Expand Up @@ -592,4 +592,108 @@ describe("StreamableHTTPClientTransport", () => {
await expect(transport.send(message)).rejects.toThrow(UnauthorizedError);
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
});
describe('Reconnection Logic', () => {
let transport: StreamableHTTPClientTransport;

// Use fake timers to control setTimeout and make the test instant.
beforeEach(() => jest.useFakeTimers());
afterEach(() => jest.useRealTimers());

it('should reconnect a GET-initiated notification stream that fails', async () => {
// ARRANGE
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
}
});

const errorSpy = jest.fn();
transport.onerror = errorSpy;

const failingStream = new ReadableStream({
start(controller) { controller.error(new Error("Network failure")); }
});

const fetchMock = global.fetch as jest.Mock;
// Mock the initial GET request, which will fail.
fetchMock.mockResolvedValueOnce({
ok: true, status: 200,
headers: new Headers({ "content-type": "text/event-stream" }),
body: failingStream,
});
// Mock the reconnection GET request, which will succeed.
fetchMock.mockResolvedValueOnce({
ok: true, status: 200,
headers: new Headers({ "content-type": "text/event-stream" }),
body: new ReadableStream(),
});

// ACT
await transport.start();
// Trigger the GET stream directly using the internal method for a clean test.
await transport["_startOrAuthSse"]({});
await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout

// ASSERT
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
}));
// THE KEY ASSERTION: A second fetch call proves reconnection was attempted.
expect(fetchMock).toHaveBeenCalledTimes(2);
expect(fetchMock.mock.calls[0][1]?.method).toBe('GET');
expect(fetchMock.mock.calls[1][1]?.method).toBe('GET');
});

it('should NOT reconnect a POST-initiated stream that fails', async () => {
// ARRANGE
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
}
});

const errorSpy = jest.fn();
transport.onerror = errorSpy;

const failingStream = new ReadableStream({
start(controller) { controller.error(new Error("Network failure")); }
});

const fetchMock = global.fetch as jest.Mock;
// Mock the POST request. It returns a streaming content-type but a failing body.
fetchMock.mockResolvedValueOnce({
ok: true, status: 200,
headers: new Headers({ "content-type": "text/event-stream" }),
body: failingStream,
});

// A dummy request message to trigger the `send` logic.
const requestMessage: JSONRPCRequest = {
jsonrpc: '2.0',
method: 'long_running_tool',
id: 'request-1',
params: {},
};

// ACT
await transport.start();
// Use the public `send` method to initiate a POST that gets a stream response.
await transport.send(requestMessage);
await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections

// ASSERT
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
}));
// THE KEY ASSERTION: Fetch was only called ONCE. No reconnection was attempted.
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock.mock.calls[0][1]?.method).toBe('POST');
});
});
});
36 changes: 21 additions & 15 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ const response = await (this._fetch ?? fetch)(this._url, {
);
}

this._handleSseStream(response.body, options);
this._handleSseStream(response.body, options, true);
} catch (error) {
this.onerror?.(error as Error);
throw error;
Expand Down Expand Up @@ -300,7 +300,11 @@ const response = await (this._fetch ?? fetch)(this._url, {
}, delay);
}

private _handleSseStream(stream: ReadableStream<Uint8Array> | null, options: StartSSEOptions): void {
private _handleSseStream(
stream: ReadableStream<Uint8Array> | null,
options: StartSSEOptions,
isReconnectable: boolean,
): void {
if (!stream) {
return;
}
Expand Down Expand Up @@ -347,20 +351,22 @@ const response = await (this._fetch ?? fetch)(this._url, {
this.onerror?.(new Error(`SSE stream disconnected: ${error}`));

// Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing
if (this._abortController && !this._abortController.signal.aborted) {
if (
isReconnectable &&
this._abortController &&
!this._abortController.signal.aborted
) {
// Use the exponential backoff reconnection strategy
if (lastEventId !== undefined) {
try {
this._scheduleReconnection({
resumptionToken: lastEventId,
onresumptiontoken,
replayMessageId
}, 0);
}
catch (error) {
this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`));
try {
this._scheduleReconnection({
resumptionToken: lastEventId,
onresumptiontoken,
replayMessageId
}, 0);
}
catch (error) {
this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`));

}
}
}
}
Expand Down Expand Up @@ -473,7 +479,7 @@ const response = await (this._fetch ?? fetch)(this._url, init);
// Handle SSE stream responses for requests
// We use the same handler as standalone streams, which now supports
// reconnection with the last event ID
this._handleSseStream(response.body, { onresumptiontoken });
this._handleSseStream(response.body, { onresumptiontoken }, false);
} else if (contentType?.includes("application/json")) {
// For non-streaming servers, we might get direct JSON responses
const data = await response.json();
Expand Down