1
1
"""Base classes for FastMCP prompts."""
2
2
3
+ from __future__ import annotations
4
+
3
5
import inspect
4
6
from collections .abc import Awaitable , Callable , Sequence
5
- from typing import Any , Literal
7
+ from typing import TYPE_CHECKING , Any , Literal
6
8
7
9
import pydantic_core
8
10
from pydantic import BaseModel , Field , TypeAdapter , validate_call
9
11
12
+ from mcp .server .fastmcp .utilities .context_injection import find_context_parameter , inject_context
13
+ from mcp .server .fastmcp .utilities .func_metadata import func_metadata
10
14
from mcp .types import ContentBlock , TextContent
11
15
16
+ if TYPE_CHECKING :
17
+ from mcp .server .fastmcp .server import Context
18
+ from mcp .server .session import ServerSessionT
19
+ from mcp .shared .context import LifespanContextT , RequestT
20
+
12
21
13
22
class Message (BaseModel ):
14
23
"""Base class for all prompt messages."""
@@ -62,6 +71,7 @@ class Prompt(BaseModel):
62
71
description : str | None = Field (None , description = "Description of what the prompt does" )
63
72
arguments : list [PromptArgument ] | None = Field (None , description = "Arguments that can be passed to the prompt" )
64
73
fn : Callable [..., PromptResult | Awaitable [PromptResult ]] = Field (exclude = True )
74
+ context_kwarg : str | None = Field (None , description = "Name of the kwarg that should receive context" , exclude = True )
65
75
66
76
@classmethod
67
77
def from_function (
@@ -70,7 +80,8 @@ def from_function(
70
80
name : str | None = None ,
71
81
title : str | None = None ,
72
82
description : str | None = None ,
73
- ) -> "Prompt" :
83
+ context_kwarg : str | None = None ,
84
+ ) -> Prompt :
74
85
"""Create a Prompt from a function.
75
86
76
87
The function can return:
@@ -84,8 +95,16 @@ def from_function(
84
95
if func_name == "<lambda>" :
85
96
raise ValueError ("You must provide a name for lambda functions" )
86
97
87
- # Get schema from TypeAdapter - will fail if function isn't properly typed
88
- parameters = TypeAdapter (fn ).json_schema ()
98
+ # Find context parameter if it exists
99
+ if context_kwarg is None :
100
+ context_kwarg = find_context_parameter (fn )
101
+
102
+ # Get schema from func_metadata, excluding context parameter
103
+ func_arg_metadata = func_metadata (
104
+ fn ,
105
+ skip_names = [context_kwarg ] if context_kwarg is not None else [],
106
+ )
107
+ parameters = func_arg_metadata .arg_model .model_json_schema ()
89
108
90
109
# Convert parameters to PromptArguments
91
110
arguments : list [PromptArgument ] = []
@@ -109,9 +128,14 @@ def from_function(
109
128
description = description or fn .__doc__ or "" ,
110
129
arguments = arguments ,
111
130
fn = fn ,
131
+ context_kwarg = context_kwarg ,
112
132
)
113
133
114
- async def render (self , arguments : dict [str , Any ] | None = None ) -> list [Message ]:
134
+ async def render (
135
+ self ,
136
+ arguments : dict [str , Any ] | None = None ,
137
+ context : Context [ServerSessionT , LifespanContextT , RequestT ] | None = None ,
138
+ ) -> list [Message ]:
115
139
"""Render the prompt with arguments."""
116
140
# Validate required arguments
117
141
if self .arguments :
@@ -122,8 +146,11 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]
122
146
raise ValueError (f"Missing required arguments: { missing } " )
123
147
124
148
try :
149
+ # Add context to arguments if needed
150
+ call_args = inject_context (self .fn , arguments or {}, context , self .context_kwarg )
151
+
125
152
# Call function and check if result is a coroutine
126
- result = self .fn (** ( arguments or {}) )
153
+ result = self .fn (** call_args )
127
154
if inspect .iscoroutine (result ):
128
155
result = await result
129
156
0 commit comments