Skip to content

feat: Support Progress Flow #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,111 @@ void testLoggingNotification(String clientType) throws InterruptedException {
mcpServer.close();
}

// ---------------------------------------
// Progress Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testProgressNotification(String clientType) throws InterruptedException {
int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress
// token
CountDownLatch latch = new CountDownLatch(expectedNotificationsCount);
// Create a list to store received logging notifications
List<McpSchema.ProgressNotification> receivedNotifications = new CopyOnWriteArrayList<>();

var clientBuilder = clientBuilders.get(clientType);

// Create server with a tool that sends logging notifications
McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
.tool(McpSchema.Tool.builder()
.name("progress-test")
.description("Test progress notifications")
.inputSchema(emptyJsonSchema)
.build())
.callHandler((exchange, request) -> {

// Create and send notifications
var progressToken = (String) request.meta().get("progressToken");

return exchange
.progressNotification(
new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started"))
.then(exchange.progressNotification(
new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data")))
.then(// Send a progress notification with another progress value
// should
exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token",
0.0, 1.0, "Another processing started")))
.then(exchange.progressNotification(
new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed")))
.thenReturn(new CallToolResult(("Progress test completed"), false));
})
.build();

var mcpServer = McpServer.async(mcpServerTransportProvider)
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.tools(tool)
.build();

try (
// Create client with progress notification handler
var mcpClient = clientBuilder.progressConsumer(notification -> {
receivedNotifications.add(notification);
latch.countDown();
}).build()) {

// Initialize client
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Call the tool that sends progress notifications
McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder()
.name("progress-test")
.meta(Map.of("progressToken", "test-progress-token"))
.build();
CallToolResult result = mcpClient.callTool(callToolRequest);
assertThat(result).isNotNull();
assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed");

assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue();

// Should have received 3 notifications
assertThat(receivedNotifications).hasSize(expectedNotificationsCount);

Map<String, McpSchema.ProgressNotification> notificationMap = receivedNotifications.stream()
.collect(Collectors.toMap(n -> n.message(), n -> n));

// First notification should be 0.0/1.0 progress
assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token");
assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0);
assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0);
assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started");

// Second notification should be 0.5/1.0 progress
assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token");
assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5);
assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0);
assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data");

// Third notification should be another progress token with 0.0/1.0 progress
assertThat(notificationMap.get("Another processing started").progressToken())
.isEqualTo("another-progress-token");
assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0);
assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0);
assertThat(notificationMap.get("Another processing started").message())
.isEqualTo("Another processing started");

// Fourth notification should be 1.0/1.0 progress
assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token");
assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0);
assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0);
assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed");
}
mcpServer.close();
}

// ---------------------------------------
// Completion Tests
// ---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ public class McpAsyncClient {
public static final TypeReference<LoggingMessageNotification> LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() {
};

public static final TypeReference<McpSchema.ProgressNotification> PROGRESS_NOTIFICATION_TYPE_REF = new TypeReference<>() {
};

/**
* Client capabilities.
*/
Expand Down Expand Up @@ -253,6 +256,16 @@ public class McpAsyncClient {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
asyncLoggingNotificationHandler(loggingConsumersFinal));

// Utility Progress Notification
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumersFinal = new ArrayList<>();
progressConsumersFinal
.add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification)));
if (!Utils.isEmpty(features.progressConsumers())) {
progressConsumersFinal.addAll(features.progressConsumers());
}
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS,
asyncProgressNotificationHandler(progressConsumersFinal));

this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo,
List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout,
ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers,
Expand Down Expand Up @@ -846,6 +859,28 @@ public Mono<Void> setLoggingLevel(LoggingLevel loggingLevel) {
});
}

/**
* Create a notification handler for progress notifications from the server. This
* handler automatically distributes progress notifications to all registered
* consumers.
* @param progressConsumers List of consumers that will be notified when a progress
* message is received. Each consumer receives the progress notification.
* @return A NotificationHandler that processes progress notifications by distributing
* the message to all registered consumers
*/
private NotificationHandler asyncProgressNotificationHandler(
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers) {

return params -> {
McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params,
PROGRESS_NOTIFICATION_TYPE_REF);

return Flux.fromIterable(progressConsumers)
.flatMap(consumer -> consumer.apply(progressNotification))
.then();
};
}

/**
* This method is package-private and used for test only. Should not be called by user
* code.
Expand Down
72 changes: 69 additions & 3 deletions mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class SyncSpec {

private final List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers = new ArrayList<>();

private final List<Consumer<McpSchema.ProgressNotification>> progressConsumers = new ArrayList<>();

private Function<CreateMessageRequest, CreateMessageResult> samplingHandler;

private Function<ElicitRequest, ElicitResult> elicitationHandler;
Expand Down Expand Up @@ -377,6 +379,36 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
return this;
}

/**
* Adds a consumer to be notified of progress notifications from the server. This
* allows the client to track long-running operations and provide feedback to
* users.
* @param progressConsumer A consumer that receives progress notifications. Must
* not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public SyncSpec progressConsumer(Consumer<McpSchema.ProgressNotification> progressConsumer) {
Assert.notNull(progressConsumer, "Progress consumer must not be null");
this.progressConsumers.add(progressConsumer);
return this;
}

/**
* Adds a multiple consumers to be notified of progress notifications from the
* server. This allows the client to track long-running operations and provide
* feedback to users.
* @param progressConsumers A list of consumers that receives progress
* notifications. Must not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public SyncSpec progressConsumers(List<Consumer<McpSchema.ProgressNotification>> progressConsumers) {
Assert.notNull(progressConsumers, "Progress consumers must not be null");
this.progressConsumers.addAll(progressConsumers);
return this;
}

/**
* Create an instance of {@link McpSyncClient} with the provided configurations or
* sensible defaults.
Expand All @@ -385,7 +417,8 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
public McpSyncClient build() {
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler, this.elicitationHandler);
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler,
this.elicitationHandler);

McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);

Expand Down Expand Up @@ -435,6 +468,8 @@ class AsyncSpec {

private final List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers = new ArrayList<>();

private final List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();

private Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler;

private Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler;
Expand Down Expand Up @@ -654,6 +689,37 @@ public AsyncSpec loggingConsumers(
return this;
}

/**
* Adds a consumer to be notified of progress notifications from the server. This
* allows the client to track long-running operations and provide feedback to
* users.
* @param progressConsumer A consumer that receives progress notifications. Must
* not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public AsyncSpec progressConsumer(Function<McpSchema.ProgressNotification, Mono<Void>> progressConsumer) {
Assert.notNull(progressConsumer, "Progress consumer must not be null");
this.progressConsumers.add(progressConsumer);
return this;
}

/**
* Adds a multiple consumers to be notified of progress notifications from the
* server. This allows the client to track long-running operations and provide
* feedback to users.
* @param progressConsumers A list of consumers that receives progress
* notifications. Must not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public AsyncSpec progressConsumers(
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers) {
Assert.notNull(progressConsumers, "Progress consumers must not be null");
this.progressConsumers.addAll(progressConsumers);
return this;
}

/**
* Create an instance of {@link McpAsyncClient} with the provided configurations
* or sensible defaults.
Expand All @@ -663,8 +729,8 @@ public McpAsyncClient build() {
return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler,
this.elicitationHandler));
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers,
this.samplingHandler, this.elicitationHandler));
}

}
Expand Down
Loading