Skip to content

Commit c3717e7

Browse files
authored
Fix context injection for resources and prompts (#1336)
1 parent c47c767 commit c3717e7

File tree

12 files changed

+741
-349
lines changed

12 files changed

+741
-349
lines changed

.gitattribute

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Generated
2+
uv.lock linguist-generated=true

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ coverage.xml
5252
*.py,cover
5353
.hypothesis/
5454
.pytest_cache/
55+
.ruff_cache/
5556
cover/
5657

5758
# Translations
@@ -168,3 +169,6 @@ cython_debug/
168169
.vscode/
169170
.windsurfrules
170171
**/CLAUDE.local.md
172+
173+
# claude code
174+
.claude/

src/mcp/server/auth/handlers/authorize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def error_response(
9999
if client is None and attempt_load_client:
100100
# make last-ditch attempt to load the client
101101
client_id = best_effort_extract_string("client_id", params)
102-
client = client_id and await self.provider.get_client(client_id)
102+
client = await self.provider.get_client(client_id) if client_id else None
103103
if redirect_uri is None and client:
104104
# make last-ditch effort to load the redirect uri
105105
try:

src/mcp/server/fastmcp/prompts/base.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
"""Base classes for FastMCP prompts."""
22

3+
from __future__ import annotations
4+
35
import inspect
46
from collections.abc import Awaitable, Callable, Sequence
5-
from typing import Any, Literal
7+
from typing import TYPE_CHECKING, Any, Literal
68

79
import pydantic_core
810
from pydantic import BaseModel, Field, TypeAdapter, validate_call
911

12+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
13+
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
1014
from mcp.types import ContentBlock, TextContent
1115

16+
if TYPE_CHECKING:
17+
from mcp.server.fastmcp.server import Context
18+
from mcp.server.session import ServerSessionT
19+
from mcp.shared.context import LifespanContextT, RequestT
20+
1221

1322
class Message(BaseModel):
1423
"""Base class for all prompt messages."""
@@ -62,6 +71,7 @@ class Prompt(BaseModel):
6271
description: str | None = Field(None, description="Description of what the prompt does")
6372
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
6473
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
74+
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
6575

6676
@classmethod
6777
def from_function(
@@ -70,7 +80,8 @@ def from_function(
7080
name: str | None = None,
7181
title: str | None = None,
7282
description: str | None = None,
73-
) -> "Prompt":
83+
context_kwarg: str | None = None,
84+
) -> Prompt:
7485
"""Create a Prompt from a function.
7586
7687
The function can return:
@@ -84,8 +95,16 @@ def from_function(
8495
if func_name == "<lambda>":
8596
raise ValueError("You must provide a name for lambda functions")
8697

87-
# Get schema from TypeAdapter - will fail if function isn't properly typed
88-
parameters = TypeAdapter(fn).json_schema()
98+
# Find context parameter if it exists
99+
if context_kwarg is None:
100+
context_kwarg = find_context_parameter(fn)
101+
102+
# Get schema from func_metadata, excluding context parameter
103+
func_arg_metadata = func_metadata(
104+
fn,
105+
skip_names=[context_kwarg] if context_kwarg is not None else [],
106+
)
107+
parameters = func_arg_metadata.arg_model.model_json_schema()
89108

90109
# Convert parameters to PromptArguments
91110
arguments: list[PromptArgument] = []
@@ -109,9 +128,14 @@ def from_function(
109128
description=description or fn.__doc__ or "",
110129
arguments=arguments,
111130
fn=fn,
131+
context_kwarg=context_kwarg,
112132
)
113133

114-
async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]:
134+
async def render(
135+
self,
136+
arguments: dict[str, Any] | None = None,
137+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
138+
) -> list[Message]:
115139
"""Render the prompt with arguments."""
116140
# Validate required arguments
117141
if self.arguments:
@@ -122,8 +146,11 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]
122146
raise ValueError(f"Missing required arguments: {missing}")
123147

124148
try:
149+
# Add context to arguments if needed
150+
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
151+
125152
# Call function and check if result is a coroutine
126-
result = self.fn(**(arguments or {}))
153+
result = self.fn(**call_args)
127154
if inspect.iscoroutine(result):
128155
result = await result
129156

src/mcp/server/fastmcp/prompts/manager.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
"""Prompt management functionality."""
22

3-
from typing import Any
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
46

57
from mcp.server.fastmcp.prompts.base import Message, Prompt
68
from mcp.server.fastmcp.utilities.logging import get_logger
79

10+
if TYPE_CHECKING:
11+
from mcp.server.fastmcp.server import Context
12+
from mcp.server.session import ServerSessionT
13+
from mcp.shared.context import LifespanContextT, RequestT
14+
815
logger = get_logger(__name__)
916

1017

@@ -39,10 +46,15 @@ def add_prompt(
3946
self._prompts[prompt.name] = prompt
4047
return prompt
4148

42-
async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
49+
async def render_prompt(
50+
self,
51+
name: str,
52+
arguments: dict[str, Any] | None = None,
53+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
54+
) -> list[Message]:
4355
"""Render a prompt by name with arguments."""
4456
prompt = self.get_prompt(name)
4557
if not prompt:
4658
raise ValueError(f"Unknown prompt: {name}")
4759

48-
return await prompt.render(arguments)
60+
return await prompt.render(arguments, context=context)

src/mcp/server/fastmcp/resources/resource_manager.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
"""Resource manager functionality."""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Callable
4-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
57

68
from pydantic import AnyUrl
79

810
from mcp.server.fastmcp.resources.base import Resource
911
from mcp.server.fastmcp.resources.templates import ResourceTemplate
1012
from mcp.server.fastmcp.utilities.logging import get_logger
1113

14+
if TYPE_CHECKING:
15+
from mcp.server.fastmcp.server import Context
16+
from mcp.server.session import ServerSessionT
17+
from mcp.shared.context import LifespanContextT, RequestT
18+
1219
logger = get_logger(__name__)
1320

1421

@@ -67,7 +74,11 @@ def add_template(
6774
self._templates[template.uri_template] = template
6875
return template
6976

70-
async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
77+
async def get_resource(
78+
self,
79+
uri: AnyUrl | str,
80+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
81+
) -> Resource | None:
7182
"""Get resource by URI, checking concrete resources first, then templates."""
7283
uri_str = str(uri)
7384
logger.debug("Getting resource", extra={"uri": uri_str})
@@ -80,7 +91,7 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
8091
for template in self._templates.values():
8192
if params := template.matches(uri_str):
8293
try:
83-
return await template.create_resource(uri_str, params)
94+
return await template.create_resource(uri_str, params, context=context)
8495
except Exception as e:
8596
raise ValueError(f"Error creating resource from template: {e}")
8697

src/mcp/server/fastmcp/resources/templates.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@
55
import inspect
66
import re
77
from collections.abc import Callable
8-
from typing import Any
8+
from typing import TYPE_CHECKING, Any
99

10-
from pydantic import BaseModel, Field, TypeAdapter, validate_call
10+
from pydantic import BaseModel, Field, validate_call
1111

1212
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
13+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
14+
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
15+
16+
if TYPE_CHECKING:
17+
from mcp.server.fastmcp.server import Context
18+
from mcp.server.session import ServerSessionT
19+
from mcp.shared.context import LifespanContextT, RequestT
1320

1421

1522
class ResourceTemplate(BaseModel):
@@ -22,6 +29,7 @@ class ResourceTemplate(BaseModel):
2229
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
2330
fn: Callable[..., Any] = Field(exclude=True)
2431
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
32+
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
2533

2634
@classmethod
2735
def from_function(
@@ -32,14 +40,23 @@ def from_function(
3240
title: str | None = None,
3341
description: str | None = None,
3442
mime_type: str | None = None,
43+
context_kwarg: str | None = None,
3544
) -> ResourceTemplate:
3645
"""Create a template from a function."""
3746
func_name = name or fn.__name__
3847
if func_name == "<lambda>":
3948
raise ValueError("You must provide a name for lambda functions")
4049

41-
# Get schema from TypeAdapter - will fail if function isn't properly typed
42-
parameters = TypeAdapter(fn).json_schema()
50+
# Find context parameter if it exists
51+
if context_kwarg is None:
52+
context_kwarg = find_context_parameter(fn)
53+
54+
# Get schema from func_metadata, excluding context parameter
55+
func_arg_metadata = func_metadata(
56+
fn,
57+
skip_names=[context_kwarg] if context_kwarg is not None else [],
58+
)
59+
parameters = func_arg_metadata.arg_model.model_json_schema()
4360

4461
# ensure the arguments are properly cast
4562
fn = validate_call(fn)
@@ -52,6 +69,7 @@ def from_function(
5269
mime_type=mime_type or "text/plain",
5370
fn=fn,
5471
parameters=parameters,
72+
context_kwarg=context_kwarg,
5573
)
5674

5775
def matches(self, uri: str) -> dict[str, Any] | None:
@@ -63,9 +81,17 @@ def matches(self, uri: str) -> dict[str, Any] | None:
6381
return match.groupdict()
6482
return None
6583

66-
async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource:
84+
async def create_resource(
85+
self,
86+
uri: str,
87+
params: dict[str, Any],
88+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
89+
) -> Resource:
6790
"""Create a resource from the template with the given parameters."""
6891
try:
92+
# Add context to params if needed
93+
params = inject_context(self.fn, params, context, self.context_kwarg)
94+
6995
# Call function and check if result is a coroutine
7096
result = self.fn(**params)
7197
if inspect.iscoroutine(result):

src/mcp/server/fastmcp/server.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mcp.server.fastmcp.prompts import Prompt, PromptManager
3131
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
3232
from mcp.server.fastmcp.tools import Tool, ToolManager
33+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
3334
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
3435
from mcp.server.lowlevel.helper_types import ReadResourceContents
3536
from mcp.server.lowlevel.server import LifespanResultT
@@ -326,7 +327,8 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]:
326327
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
327328
"""Read a resource by URI."""
328329

329-
resource = await self._resource_manager.get_resource(uri)
330+
context = self.get_context()
331+
resource = await self._resource_manager.get_resource(uri, context=context)
330332
if not resource:
331333
raise ResourceError(f"Unknown resource: {uri}")
332334

@@ -510,13 +512,19 @@ async def get_weather(city: str) -> str:
510512

511513
def decorator(fn: AnyFunction) -> AnyFunction:
512514
# Check if this should be a template
515+
sig = inspect.signature(fn)
513516
has_uri_params = "{" in uri and "}" in uri
514-
has_func_params = bool(inspect.signature(fn).parameters)
517+
has_func_params = bool(sig.parameters)
515518

516519
if has_uri_params or has_func_params:
517-
# Validate that URI params match function params
520+
# Check for Context parameter to exclude from validation
521+
context_param = find_context_parameter(fn)
522+
523+
# Validate that URI params match function params (excluding context)
518524
uri_params = set(re.findall(r"{(\w+)}", uri))
519-
func_params = set(inspect.signature(fn).parameters.keys())
525+
# We need to remove the context_param from the resource function if
526+
# there is any.
527+
func_params = {p for p in sig.parameters.keys() if p != context_param}
520528

521529
if uri_params != func_params:
522530
raise ValueError(
@@ -982,7 +990,7 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -
982990
if not prompt:
983991
raise ValueError(f"Unknown prompt: {name}")
984992

985-
messages = await prompt.render(arguments)
993+
messages = await prompt.render(arguments, context=self.get_context())
986994

987995
return GetPromptResult(
988996
description=prompt.description,

src/mcp/server/fastmcp/tools/base.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import inspect
55
from collections.abc import Callable
66
from functools import cached_property
7-
from typing import TYPE_CHECKING, Any, get_origin
7+
from typing import TYPE_CHECKING, Any
88

99
from pydantic import BaseModel, Field
1010

1111
from mcp.server.fastmcp.exceptions import ToolError
12+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
1213
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
1314
from mcp.types import ToolAnnotations
1415

@@ -49,8 +50,6 @@ def from_function(
4950
structured_output: bool | None = None,
5051
) -> Tool:
5152
"""Create a Tool from a function."""
52-
from mcp.server.fastmcp.server import Context
53-
5453
func_name = name or fn.__name__
5554

5655
if func_name == "<lambda>":
@@ -60,13 +59,7 @@ def from_function(
6059
is_async = _is_async_callable(fn)
6160

6261
if context_kwarg is None:
63-
sig = inspect.signature(fn)
64-
for param_name, param in sig.parameters.items():
65-
if get_origin(param.annotation) is not None:
66-
continue
67-
if issubclass(param.annotation, Context):
68-
context_kwarg = param_name
69-
break
62+
context_kwarg = find_context_parameter(fn)
7063

7164
func_arg_metadata = func_metadata(
7265
fn,

0 commit comments

Comments
 (0)