From 6f8673522fe6e2f6f7eae7b41ac003c506577bb1 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 04:27:35 +0530 Subject: [PATCH 01/15] Adding CountToken to Gemini Gemini Provides an endpoint to count token before sending an response https://ai.google.dev/api/tokens#method:-models.counttokens --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 64008622b..9ca1b9dcc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -397,6 +397,104 @@ def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: return response_schema + async def count_tokens( + self, + messages: list[ModelMessage], + model_settings: GeminiModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> usage.Usage: + check_allow_model_requests() + async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: + data = await http_response.aread() + response = _gemini_count_tokens_response_ta.validate_json(data) + return self._process_count_tokens_response(response) + + @asynccontextmanager + async def _make_count_request( + self, + messages: list[ModelMessage], + model_settings: GeminiModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[HTTPResponse]: + tools = self._get_tools(model_request_parameters) + tool_config = self._get_tool_config(model_request_parameters, tools) + sys_prompt_parts, contents = await self._message_to_gemini_content(messages) + + request_data = _GeminiCountTokensRequest(contents=contents) + if sys_prompt_parts: + request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) + if tools is not None: + request_data['tools'] = tools + if tool_config is not None: + request_data['toolConfig'] = tool_config + + generation_config = _settings_to_generation_config(model_settings) + if model_request_parameters.output_mode == 'native': + if tools: + raise UserError('Gemini does not support structured output and tools at the same time.') + generation_config['response_mime_type'] = 'application/json' + output_object = model_request_parameters.output_object + assert output_object is not None + generation_config['response_schema'] = self._map_response_schema(output_object) + elif model_request_parameters.output_mode == 'prompted' and not tools: + generation_config['response_mime_type'] = 'application/json' + + if generation_config: + request_data['generateContentRequest'] = { + 'contents': contents, + 'generationConfig': generation_config, + } + if sys_prompt_parts: + request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) + if tools is not None: + request_data['generateContentRequest']['tools'] = tools + if tool_config is not None: + request_data['generateContentRequest']['toolConfig'] = tool_config + + if gemini_safety_settings := model_settings.get('gemini_safety_settings'): + request_data['safetySettings'] = gemini_safety_settings + + if gemini_labels := model_settings.get('gemini_labels'): + if self._system == 'google-vertex': + request_data['labels'] = gemini_labels + + headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} + url = f'/models/{self._model_name}:countTokens' + + request_json = _gemini_count_tokens_request_ta.dump_json(request_data, by_alias=True) + async with self.client.stream( + 'POST', + url, + content=request_json, + headers=headers, + timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), + ) as r: + if (status_code := r.status_code) != 200: + await r.aread() + if status_code >= 400: + raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) + raise UnexpectedModelBehavior( # pragma: no cover + f'Unexpected response from gemini {status_code}', r.text) + yield r + + def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: + details: dict[str, int] = {} + if cached_content_token_count := response.get('cachedContentTokenCount'): + details['cached_content_tokens'] = cached_content_token_count + + for key, metadata_details in response.items(): + if key.endswith('TokensDetails') and metadata_details: + metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) + suffix = key.removesuffix('TokensDetails').lower() + for detail in metadata_details: + details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] + + return usage.Usage( + request_tokens=response.get('totalTokens', 0), + response_tokens=0, # countTokens does not provide response tokens + total_tokens=response.get('totalTokens', 0), + details=details, + ) def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -809,6 +907,30 @@ class _GeminiResponse(TypedDict): vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] +@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) +class _GeminiCountTokensRequest(TypedDict): + """Schema for a countTokens API request to the Gemini API. + + See for API docs. + """ + + contents: NotRequired[list[_GeminiContent]] + generateContentRequest: NotRequired[_GeminiRequest] + + +@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) +class _GeminiCountTokensResponse(TypedDict): + """Schema for the response from the Gemini countTokens API. + + See for API docs. + """ + + totalTokens: int + cachedContentTokenCount: NotRequired[int] + promptTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] + cacheTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] + + class _GeminiCandidates(TypedDict): """See .""" From 5cd88e0e77c975427726188f6480448e6aef2e5d Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 04:52:53 +0530 Subject: [PATCH 02/15] Update gemini.py added type adaptor --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 9ca1b9dcc..14a5eed1e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -929,7 +929,11 @@ class _GeminiCountTokensResponse(TypedDict): cachedContentTokenCount: NotRequired[int] promptTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] cacheTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] - + + +_gemini_count_tokens_request_ta = pydantic.TypeAdapter(_GeminiCountTokensRequest) +_gemini_count_tokens_response_ta = pydantic.TypeAdapter(_GeminiCountTokensResponse) + class _GeminiCandidates(TypedDict): """See .""" From a30234527a8dbe0b51a5c8e9752aed31648f6921 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:06:24 +0530 Subject: [PATCH 03/15] Update gemini.py Removed extra assignment --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 14a5eed1e..f2199af09 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -421,12 +421,6 @@ async def _make_count_request( sys_prompt_parts, contents = await self._message_to_gemini_content(messages) request_data = _GeminiCountTokensRequest(contents=contents) - if sys_prompt_parts: - request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) - if tools is not None: - request_data['tools'] = tools - if tool_config is not None: - request_data['toolConfig'] = tool_config generation_config = _settings_to_generation_config(model_settings) if model_request_parameters.output_mode == 'native': @@ -451,13 +445,6 @@ async def _make_count_request( if tool_config is not None: request_data['generateContentRequest']['toolConfig'] = tool_config - if gemini_safety_settings := model_settings.get('gemini_safety_settings'): - request_data['safetySettings'] = gemini_safety_settings - - if gemini_labels := model_settings.get('gemini_labels'): - if self._system == 'google-vertex': - request_data['labels'] = gemini_labels - headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} url = f'/models/{self._model_name}:countTokens' From 3b2e26ab82901c3af0e0dc23d4c00518d285ff95 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:15:41 +0530 Subject: [PATCH 04/15] Update gemini.py Linting --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index f2199af09..917fd698e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -401,8 +401,7 @@ async def count_tokens( self, messages: list[ModelMessage], model_settings: GeminiModelSettings | None, - model_request_parameters: ModelRequestParameters, - ) -> usage.Usage: + model_request_parameters: ModelRequestParameters,) -> usage.Usage: check_allow_model_requests() async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: data = await http_response.aread() @@ -415,7 +414,7 @@ async def _make_count_request( messages: list[ModelMessage], model_settings: GeminiModelSettings, model_request_parameters: ModelRequestParameters, - ) -> AsyncIterator[HTTPResponse]: + ) -> AsyncIterator[HTTPResponse]: tools = self._get_tools(model_request_parameters) tool_config = self._get_tool_config(model_request_parameters, tools) sys_prompt_parts, contents = await self._message_to_gemini_content(messages) @@ -439,7 +438,9 @@ async def _make_count_request( 'generationConfig': generation_config, } if sys_prompt_parts: - request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) + request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent( + role='user', parts=sys_prompt_parts + ) if tools is not None: request_data['generateContentRequest']['tools'] = tools if tool_config is not None: @@ -460,7 +461,7 @@ async def _make_count_request( await r.aread() if status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) - raise UnexpectedModelBehavior( # pragma: no cover + raise UnexpectedModelBehavior( # pragma: no cover f'Unexpected response from gemini {status_code}', r.text) yield r From dc4d29b91e219a9d668bab21644c4fe68cfd0ebc Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:21:12 +0530 Subject: [PATCH 05/15] Update gemini.py Linting --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 917fd698e..dadc34feb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -401,7 +401,8 @@ async def count_tokens( self, messages: list[ModelMessage], model_settings: GeminiModelSettings | None, - model_request_parameters: ModelRequestParameters,) -> usage.Usage: + model_request_parameters: ModelRequestParameters, + ) -> usage.Usage: check_allow_model_requests() async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: data = await http_response.aread() @@ -462,7 +463,8 @@ async def _make_count_request( if status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) raise UnexpectedModelBehavior( # pragma: no cover - f'Unexpected response from gemini {status_code}', r.text) + f'Unexpected response from gemini {status_code}', r.text + ) yield r def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: From 16f18dcb699c919991f075a152a84661f36c13ef Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:32:00 +0530 Subject: [PATCH 06/15] Update gemini.py --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index dadc34feb..4b3f1b8ca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -485,6 +485,7 @@ def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) - total_tokens=response.get('totalTokens', 0), details=details, ) + def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -911,7 +912,6 @@ class _GeminiCountTokensRequest(TypedDict): @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): """Schema for the response from the Gemini countTokens API. - See for API docs. """ From 24d6c25a446916ed1cf2dbe6844930061e52b5cf Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:42:23 +0530 Subject: [PATCH 07/15] Update gemini.py --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4b3f1b8ca..b23e80548 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -466,7 +466,7 @@ async def _make_count_request( f'Unexpected response from gemini {status_code}', r.text ) yield r - + def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: details: dict[str, int] = {} if cached_content_token_count := response.get('cachedContentTokenCount'): @@ -485,7 +485,7 @@ def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) - total_tokens=response.get('totalTokens', 0), details=details, ) - + def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -907,7 +907,7 @@ class _GeminiCountTokensRequest(TypedDict): contents: NotRequired[list[_GeminiContent]] generateContentRequest: NotRequired[_GeminiRequest] - + @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): From 90fc8bbaca3172cecc4218d1605b9a0596887907 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:46:24 +0530 Subject: [PATCH 08/15] Update gemini.py Linting --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index b23e80548..041a01f75 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -912,6 +912,7 @@ class _GeminiCountTokensRequest(TypedDict): @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): """Schema for the response from the Gemini countTokens API. + See for API docs. """ From 2bfc8d083c2ef1e4e7edfbbdd5dfe8813a0a2be5 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:49:18 +0530 Subject: [PATCH 09/15] Update gemini.py Removed White Space --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 041a01f75..ee6a659b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -912,7 +912,7 @@ class _GeminiCountTokensRequest(TypedDict): @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): """Schema for the response from the Gemini countTokens API. - + See for API docs. """ From bae4ca9d29792424248523e9e73e8f7b4df05e7b Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Fri, 25 Jul 2025 14:51:03 +0530 Subject: [PATCH 10/15] Enabling Request Token Count in Google (#1) * adding count token for google * resolved conflicts --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 6 ++++ pydantic_ai_slim/pydantic_ai/messages.py | 27 +++++++++++++++ .../pydantic_ai/models/__init__.py | 18 +++++++++- .../pydantic_ai/models/anthropic.py | 8 +++++ .../pydantic_ai/models/bedrock.py | 8 +++++ pydantic_ai_slim/pydantic_ai/models/cohere.py | 8 +++++ .../pydantic_ai/models/fallback.py | 9 ++++- .../pydantic_ai/models/function.py | 9 +++++ pydantic_ai_slim/pydantic_ai/models/gemini.py | 8 +++++ pydantic_ai_slim/pydantic_ai/models/google.py | 34 +++++++++++++++++++ pydantic_ai_slim/pydantic_ai/models/groq.py | 8 +++++ .../pydantic_ai/models/huggingface.py | 8 +++++ .../pydantic_ai/models/mcp_sampling.py | 9 ++++- .../pydantic_ai/models/mistral.py | 8 +++++ pydantic_ai_slim/pydantic_ai/models/openai.py | 15 ++++++++ pydantic_ai_slim/pydantic_ai/models/test.py | 7 ++++ .../pydantic_ai/models/wrapper.py | 5 ++- pydantic_ai_slim/pydantic_ai/usage.py | 6 ++++ tests/evals/test_evaluators.py | 8 ++++- tests/models/test_google.py | 27 +++++++++++++++ 20 files changed, 231 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 312a8a2fc..87a49a468 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -356,6 +356,12 @@ async def _make_request( message_history = await _process_message_history( ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) ) + + if ctx.deps.usage_limits and ctx.deps.usage_limits.pre_request_token_check_with_overhead: + token_count = await ctx.deps.model.count_tokens(message_history) + + ctx.deps.usage_limits.check_tokens(_usage.Usage(request_tokens=token_count.total_tokens)) + model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 379d70efd..afa0ec7f1 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1143,3 +1143,30 @@ def tool_call_id(self) -> str: HandleResponseEvent = Annotated[ Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('event_kind') ] + + +@dataclass(repr=False) +class BaseCountTokensResponse: + """Structured response for token count API calls from various model providers.""" + + total_tokens: int | None = field( + default=None, metadata={'description': 'Total number of tokens counted in the messages.'} + ) + """Total number of tokens counted in the messages.""" + + model_name: str | None = field(default=None) + """Name of the model that provided the token count.""" + + vendor_details: dict[str, Any] | None = field(default=None) + """Vendor-specific token count details (e.g., cached_content_token_count for Gemini).""" + + vendor_id: str | None = field(default=None) + """Vendor request ID for tracking the token count request.""" + + timestamp: datetime = field(default_factory=_now_utc) + """Timestamp of the token count response.""" + + error: str | None = field(default=None) + """Error message if the token count request failed.""" + + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6cdcbfbd6..fa8e5b726 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -24,7 +24,15 @@ from .._output import OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError -from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl +from ..messages import ( + BaseCountTokensResponse, + FileUrl, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponseStreamEvent, + VideoUrl, +) from ..output import OutputMode from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings @@ -382,6 +390,14 @@ async def request( """Make a request to the model.""" raise NotImplementedError() + @abstractmethod + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Make a request to the model.""" + raise NotImplementedError() + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 02f9111c2..361ee01f4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -13,6 +13,7 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -165,6 +166,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the AnthropicModel.""" + raise NotImplementedError('Token counting is not supported by AnthropicModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index b63ed4e1f..746e8d66d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -17,6 +17,7 @@ from pydantic_ai import _utils, usage from pydantic_ai.messages import ( AudioUrl, + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -258,6 +259,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the BedrockConverseModel.""" + raise NotImplementedError('Token counting is not supported by BedrockConverseModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 4243ef492..087d47941 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -11,6 +11,7 @@ from .. import ModelHTTPError, usage from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id from ..messages import ( + BaseCountTokensResponse, ModelMessage, ModelRequest, ModelResponse, @@ -149,6 +150,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the CohereModel.""" + raise NotImplementedError('Token counting is not supported by CohereModel') + @property def model_name(self) -> CohereModelName: """The model name.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index 4455defce..80957d452 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -13,7 +13,7 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: - from ..messages import ModelMessage, ModelResponse + from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse from ..settings import ModelSettings @@ -77,6 +77,13 @@ async def request( raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the FallbackModel.""" + raise NotImplementedError('Token counting is not supported by FallbackModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index c48873f04..f7e583173 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -16,6 +16,8 @@ from .. import _utils, usage from .._utils import PeekableAsyncStream from ..messages import ( + AudioUrl, + BaseCountTokensResponse, BinaryContent, ModelMessage, ModelRequest, @@ -139,6 +141,13 @@ async def request( response.usage.requests = 1 return response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the FunctionModel.""" + raise NotImplementedError('Token counting is not supported by FunctionModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 3202e9dac..556f1040a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -19,6 +19,7 @@ from .._output import OutputObjectDefinition from ..exceptions import UserError from ..messages import ( + BaseCountTokensResponse, BinaryContent, FileUrl, ModelMessage, @@ -158,6 +159,13 @@ async def request( response = _gemini_response_ta.validate_json(data) return self._process_response(response) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the GeminiModel.""" + raise NotImplementedError('Token counting is not supported by GeminiModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 082f5ba56..517d9cad4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -14,6 +14,7 @@ from .._output import OutputObjectDefinition from ..exceptions import UserError from ..messages import ( + BaseCountTokensResponse, BinaryContent, FileUrl, ModelMessage, @@ -48,6 +49,7 @@ from google.genai.types import ( ContentDict, ContentUnionDict, + CountTokensResponse, FunctionCallDict, FunctionCallingConfigDict, FunctionCallingConfigMode, @@ -181,6 +183,18 @@ async def request( response = await self._generate_content(messages, False, model_settings, model_request_parameters) return self._process_response(response) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + check_allow_model_requests() + _, contents = await self._map_messages(messages) + response = self.client.models.count_tokens( + model=self._model_name, + contents=contents, + ) + return self._process_count_tokens_response(response) + @asynccontextmanager async def request_stream( self, @@ -338,6 +352,26 @@ async def _process_streamed_response(self, response: AsyncIterator[GenerateConte _timestamp=first_chunk.create_time or _utils.now_utc(), ) + def _process_count_tokens_response( + self, + response: CountTokensResponse, + ) -> BaseCountTokensResponse: + """Process Gemini token count response into BaseCountTokensResponse.""" + if not hasattr(response, 'total_tokens') or response.total_tokens is None: + raise UnexpectedModelBehavior('Total tokens missing from Gemini response', str(response)) + + vendor_details: dict[str, Any] | None = None + if hasattr(response, 'cached_content_token_count'): + vendor_details = {} + vendor_details['cached_content_token_count'] = response.cached_content_token_count + + return BaseCountTokensResponse( + total_tokens=response.total_tokens, + model_name=self._model_name, + vendor_details=vendor_details if vendor_details else None, + vendor_id=getattr(response, 'request_id', None), + ) + async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: contents: list[ContentUnionDict] = [] system_parts: list[PartDict] = [] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index ffca84b44..f6f28e554 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -14,6 +14,7 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -160,6 +161,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the GroqModel.""" + raise NotImplementedError('Token counting is not supported by GroqModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 4b3c2ff40..4ad9da66c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -16,6 +16,7 @@ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( AudioUrl, + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -140,6 +141,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the HuggingFaceModel.""" + raise NotImplementedError('Token counting is not supported by HuggingFaceModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index ebfaac92d..bead387c0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast from .. import _mcp, exceptions, usage -from ..messages import ModelMessage, ModelResponse +from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse from ..settings import ModelSettings from . import Model, ModelRequestParameters, StreamedResponse @@ -70,6 +70,13 @@ async def request( f'Unexpected result from MCP sampling, expected "assistant" role, got {result.role}.' ) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the MCPSamplingModel.""" + raise NotImplementedError('Token counting is not supported by MCPSamplingModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index ca73558bc..800333efe 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -16,6 +16,7 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -167,6 +168,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the MistralModel.""" + raise NotImplementedError('Token counting is not supported by MistralModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 35dca2e03..7d4657988 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -20,6 +20,7 @@ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( AudioUrl, + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -248,6 +249,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the OpenAIModel.""" + raise NotImplementedError('Token counting is not supported by OpenAIModel') + @asynccontextmanager async def request_stream( self, @@ -672,6 +680,13 @@ async def request( ) return self._process_response(response) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the OpenAIResponsesModel.""" + raise NotImplementedError('Token counting is not supported by the OpenAIResponsesModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index eebe00d44..05061ba38 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -13,6 +13,7 @@ from .. import _utils from ..messages import ( + BaseCountTokensResponse, ModelMessage, ModelRequest, ModelResponse, @@ -112,6 +113,12 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + return BaseCountTokensResponse(total_tokens=1) + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index cc91f9c72..d05b390dc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import Any -from ..messages import ModelMessage, ModelResponse +from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse from ..profiles import ModelProfile from ..settings import ModelSettings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model @@ -29,6 +29,9 @@ def __init__(self, wrapped: Model | KnownModelName): async def request(self, *args: Any, **kwargs: Any) -> ModelResponse: return await self.wrapped.request(*args, **kwargs) + async def count_tokens(self, *args: Any, **kwargs: Any) -> BaseCountTokensResponse: + return await self.wrapped.count_tokens(*args, **kwargs) + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index c3f4c1885..39e7a15d2 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -28,6 +28,8 @@ class Usage: """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" details: dict[str, int] | None = None """Any extra details returned by the model.""" + eager_request_tokens_check: bool = False + """Any extra details returned by the model.""" def incr(self, incr_usage: Usage) -> None: """Increment the usage in place. @@ -96,6 +98,10 @@ class UsageLimits: """The maximum number of tokens allowed in responses from the model.""" total_tokens_limit: int | None = None """The maximum number of tokens allowed in requests and responses combined.""" + pre_request_token_check_with_overhead: bool = False + """If True, perform a token counting pass before sending the request to the model, + to enforce `request_tokens_limit` ahead of time. This may incur additional overhead + (from calling the model's `count_tokens` method) and is disabled by default.""" def has_token_limits(self) -> bool: """Returns `True` if this instance places any limits on token counts. diff --git a/tests/evals/test_evaluators.py b/tests/evals/test_evaluators.py index 235296c4a..cfe373b97 100644 --- a/tests/evals/test_evaluators.py +++ b/tests/evals/test_evaluators.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, TypeAdapter from pydantic_core import to_jsonable_python -from pydantic_ai.messages import ModelMessage, ModelResponse +from pydantic_ai.messages import BaseCountTokensResponse, ModelMessage, ModelResponse from pydantic_ai.models import Model, ModelRequestParameters from pydantic_ai.settings import ModelSettings @@ -125,6 +125,12 @@ async def request( ) -> ModelResponse: raise NotImplementedError + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + raise NotImplementedError + @property def model_name(self) -> str: return 'my-model' diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 7e1f372bc..80016c530 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -1393,3 +1393,30 @@ class CountryLanguage(BaseModel): ), ] ) + + +async def test_google_model_count_tokens(allow_model_requests: None, google_provider: GoogleProvider): + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + + messages = [ + ModelRequest( + parts=[ + SystemPromptPart(content='You are a helpful chatbot.', timestamp=IsDatetime()), + UserPromptPart(content='What was the temperature in London 1st January 2022?', timestamp=IsDatetime()), + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='temperature', + args={'date': '2022-01-01', 'city': 'London'}, + tool_call_id='test_id', + ) + ], + model_name='gemini-1.5-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + result = await model.count_tokens(messages) + assert result.total_tokens == snapshot(7) From 72c812554dc9b8fc53308d4fc55fa15f7b2ae96c Mon Sep 17 00:00:00 2001 From: Abhishek Kaushik <56749351+kauabh@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:11:42 +0530 Subject: [PATCH 11/15] removed extra argument --- pydantic_ai_slim/pydantic_ai/usage.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 39e7a15d2..81209ffeb 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -28,8 +28,6 @@ class Usage: """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" details: dict[str, int] | None = None """Any extra details returned by the model.""" - eager_request_tokens_check: bool = False - """Any extra details returned by the model.""" def incr(self, incr_usage: Usage) -> None: """Increment the usage in place. From fa9de61940734cf1b03cd18da8c82c20d7d239f6 Mon Sep 17 00:00:00 2001 From: Abhishek Kaushik <56749351+kauabh@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:17:37 +0530 Subject: [PATCH 12/15] updated gemini, removed redundant code --- .../pydantic_ai/models/function.py | 1 - pydantic_ai_slim/pydantic_ai/models/gemini.py | 121 +----------------- 2 files changed, 2 insertions(+), 120 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index f7e583173..67376f2b6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -16,7 +16,6 @@ from .. import _utils, usage from .._utils import PeekableAsyncStream from ..messages import ( - AudioUrl, BaseCountTokensResponse, BinaryContent, ModelMessage, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 556f1040a..4f49eb702 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -163,8 +163,8 @@ async def count_tokens( self, messages: list[ModelMessage], ) -> BaseCountTokensResponse: - """Token counting is not supported by the GeminiModel.""" - raise NotImplementedError('Token counting is not supported by GeminiModel') + """Token counting is not supported by the CohereModel.""" + raise NotImplementedError('Token counting is not supported by CohereModel') @asynccontextmanager async def request_stream( @@ -390,95 +390,6 @@ def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: return response_schema - async def count_tokens( - self, - messages: list[ModelMessage], - model_settings: GeminiModelSettings | None, - model_request_parameters: ModelRequestParameters, - ) -> usage.Usage: - check_allow_model_requests() - async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: - data = await http_response.aread() - response = _gemini_count_tokens_response_ta.validate_json(data) - return self._process_count_tokens_response(response) - - @asynccontextmanager - async def _make_count_request( - self, - messages: list[ModelMessage], - model_settings: GeminiModelSettings, - model_request_parameters: ModelRequestParameters, - ) -> AsyncIterator[HTTPResponse]: - tools = self._get_tools(model_request_parameters) - tool_config = self._get_tool_config(model_request_parameters, tools) - sys_prompt_parts, contents = await self._message_to_gemini_content(messages) - - request_data = _GeminiCountTokensRequest(contents=contents) - - generation_config = _settings_to_generation_config(model_settings) - if model_request_parameters.output_mode == 'native': - if tools: - raise UserError('Gemini does not support structured output and tools at the same time.') - generation_config['response_mime_type'] = 'application/json' - output_object = model_request_parameters.output_object - assert output_object is not None - generation_config['response_schema'] = self._map_response_schema(output_object) - elif model_request_parameters.output_mode == 'prompted' and not tools: - generation_config['response_mime_type'] = 'application/json' - - if generation_config: - request_data['generateContentRequest'] = { - 'contents': contents, - 'generationConfig': generation_config, - } - if sys_prompt_parts: - request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent( - role='user', parts=sys_prompt_parts - ) - if tools is not None: - request_data['generateContentRequest']['tools'] = tools - if tool_config is not None: - request_data['generateContentRequest']['toolConfig'] = tool_config - - headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} - url = f'/models/{self._model_name}:countTokens' - - request_json = _gemini_count_tokens_request_ta.dump_json(request_data, by_alias=True) - async with self.client.stream( - 'POST', - url, - content=request_json, - headers=headers, - timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), - ) as r: - if (status_code := r.status_code) != 200: - await r.aread() - if status_code >= 400: - raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) - raise UnexpectedModelBehavior( # pragma: no cover - f'Unexpected response from gemini {status_code}', r.text - ) - yield r - - def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: - details: dict[str, int] = {} - if cached_content_token_count := response.get('cachedContentTokenCount'): - details['cached_content_tokens'] = cached_content_token_count - - for key, metadata_details in response.items(): - if key.endswith('TokensDetails') and metadata_details: - metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) - suffix = key.removesuffix('TokensDetails').lower() - for detail in metadata_details: - details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] - - return usage.Usage( - request_tokens=response.get('totalTokens', 0), - response_tokens=0, # countTokens does not provide response tokens - total_tokens=response.get('totalTokens', 0), - details=details, - ) - def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -895,34 +806,6 @@ class _GeminiResponse(TypedDict): vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] -@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) -class _GeminiCountTokensRequest(TypedDict): - """Schema for a countTokens API request to the Gemini API. - - See for API docs. - """ - - contents: NotRequired[list[_GeminiContent]] - generateContentRequest: NotRequired[_GeminiRequest] - - -@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) -class _GeminiCountTokensResponse(TypedDict): - """Schema for the response from the Gemini countTokens API. - - See for API docs. - """ - - totalTokens: int - cachedContentTokenCount: NotRequired[int] - promptTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] - cacheTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] - - -_gemini_count_tokens_request_ta = pydantic.TypeAdapter(_GeminiCountTokensRequest) -_gemini_count_tokens_response_ta = pydantic.TypeAdapter(_GeminiCountTokensResponse) - - class _GeminiCandidates(TypedDict): """See .""" From 27b1fb6793b17c6f4ccf40754a75582e77375af1 Mon Sep 17 00:00:00 2001 From: Abhishek Kaushik <56749351+kauabh@users.noreply.github.com> Date: Thu, 31 Jul 2025 15:11:06 +0530 Subject: [PATCH 13/15] updated logic for count token --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 60 ++++++++++--------- pydantic_ai_slim/pydantic_ai/messages.py | 17 +----- .../pydantic_ai/models/__init__.py | 8 ++- .../pydantic_ai/models/anthropic.py | 8 --- .../pydantic_ai/models/bedrock.py | 8 --- pydantic_ai_slim/pydantic_ai/models/cohere.py | 8 --- .../pydantic_ai/models/fallback.py | 9 +-- .../pydantic_ai/models/function.py | 8 --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 8 --- pydantic_ai_slim/pydantic_ai/models/google.py | 38 +++++++----- pydantic_ai_slim/pydantic_ai/models/groq.py | 8 --- .../pydantic_ai/models/huggingface.py | 8 --- .../pydantic_ai/models/mcp_sampling.py | 9 +-- .../pydantic_ai/models/mistral.py | 8 --- pydantic_ai_slim/pydantic_ai/models/openai.py | 15 ----- pydantic_ai_slim/pydantic_ai/models/test.py | 7 --- .../pydantic_ai/models/wrapper.py | 5 +- pydantic_ai_slim/pydantic_ai/usage.py | 4 +- tests/evals/test_evaluators.py | 8 +-- .../test_google_model_count_tokens.yaml | 52 ++++++++++++++++ tests/models/test_google.py | 38 ++++++------ 21 files changed, 142 insertions(+), 192 deletions(-) create mode 100644 tests/models/cassettes/test_google/test_google_model_count_tokens.yaml diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index da20ec0ec..57a2493ec 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -357,11 +357,6 @@ async def _make_request( ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) ) - if ctx.deps.usage_limits and ctx.deps.usage_limits.pre_request_token_check_with_overhead: - token_count = await ctx.deps.model.count_tokens(message_history) - - ctx.deps.usage_limits.check_tokens(_usage.Usage(request_tokens=token_count.total_tokens)) - model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) @@ -373,8 +368,19 @@ async def _prepare_request( ctx.state.message_history.append(self.request) # Check usage - if ctx.deps.usage_limits: # pragma: no branch - ctx.deps.usage_limits.check_before_request(ctx.state.usage) + model_request_parameters = await _prepare_request_parameters(ctx) + if ctx.deps.usage_limits: + message_history = await _process_message_history( + ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) + ) + if ctx.deps.usage_limits.count_tokens_before_request: + token_count = await ctx.deps.model.count_tokens( + message_history, ctx.deps.model_settings, model_request_parameters + ) + ctx.state.usage.incr(token_count.to_usage()) + ctx.deps.usage_limits.check_before_request(ctx.state.usage) + else: + ctx.deps.usage_limits.check_before_request(ctx.state.usage) # Increment run_step ctx.state.run_step += 1 @@ -665,11 +671,11 @@ async def process_function_tools( # noqa: C901 for call in calls_to_run: yield _messages.FunctionToolCallEvent(call) - user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list) + user_parts: list[_messages.UserPromptPart] = [] if calls_to_run: # Run all tool tasks in parallel - tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} + parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} with ctx.deps.tracer.start_as_current_span( 'running tools', attributes={ @@ -687,16 +693,15 @@ async def process_function_tools( # noqa: C901 done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: index = tasks.index(task) - tool_part, tool_user_parts = task.result() - yield _messages.FunctionToolResultEvent(tool_part) + tool_result_part, extra_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_result_part) - tool_parts_by_index[index] = tool_part - user_parts_by_index[index] = tool_user_parts + parts_by_index[index] = [tool_result_part, *extra_parts] # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing - for k in sorted(tool_parts_by_index): - output_parts.append(tool_parts_by_index[k]) + for k in sorted(parts_by_index): + output_parts.extend(parts_by_index[k]) # Finally, we handle deferred tool calls for call in tool_calls_by_kind['deferred']: @@ -711,8 +716,7 @@ async def process_function_tools( # noqa: C901 else: yield _messages.FunctionToolCallEvent(call) - for k in sorted(user_parts_by_index): - output_parts.extend(user_parts_by_index[k]) + output_parts.extend(user_parts) if final_result: output_final_result.append(final_result) @@ -721,18 +725,18 @@ async def process_function_tools( # noqa: C901 async def _call_function_tool( tool_manager: ToolManager[DepsT], tool_call: _messages.ToolCallPart, -) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]: +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: try: tool_result = await tool_manager.handle_call(tool_call) except ToolRetryError as e: return (e.tool_retry, []) - tool_part = _messages.ToolReturnPart( + part = _messages.ToolReturnPart( tool_name=tool_call.tool_name, content=tool_result, tool_call_id=tool_call.tool_call_id, ) - user_parts: list[_messages.UserPromptPart] = [] + extra_parts: list[_messages.ModelRequestPart] = [] if isinstance(tool_result, _messages.ToolReturn): if ( @@ -748,12 +752,12 @@ async def _call_function_tool( f'Please use `content` instead.' ) - tool_part.content = tool_result.return_value # type: ignore - tool_part.metadata = tool_result.metadata + part.content = tool_result.return_value # type: ignore + part.metadata = tool_result.metadata if tool_result.content: - user_parts.append( + extra_parts.append( _messages.UserPromptPart( - content=tool_result.content, + content=list(tool_result.content), part_kind='user-prompt', ) ) @@ -771,7 +775,7 @@ def process_content(content: Any) -> Any: else: identifier = multi_modal_content_identifier(content.url) - user_parts.append( + extra_parts.append( _messages.UserPromptPart( content=[f'This is file {identifier}:', content], part_kind='user-prompt', @@ -783,11 +787,11 @@ def process_content(content: Any) -> Any: if isinstance(tool_result, list): contents = cast(list[Any], tool_result) - tool_part.content = [process_content(content) for content in contents] + part.content = [process_content(content) for content in contents] else: - tool_part.content = process_content(tool_result) + part.content = process_content(tool_result) - return (tool_part, user_parts) + return (part, extra_parts) @dataclasses.dataclass diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 03677e82f..51ac17bd0 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1159,19 +1159,8 @@ class BaseCountTokensResponse: ) """Total number of tokens counted in the messages.""" - model_name: str | None = field(default=None) - """Name of the model that provided the token count.""" - - vendor_details: dict[str, Any] | None = field(default=None) - """Vendor-specific token count details (e.g., cached_content_token_count for Gemini).""" - - vendor_id: str | None = field(default=None) - """Vendor request ID for tracking the token count request.""" - - timestamp: datetime = field(default_factory=_now_utc) - """Timestamp of the token count response.""" - - error: str | None = field(default=None) - """Error message if the token count request failed.""" + def to_usage(self) -> Usage: + """Usage object conversion for compatibility with Usage.incr.""" + return Usage(request_tokens=self.total_tokens) __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index fa8e5b726..14161a188 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -390,13 +390,15 @@ async def request( """Make a request to the model.""" raise NotImplementedError() - @abstractmethod async def count_tokens( self, messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, ) -> BaseCountTokensResponse: - """Make a request to the model.""" - raise NotImplementedError() + """Make a request to the model for counting tokens.""" + # This method is not required, but you need to implement it if you want to support token counting before making a request + raise NotImplementedError(f'Token counting API call is not supported by this {self.__class__.__name__}') @asynccontextmanager async def request_stream( diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 361ee01f4..02f9111c2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -13,7 +13,6 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( - BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -166,13 +165,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the AnthropicModel.""" - raise NotImplementedError('Token counting is not supported by AnthropicModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 746e8d66d..b63ed4e1f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -17,7 +17,6 @@ from pydantic_ai import _utils, usage from pydantic_ai.messages import ( AudioUrl, - BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -259,13 +258,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the BedrockConverseModel.""" - raise NotImplementedError('Token counting is not supported by BedrockConverseModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 087d47941..4243ef492 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -11,7 +11,6 @@ from .. import ModelHTTPError, usage from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id from ..messages import ( - BaseCountTokensResponse, ModelMessage, ModelRequest, ModelResponse, @@ -150,13 +149,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the CohereModel.""" - raise NotImplementedError('Token counting is not supported by CohereModel') - @property def model_name(self) -> CohereModelName: """The model name.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index 80957d452..4455defce 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -13,7 +13,7 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: - from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse + from ..messages import ModelMessage, ModelResponse from ..settings import ModelSettings @@ -77,13 +77,6 @@ async def request( raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the FallbackModel.""" - raise NotImplementedError('Token counting is not supported by FallbackModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 67376f2b6..c48873f04 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -16,7 +16,6 @@ from .. import _utils, usage from .._utils import PeekableAsyncStream from ..messages import ( - BaseCountTokensResponse, BinaryContent, ModelMessage, ModelRequest, @@ -140,13 +139,6 @@ async def request( response.usage.requests = 1 return response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the FunctionModel.""" - raise NotImplementedError('Token counting is not supported by FunctionModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4f49eb702..4ac07f8ad 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -19,7 +19,6 @@ from .._output import OutputObjectDefinition from ..exceptions import UserError from ..messages import ( - BaseCountTokensResponse, BinaryContent, FileUrl, ModelMessage, @@ -159,13 +158,6 @@ async def request( response = _gemini_response_ta.validate_json(data) return self._process_response(response) - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the CohereModel.""" - raise NotImplementedError('Token counting is not supported by CohereModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 517d9cad4..493bc935e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -49,6 +49,7 @@ from google.genai.types import ( ContentDict, ContentUnionDict, + CountTokensConfigDict, CountTokensResponse, FunctionCallDict, FunctionCallingConfigDict, @@ -186,12 +187,33 @@ async def request( async def count_tokens( self, messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, ) -> BaseCountTokensResponse: check_allow_model_requests() - _, contents = await self._map_messages(messages) + model_settings = cast(GoogleModelSettings, model_settings or {}) + system_instruction, contents = await self._map_messages(messages) + system_instruction = None + tools = self._get_tools(model_request_parameters) + http_options: HttpOptionsDict = { + 'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} + } + if timeout := model_settings.get('timeout'): + if isinstance(timeout, (int, float)): + http_options['timeout'] = int(1000 * timeout) + else: + raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout') + + config = CountTokensConfigDict( + http_options=http_options, + system_instruction=system_instruction, + tools=cast(list[ToolDict], tools), + ) + response = self.client.models.count_tokens( model=self._model_name, contents=contents, + config=config, ) return self._process_count_tokens_response(response) @@ -356,21 +378,9 @@ def _process_count_tokens_response( self, response: CountTokensResponse, ) -> BaseCountTokensResponse: - """Process Gemini token count response into BaseCountTokensResponse.""" if not hasattr(response, 'total_tokens') or response.total_tokens is None: raise UnexpectedModelBehavior('Total tokens missing from Gemini response', str(response)) - - vendor_details: dict[str, Any] | None = None - if hasattr(response, 'cached_content_token_count'): - vendor_details = {} - vendor_details['cached_content_token_count'] = response.cached_content_token_count - - return BaseCountTokensResponse( - total_tokens=response.total_tokens, - model_name=self._model_name, - vendor_details=vendor_details if vendor_details else None, - vendor_id=getattr(response, 'request_id', None), - ) + return BaseCountTokensResponse(total_tokens=response.total_tokens) async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: contents: list[ContentUnionDict] = [] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index f6f28e554..ffca84b44 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -14,7 +14,6 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( - BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -161,13 +160,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the GroqModel.""" - raise NotImplementedError('Token counting is not supported by GroqModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 4ad9da66c..4b3c2ff40 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -16,7 +16,6 @@ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( AudioUrl, - BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -141,13 +140,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the HuggingFaceModel.""" - raise NotImplementedError('Token counting is not supported by HuggingFaceModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index bead387c0..ebfaac92d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast from .. import _mcp, exceptions, usage -from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse +from ..messages import ModelMessage, ModelResponse from ..settings import ModelSettings from . import Model, ModelRequestParameters, StreamedResponse @@ -70,13 +70,6 @@ async def request( f'Unexpected result from MCP sampling, expected "assistant" role, got {result.role}.' ) - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the MCPSamplingModel.""" - raise NotImplementedError('Token counting is not supported by MCPSamplingModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 800333efe..ca73558bc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -16,7 +16,6 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( - BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -168,13 +167,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the MistralModel.""" - raise NotImplementedError('Token counting is not supported by MistralModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 7d4657988..35dca2e03 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -20,7 +20,6 @@ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( AudioUrl, - BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -249,13 +248,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the OpenAIModel.""" - raise NotImplementedError('Token counting is not supported by OpenAIModel') - @asynccontextmanager async def request_stream( self, @@ -680,13 +672,6 @@ async def request( ) return self._process_response(response) - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - """Token counting is not supported by the OpenAIResponsesModel.""" - raise NotImplementedError('Token counting is not supported by the OpenAIResponsesModel') - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 05061ba38..eebe00d44 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -13,7 +13,6 @@ from .. import _utils from ..messages import ( - BaseCountTokensResponse, ModelMessage, ModelRequest, ModelResponse, @@ -113,12 +112,6 @@ async def request( model_response.usage.requests = 1 return model_response - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - return BaseCountTokensResponse(total_tokens=1) - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index d05b390dc..cc91f9c72 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import Any -from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse +from ..messages import ModelMessage, ModelResponse from ..profiles import ModelProfile from ..settings import ModelSettings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model @@ -29,9 +29,6 @@ def __init__(self, wrapped: Model | KnownModelName): async def request(self, *args: Any, **kwargs: Any) -> ModelResponse: return await self.wrapped.request(*args, **kwargs) - async def count_tokens(self, *args: Any, **kwargs: Any) -> BaseCountTokensResponse: - return await self.wrapped.count_tokens(*args, **kwargs) - @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 81209ffeb..dd33d0a8e 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -96,10 +96,10 @@ class UsageLimits: """The maximum number of tokens allowed in responses from the model.""" total_tokens_limit: int | None = None """The maximum number of tokens allowed in requests and responses combined.""" - pre_request_token_check_with_overhead: bool = False + count_tokens_before_request: bool = False """If True, perform a token counting pass before sending the request to the model, to enforce `request_tokens_limit` ahead of time. This may incur additional overhead - (from calling the model's `count_tokens` method) and is disabled by default.""" + (from calling the model's `count_tokens` API before making the actual request) and is disabled by default.""" def has_token_limits(self) -> bool: """Returns `True` if this instance places any limits on token counts. diff --git a/tests/evals/test_evaluators.py b/tests/evals/test_evaluators.py index cfe373b97..235296c4a 100644 --- a/tests/evals/test_evaluators.py +++ b/tests/evals/test_evaluators.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, TypeAdapter from pydantic_core import to_jsonable_python -from pydantic_ai.messages import BaseCountTokensResponse, ModelMessage, ModelResponse +from pydantic_ai.messages import ModelMessage, ModelResponse from pydantic_ai.models import Model, ModelRequestParameters from pydantic_ai.settings import ModelSettings @@ -125,12 +125,6 @@ async def request( ) -> ModelResponse: raise NotImplementedError - async def count_tokens( - self, - messages: list[ModelMessage], - ) -> BaseCountTokensResponse: - raise NotImplementedError - @property def model_name(self) -> str: return 'my-model' diff --git a/tests/models/cassettes/test_google/test_google_model_count_tokens.yaml b/tests/models/cassettes/test_google/test_google_model_count_tokens.yaml new file mode 100644 index 000000000..36e428311 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_model_count_tokens.yaml @@ -0,0 +1,52 @@ +interactions: +- request: + body: '{"contents": [{"parts": [{"text": "The quick brown fox jumps over the lazy + dog."}]}]}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '85' + Content-Type: + - application/json + User-Agent: + - python-requests/2.32.4 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:countTokens + response: + body: + string: "{\n \"totalTokens\": 10,\n \"promptTokensDetails\": [\n {\n \"modality\": + \"TEXT\",\n \"tokenCount\": 10\n }\n ]\n}\n" + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Content-Encoding: + - gzip + Content-Type: + - application/json; charset=UTF-8 + Date: + - Thu, 31 Jul 2025 06:05:35 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1559 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 80016c530..b8e9af694 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -43,7 +43,7 @@ with try_import() as imports_successful: from google.genai.types import HarmBlockThreshold, HarmCategory - from pydantic_ai.models.google import GoogleModel, GoogleModelSettings + from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, ModelRequestParameters from pydantic_ai.providers.google import GoogleProvider pytestmark = [ @@ -1396,27 +1396,29 @@ class CountryLanguage(BaseModel): async def test_google_model_count_tokens(allow_model_requests: None, google_provider: GoogleProvider): - model = GoogleModel('gemini-1.5-flash', provider=google_provider) + model = GoogleModel('gemini-2.0-flash', provider=google_provider) messages = [ ModelRequest( parts=[ - SystemPromptPart(content='You are a helpful chatbot.', timestamp=IsDatetime()), - UserPromptPart(content='What was the temperature in London 1st January 2022?', timestamp=IsDatetime()), + SystemPromptPart(content='You are an expert', timestamp=IsDatetime()), + UserPromptPart( + content='The quick brown fox jumps over the lazydog.', + timestamp=IsDatetime(), + ), ] ), - ModelResponse( - parts=[ - ToolCallPart( - tool_name='temperature', - args={'date': '2022-01-01', 'city': 'London'}, - tool_call_id='test_id', - ) - ], - model_name='gemini-1.5-flash', - timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - ), + ModelResponse(parts=[TextPart(content="""That's a classic!""")]), ] - result = await model.count_tokens(messages) - assert result.total_tokens == snapshot(7) + result = await model.count_tokens( + messages, + model_settings={'temperature': 0.0}, + model_request_parameters=ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[], + output_mode='text', + output_object=None, + ), + ) + assert result.total_tokens == snapshot(10) From ac57d7235361bf89447de714b6c6fd13bfe9a82c Mon Sep 17 00:00:00 2001 From: Abhishek Kaushik <56749351+kauabh@users.noreply.github.com> Date: Fri, 1 Aug 2025 21:42:55 +0530 Subject: [PATCH 14/15] updated test, count token function is async --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/models/google.py | 12 +++--- ...st_google_model_usage_limit_exceeded.yaml} | 34 +++++++---------- tests/models/test_google.py | 38 +++++-------------- 4 files changed, 30 insertions(+), 56 deletions(-) rename tests/models/cassettes/test_google/{test_google_model_count_tokens.yaml => test_google_model_usage_limit_exceeded.yaml} (60%) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 57a2493ec..21078203d 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -378,7 +378,7 @@ async def _prepare_request( message_history, ctx.deps.model_settings, model_request_parameters ) ctx.state.usage.incr(token_count.to_usage()) - ctx.deps.usage_limits.check_before_request(ctx.state.usage) + ctx.deps.usage_limits.check_tokens(ctx.state.usage) else: ctx.deps.usage_limits.check_before_request(ctx.state.usage) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 763caed8b..41552c8e8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -192,9 +192,7 @@ async def count_tokens( ) -> BaseCountTokensResponse: check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) - system_instruction, contents = await self._map_messages(messages) - system_instruction = None - tools = self._get_tools(model_request_parameters) + _, contents = await self._map_messages(messages) http_options: HttpOptionsDict = { 'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} } @@ -203,14 +201,14 @@ async def count_tokens( http_options['timeout'] = int(1000 * timeout) else: raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout') - + # system_instruction and tools parameter are not supported for Count Tokens https://github.com/googleapis/python-genai/blob/038ecd3375f7c63a8ee8c1afa50cff1976343625/google/genai/models.py#L1169 config = CountTokensConfigDict( http_options=http_options, - system_instruction=system_instruction, - tools=cast(list[ToolDict], tools), + system_instruction=None, + tools=None, ) - response = self.client.models.count_tokens( + response = await self.client.aio.models.count_tokens( model=self._model_name, contents=contents, config=config, diff --git a/tests/models/cassettes/test_google/test_google_model_count_tokens.yaml b/tests/models/cassettes/test_google/test_google_model_usage_limit_exceeded.yaml similarity index 60% rename from tests/models/cassettes/test_google/test_google_model_count_tokens.yaml rename to tests/models/cassettes/test_google/test_google_model_usage_limit_exceeded.yaml index 36e428311..0d00e3d51 100644 --- a/tests/models/cassettes/test_google/test_google_model_count_tokens.yaml +++ b/tests/models/cassettes/test_google/test_google_model_usage_limit_exceeded.yaml @@ -1,39 +1,31 @@ interactions: - request: - body: '{"contents": [{"parts": [{"text": "The quick brown fox jumps over the lazy - dog."}]}]}' + body: '{"contents": [{"parts": [{"text": "The quick brown fox jumps over the lazydog."}], + "role": "user"}]}' headers: - Accept: - - '*/*' - Accept-Encoding: - - gzip, deflate - Connection: - - keep-alive - Content-Length: - - '85' Content-Type: - application/json - User-Agent: - - python-requests/2.32.4 - method: POST - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:countTokens + user-agent: + - google-genai-sdk/1.26.0 gl-python/3.12.7 + x-goog-api-client: + - google-genai-sdk/1.26.0 gl-python/3.12.7 + method: post + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:countTokens response: body: - string: "{\n \"totalTokens\": 10,\n \"promptTokensDetails\": [\n {\n \"modality\": - \"TEXT\",\n \"tokenCount\": 10\n }\n ]\n}\n" + string: "{\n \"totalTokens\": 12,\n \"promptTokensDetails\": [\n {\n \"modality\": + \"TEXT\",\n \"tokenCount\": 12\n }\n ]\n}\n" headers: Alt-Svc: - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - Content-Encoding: - - gzip Content-Type: - application/json; charset=UTF-8 Date: - - Thu, 31 Jul 2025 06:05:35 GMT + - Fri, 01 Aug 2025 15:59:25 GMT Server: - scaffolding on HTTPServer2 Server-Timing: - - gfet4t7; dur=1559 + - gfet4t7; dur=1582 Transfer-Encoding: - chunked Vary: @@ -46,6 +38,8 @@ interactions: - SAMEORIGIN X-XSS-Protection: - '0' + content-length: + - '117' status: code: 200 message: OK diff --git a/tests/models/test_google.py b/tests/models/test_google.py index b8e9af694..22354ddbd 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict +from pydantic_ai import UsageLimitExceeded from pydantic_ai.agent import Agent from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError from pydantic_ai.messages import ( @@ -36,7 +37,7 @@ VideoUrl, ) from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput -from pydantic_ai.result import Usage +from pydantic_ai.result import Usage, UsageLimits from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -1394,31 +1395,12 @@ class CountryLanguage(BaseModel): ] ) +async def test_google_model_usage_limit_exceeded(allow_model_requests: None, google_provider: GoogleProvider): + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + agent = Agent(model=model) -async def test_google_model_count_tokens(allow_model_requests: None, google_provider: GoogleProvider): - model = GoogleModel('gemini-2.0-flash', provider=google_provider) - - messages = [ - ModelRequest( - parts=[ - SystemPromptPart(content='You are an expert', timestamp=IsDatetime()), - UserPromptPart( - content='The quick brown fox jumps over the lazydog.', - timestamp=IsDatetime(), - ), - ] - ), - ModelResponse(parts=[TextPart(content="""That's a classic!""")]), - ] - result = await model.count_tokens( - messages, - model_settings={'temperature': 0.0}, - model_request_parameters=ModelRequestParameters( - function_tools=[], - allow_text_output=True, - output_tools=[], - output_mode='text', - output_object=None, - ), - ) - assert result.total_tokens == snapshot(10) + with pytest.raises(UsageLimitExceeded, match='Exceeded the request_tokens_limit of 10 \\(request_tokens=12\\)'): + await agent.run( + 'The quick brown fox jumps over the lazydog.', + usage_limits=UsageLimits(request_tokens_limit=9, count_tokens_before_request=True), + ) From 4d2e48076cb1ccc47a736f7ab7bfbf66f9b17e85 Mon Sep 17 00:00:00 2001 From: Abhishek Kaushik <56749351+kauabh@users.noreply.github.com> Date: Fri, 1 Aug 2025 21:57:44 +0530 Subject: [PATCH 15/15] corrected agent_graph, typo in test --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 51 ++++++++++---------- tests/models/test_google.py | 5 +- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 21078203d..29c81f0c5 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -356,7 +356,6 @@ async def _make_request( message_history = await _process_message_history( ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) ) - model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) @@ -368,12 +367,12 @@ async def _prepare_request( ctx.state.message_history.append(self.request) # Check usage - model_request_parameters = await _prepare_request_parameters(ctx) - if ctx.deps.usage_limits: - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + if ctx.deps.usage_limits: # pragma: no branch if ctx.deps.usage_limits.count_tokens_before_request: + model_request_parameters = await _prepare_request_parameters(ctx) + message_history = await _process_message_history( + ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) + ) token_count = await ctx.deps.model.count_tokens( message_history, ctx.deps.model_settings, model_request_parameters ) @@ -671,11 +670,11 @@ async def process_function_tools( # noqa: C901 for call in calls_to_run: yield _messages.FunctionToolCallEvent(call) - user_parts: list[_messages.UserPromptPart] = [] + user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list) if calls_to_run: # Run all tool tasks in parallel - parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} + tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} with ctx.deps.tracer.start_as_current_span( 'running tools', attributes={ @@ -693,15 +692,16 @@ async def process_function_tools( # noqa: C901 done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: index = tasks.index(task) - tool_result_part, extra_parts = task.result() - yield _messages.FunctionToolResultEvent(tool_result_part) + tool_part, tool_user_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_part) - parts_by_index[index] = [tool_result_part, *extra_parts] + tool_parts_by_index[index] = tool_part + user_parts_by_index[index] = tool_user_parts # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing - for k in sorted(parts_by_index): - output_parts.extend(parts_by_index[k]) + for k in sorted(tool_parts_by_index): + output_parts.append(tool_parts_by_index[k]) # Finally, we handle deferred tool calls for call in tool_calls_by_kind['deferred']: @@ -716,7 +716,8 @@ async def process_function_tools( # noqa: C901 else: yield _messages.FunctionToolCallEvent(call) - output_parts.extend(user_parts) + for k in sorted(user_parts_by_index): + output_parts.extend(user_parts_by_index[k]) if final_result: output_final_result.append(final_result) @@ -725,18 +726,18 @@ async def process_function_tools( # noqa: C901 async def _call_function_tool( tool_manager: ToolManager[DepsT], tool_call: _messages.ToolCallPart, -) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]: try: tool_result = await tool_manager.handle_call(tool_call) except ToolRetryError as e: return (e.tool_retry, []) - part = _messages.ToolReturnPart( + tool_part = _messages.ToolReturnPart( tool_name=tool_call.tool_name, content=tool_result, tool_call_id=tool_call.tool_call_id, ) - extra_parts: list[_messages.ModelRequestPart] = [] + user_parts: list[_messages.UserPromptPart] = [] if isinstance(tool_result, _messages.ToolReturn): if ( @@ -752,12 +753,12 @@ async def _call_function_tool( f'Please use `content` instead.' ) - part.content = tool_result.return_value # type: ignore - part.metadata = tool_result.metadata + tool_part.content = tool_result.return_value # type: ignore + tool_part.metadata = tool_result.metadata if tool_result.content: - extra_parts.append( + user_parts.append( _messages.UserPromptPart( - content=list(tool_result.content), + content=tool_result.content, part_kind='user-prompt', ) ) @@ -775,7 +776,7 @@ def process_content(content: Any) -> Any: else: identifier = multi_modal_content_identifier(content.url) - extra_parts.append( + user_parts.append( _messages.UserPromptPart( content=[f'This is file {identifier}:', content], part_kind='user-prompt', @@ -787,11 +788,11 @@ def process_content(content: Any) -> Any: if isinstance(tool_result, list): contents = cast(list[Any], tool_result) - part.content = [process_content(content) for content in contents] + tool_part.content = [process_content(content) for content in contents] else: - part.content = process_content(tool_result) + tool_part.content = process_content(tool_result) - return (part, extra_parts) + return (tool_part, user_parts) @dataclasses.dataclass diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 22354ddbd..958ff5219 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -44,7 +44,7 @@ with try_import() as imports_successful: from google.genai.types import HarmBlockThreshold, HarmCategory - from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, ModelRequestParameters + from pydantic_ai.models.google import GoogleModel, GoogleModelSettings from pydantic_ai.providers.google import GoogleProvider pytestmark = [ @@ -1395,11 +1395,12 @@ class CountryLanguage(BaseModel): ] ) + async def test_google_model_usage_limit_exceeded(allow_model_requests: None, google_provider: GoogleProvider): model = GoogleModel('gemini-2.5-flash', provider=google_provider) agent = Agent(model=model) - with pytest.raises(UsageLimitExceeded, match='Exceeded the request_tokens_limit of 10 \\(request_tokens=12\\)'): + with pytest.raises(UsageLimitExceeded, match='Exceeded the request_tokens_limit of 9 \\(request_tokens=12\\)'): await agent.run( 'The quick brown fox jumps over the lazydog.', usage_limits=UsageLimits(request_tokens_limit=9, count_tokens_before_request=True),