Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
package io.modelcontextprotocol.common;

import java.util.Map;
import java.util.Optional;

import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.util.Assert;

/**
Expand All @@ -28,6 +30,21 @@ public Object get(String key) {
return this.metadata.get(key);
}

@Override
public Optional<String> lastEventId() {
return Optional.ofNullable(metadata.get(HttpHeaders.LAST_EVENT_ID)).map(Object::toString);
}

@Override
public Optional<String> sessionId() {
return Optional.ofNullable(metadata.get(HttpHeaders.MCP_SESSION_ID)).map(Object::toString);
}

@Override
public Optional<String> protocolVersion() {
return Optional.ofNullable(metadata.get(HttpHeaders.PROTOCOL_VERSION)).map(Object::toString);
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

package io.modelcontextprotocol.common;

import java.security.Principal;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;

/**
* Context associated with the transport layer. It allows to add transport-level metadata
Expand Down Expand Up @@ -43,4 +45,32 @@ static McpTransportContext create(Map<String, Object> metadata) {
*/
Object get(String key);

/**
* @return The MCP Protocl Version
*/
default Optional<String> protocolVersion() {
return Optional.empty();
}

/**
* @return The Session ID
*/
default Optional<String> sessionId() {
return Optional.empty();
}

/**
* @return The Last Event ID
*/
default Optional<String> lastEventId() {
return Optional.empty();
}

/**
* @return The Principal. it may represent the authenticated user.
*/
default Optional<Principal> principal() {
return Optional.empty();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2024-2024 the original author or authors.
*/
package io.modelcontextprotocol.server.servlet;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.ProtocolVersions;
import jakarta.servlet.http.HttpServletRequest;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/**
* {@link McpTransportContextExtractor} implementation for {@link HttpServletRequest}.
*/
public class HttpServletRequestMcpTransportContextExtractor
implements McpTransportContextExtractor<HttpServletRequest> {

@Override
public McpTransportContext extract(HttpServletRequest request) {
return McpTransportContext.create(metadata(request));
}

/**
* @param request Servlet Request
* @return Extracts Map for MCP Transport Context
*/
protected Map<String, Object> metadata(HttpServletRequest request) {
Map<String, Object> metadata = new HashMap<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth passing a size to this hash map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added f02ec6f

metadata.put(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION,
Optional.ofNullable(request.getHeader(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION))
.orElse(ProtocolVersions.MCP_2025_03_26));
Optional.ofNullable(request.getHeader(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID))
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID, v));
Optional.ofNullable(request.getHeader(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID))
.ifPresent(v -> metadata.put(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID, v));
return metadata;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/*
* Copyright 2024-2024 the original author or authors.
*/
/**
* Classes related with servlet support.
*/
package io.modelcontextprotocol.server.servlet;
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -503,8 +504,7 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

private Duration keepAliveInterval;

Expand Down Expand Up @@ -594,7 +594,8 @@ public HttpServletSseServerTransportProvider build() {
}
return new HttpServletSseServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint,
keepAliveInterval, contextExtractor);
keepAliveInterval,
contextExtractor == null ? new HttpServletRequestMcpTransportContextExtractor() : contextExtractor);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.io.PrintWriter;

import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -240,8 +241,7 @@ public static class Builder {

private String mcpEndpoint = "/mcp";

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

private Builder() {
// used by a static method
Expand Down Expand Up @@ -297,7 +297,8 @@ public Builder contextExtractor(McpTransportContextExtractor<HttpServletRequest>
public HttpServletStatelessServerTransport build() {
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
return new HttpServletStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
mcpEndpoint, contextExtractor);
mcpEndpoint,
contextExtractor == null ? new HttpServletRequestMcpTransportContextExtractor() : contextExtractor);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/*
* Copyright 2024-2024 the original author or authors.
*/

package io.modelcontextprotocol.server.transport;

import java.io.BufferedReader;
Expand All @@ -13,6 +12,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -769,8 +769,7 @@ public static class Builder {

private boolean disallowDelete = false;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

private Duration keepAliveInterval;

Expand Down Expand Up @@ -843,7 +842,8 @@ public HttpServletStreamableServerTransportProvider build() {
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
return new HttpServletStreamableServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete,
contextExtractor, keepAliveInterval);
contextExtractor == null ? new HttpServletRequestMcpTransportContextExtractor() : contextExtractor,
keepAliveInterval);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
Expand Down Expand Up @@ -91,10 +92,16 @@ public class AsyncServerMcpTransportContextIntegrationTests {
return Mono.just(builder);
};

private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = (HttpServletRequest r) -> {
var headerValue = r.getHeader(HEADER_NAME);
return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
: McpTransportContext.EMPTY;
private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = new HttpServletRequestMcpTransportContextExtractor() {
@Override
protected Map<String, Object> metadata(HttpServletRequest r) {
Map<String, Object> m = super.metadata(r);
var headerValue = r.getHeader(HEADER_NAME);
if (headerValue != null) {
m.put("server-side-header-value", headerValue);
}
return m;
}
};

private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.modelcontextprotocol.common;

import io.modelcontextprotocol.spec.HttpHeaders;
import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

class DefaultMcpTransportContextTest {

@Test
void protocolVersionNotPresent() {
var ctx = new DefaultMcpTransportContext(Collections.emptyMap());
assertFalse(ctx.protocolVersion().isPresent());
}

@Test
void sessionIdNotPresent() {
var ctx = new DefaultMcpTransportContext(Collections.emptyMap());
assertFalse(ctx.sessionId().isPresent());
}

@Test
void lastEventIdNotPresent() {
var ctx = new DefaultMcpTransportContext(Collections.emptyMap());
assertFalse(ctx.lastEventId().isPresent());
}

@Test
void protocolVersion_returnsProvidedValue() {
var ctx = new DefaultMcpTransportContext(Map.of(HttpHeaders.PROTOCOL_VERSION, "2025-01-01",
HttpHeaders.MCP_SESSION_ID, "session-123", HttpHeaders.LAST_EVENT_ID, "evt-456"));
assertEquals("2025-01-01", ctx.protocolVersion().orElseThrow());
}

@Test
void sessionId_returnsProvidedValue() {
var ctx = new DefaultMcpTransportContext(Map.of(HttpHeaders.PROTOCOL_VERSION, "2025-01-01",
HttpHeaders.MCP_SESSION_ID, "session-abc", HttpHeaders.LAST_EVENT_ID, "evt-456"));
assertEquals("session-abc", ctx.sessionId().orElseThrow());
}

@Test
void lastEventId_returnsProvidedValue() {
var ctx = new DefaultMcpTransportContext(Map.of(HttpHeaders.PROTOCOL_VERSION, "2025-01-01",
HttpHeaders.MCP_SESSION_ID, "session-abc", HttpHeaders.LAST_EVENT_ID, "evt-999"));
assertEquals("evt-999", ctx.lastEventId().orElseThrow());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport;
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
Expand Down Expand Up @@ -71,10 +72,16 @@ public class SyncServerMcpTransportContextIntegrationTests {
}
};

private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = (HttpServletRequest r) -> {
var headerValue = r.getHeader(HEADER_NAME);
return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue))
: McpTransportContext.EMPTY;
private final McpTransportContextExtractor<HttpServletRequest> serverContextExtractor = new HttpServletRequestMcpTransportContextExtractor() {
@Override
protected Map<String, Object> metadata(HttpServletRequest r) {
Map<String, Object> m = super.metadata(r);
var headerValue = r.getHeader(HEADER_NAME);
if (headerValue != null) {
m.put("server-side-header-value", headerValue);
}
return m;
}
};

private final BiFunction<McpTransportContext, McpSchema.CallToolRequest, McpSchema.CallToolResult> statelessHandler = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.server.transport.TomcatTestUtil;
import jakarta.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -98,7 +99,13 @@ public void after() {
protected void prepareClients(int port, String mcpEndpoint) {
}

static McpTransportContextExtractor<HttpServletRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
.create(Map.of("important", "value"));
static McpTransportContextExtractor<HttpServletRequest> TEST_CONTEXT_EXTRACTOR = new HttpServletRequestMcpTransportContextExtractor() {
@Override
protected Map<String, Object> metadata(HttpServletRequest r) {
Map<String, Object> m = super.metadata(r);
m.put("important", "value");
return m;
}
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
import io.modelcontextprotocol.server.servlet.HttpServletRequestMcpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
import io.modelcontextprotocol.server.transport.TomcatTestUtil;
import jakarta.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -96,7 +96,13 @@ public void after() {
protected void prepareClients(int port, String mcpEndpoint) {
}

static McpTransportContextExtractor<HttpServletRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
.create(Map.of("important", "value"));
static McpTransportContextExtractor<HttpServletRequest> TEST_CONTEXT_EXTRACTOR = new HttpServletRequestMcpTransportContextExtractor() {
@Override
protected Map<String, Object> metadata(HttpServletRequest r) {
Map<String, Object> m = super.metadata(r);
m.put("important", "value");
return m;
}
};

}
Loading
Loading