From a91cfd1e76ebf86abff359e6a96bb0e484107f3f Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 26 Aug 2025 17:39:06 +0200 Subject: [PATCH 1/9] Add `cost` to `RunUsage` --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 3 +++ pydantic_ai_slim/pydantic_ai/usage.py | 12 ++++++++++++ .../test_model_names/test_known_model_names.yaml | 16 ++++++++-------- tests/models/test_openai.py | 15 +++++++-------- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 1e8beaec87..150000680c 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -10,6 +10,7 @@ from dataclasses import field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast +from genai_prices import calc_price from opentelemetry.trace import Tracer from typing_extensions import TypeGuard, TypeVar, assert_never @@ -309,6 +310,7 @@ async def stream( ) as streamed_response: self._did_stream = True ctx.state.usage.requests += 1 + ctx.state.usage.cost += calc_price(streamed_response.usage(), ctx.deps.model.model_name).total_price agent_stream = result.AgentStream[DepsT, T]( streamed_response, ctx.deps.output_schema, @@ -338,6 +340,7 @@ async def _make_request( model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.requests += 1 + ctx.state.usage.cost += calc_price(model_response.usage, ctx.deps.model.model_name).total_price return self._finish_handling(ctx, model_response) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 2879d38e6a..2823d812c1 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -3,6 +3,7 @@ import dataclasses from copy import copy from dataclasses import dataclass, fields +from decimal import Decimal from typing_extensions import deprecated, overload @@ -19,6 +20,7 @@ class UsageBase: cache_write_tokens: int = 0 """Number of tokens written to the cache.""" + cache_read_tokens: int = 0 """Number of tokens read from the cache.""" @@ -27,8 +29,10 @@ class UsageBase: input_audio_tokens: int = 0 """Number of audio input tokens.""" + cache_audio_read_tokens: int = 0 """Number of audio tokens read from the cache.""" + output_audio_tokens: int = 0 """Number of audio output tokens.""" @@ -122,17 +126,22 @@ class RunUsage(UsageBase): cache_write_tokens: int = 0 """Total number of tokens written to the cache.""" + cache_read_tokens: int = 0 """Total number of tokens read from the cache.""" input_audio_tokens: int = 0 """Total number of audio input tokens.""" + cache_audio_read_tokens: int = 0 """Total number of audio tokens read from the cache.""" output_tokens: int = 0 """Total number of text output/completion tokens.""" + cost: Decimal = Decimal('0.0') + """Total cost of the run.""" + details: dict[str, int] = dataclasses.field(default_factory=dict) """Any extra details returned by the model.""" @@ -170,6 +179,9 @@ def _incr_usage_tokens(slf: RunUsage | RequestUsage, incr_usage: RunUsage | Requ slf.cache_audio_read_tokens += incr_usage.cache_audio_read_tokens slf.output_tokens += incr_usage.output_tokens + if isinstance(slf, RunUsage) and isinstance(incr_usage, RunUsage): + slf.cost += incr_usage.cost + for key, value in incr_usage.details.items(): slf.details[key] = slf.details.get(key, 0) + value diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index 67e098800b..6ba2701391 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -109,23 +109,23 @@ interactions: parsed_body: data: - created: 0 - id: llama-4-maverick-17b-128e-instruct + id: llama-4-scout-17b-16e-instruct object: model owned_by: Cerebras - created: 0 - id: qwen-3-32b + id: gpt-oss-120b object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-instruct-2507 + id: qwen-3-235b-a22b-thinking-2507 object: model owned_by: Cerebras - created: 0 - id: llama-4-scout-17b-16e-instruct + id: qwen-3-32b object: model owned_by: Cerebras - created: 0 - id: gpt-oss-120b + id: llama-3.3-70b object: model owned_by: Cerebras - created: 0 @@ -133,15 +133,15 @@ interactions: object: model owned_by: Cerebras - created: 0 - id: llama-3.3-70b + id: llama3.1-8b object: model owned_by: Cerebras - created: 0 - id: llama3.1-8b + id: qwen-3-235b-a22b-instruct-2507 object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-thinking-2507 + id: llama-4-maverick-17b-128e-instruct object: model owned_by: Cerebras object: list diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 2fe56ee16e..982706c327 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field from datetime import datetime, timezone +from decimal import Decimal from enum import Enum from functools import cached_property from typing import Annotated, Any, Callable, Literal, Union, cast @@ -236,13 +237,7 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.output == 'world' - assert result.usage() == snapshot( - RunUsage( - requests=1, - input_tokens=2, - output_tokens=1, - ) - ) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1, cost=Decimal('0.000015'))) async def test_request_structured_response(allow_model_requests: None): @@ -423,7 +418,9 @@ async def get_location(loc_name: str) -> str: ), ] ) - assert result.usage() == snapshot(RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3)) + assert result.usage() == snapshot( + RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, cost=Decimal('0.00004625')) + ) FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] @@ -826,6 +823,7 @@ async def test_openai_audio_url_input(allow_model_requests: None, openai_api_key 'text_tokens': 72, }, requests=1, + cost=Decimal('0.0008925'), ) ) @@ -1024,6 +1022,7 @@ async def test_audio_as_binary_content_input( 'text_tokens': 9, }, requests=1, + cost=Decimal('0.00020'), ) ) From dd883fd24e507f93766619e278d371953da06534 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 26 Aug 2025 17:47:42 +0200 Subject: [PATCH 2/9] Suppress LookupError --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 22 +++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 150000680c..74d64a7b4e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -5,7 +5,7 @@ import hashlib from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager, contextmanager, suppress from contextvars import ContextVar from dataclasses import field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast @@ -310,7 +310,15 @@ async def stream( ) as streamed_response: self._did_stream = True ctx.state.usage.requests += 1 - ctx.state.usage.cost += calc_price(streamed_response.usage(), ctx.deps.model.model_name).total_price + + # If we can't calculate the price, we don't want to fail the run. + with suppress(LookupError): + ctx.state.usage.cost += calc_price( + streamed_response.usage(), + ctx.deps.model.model_name, + provider_id=streamed_response.provider_name, + genai_request_timestamp=streamed_response.timestamp, + ).total_price agent_stream = result.AgentStream[DepsT, T]( streamed_response, ctx.deps.output_schema, @@ -340,7 +348,15 @@ async def _make_request( model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.requests += 1 - ctx.state.usage.cost += calc_price(model_response.usage, ctx.deps.model.model_name).total_price + + # If we can't calculate the price, we don't want to fail the run. + with suppress(LookupError): + ctx.state.usage.cost += calc_price( + model_response.usage, + ctx.deps.model.model_name, + provider_id=model_response.provider_name, + genai_request_timestamp=model_response.timestamp, + ).total_price return self._finish_handling(ctx, model_response) From 40041e6ee4e4a3859ab79a90389b1736e39344b5 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 27 Aug 2025 10:56:53 +0200 Subject: [PATCH 3/9] Pass pipeline --- docs/agents.md | 2 +- docs/multi-agent-applications.md | 8 ++++++-- docs/output.md | 2 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 7 +------ .../test_model_names/test_known_model_names.yaml | 16 ++++++++-------- tests/models/test_anthropic.py | 4 ++++ tests/models/test_cohere.py | 2 ++ tests/models/test_gemini.py | 7 ++++--- tests/models/test_google.py | 3 +++ tests/models/test_openai.py | 6 ++---- 10 files changed, 32 insertions(+), 25 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 9a64255324..b352f23687 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -563,7 +563,7 @@ result_sync = agent.run_sync( print(result_sync.output) #> Rome print(result_sync.usage()) -#> RunUsage(input_tokens=62, output_tokens=1, requests=1) +#> RunUsage(input_tokens=62, output_tokens=1, requests=1, cost=Decimal('0.000201')) try: result_sync = agent.run_sync( diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 281316bd46..b7f17aaef7 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -53,7 +53,7 @@ result = joke_selection_agent.run_sync( print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) -#> RunUsage(input_tokens=204, output_tokens=24, requests=3) +#> RunUsage(input_tokens=204, output_tokens=24, requests=3, cost=Decimal('0.0003475')) ``` 1. The "parent" or controlling agent. @@ -144,7 +144,11 @@ async def main(): print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) # (6)! - #> RunUsage(input_tokens=309, output_tokens=32, requests=4) + """ + RunUsage( + input_tokens=309, output_tokens=32, requests=4, cost=Decimal('0.00036') + ) + """ ``` 1. Define a dataclass to hold the client and API key dependencies. diff --git a/docs/output.md b/docs/output.md index d0ba4ff06a..fe4a1e4945 100644 --- a/docs/output.md +++ b/docs/output.md @@ -24,7 +24,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> RunUsage(input_tokens=57, output_tokens=8, requests=1) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1, cost=Decimal('0.000006675')) ``` _(This example is complete, it can be run "as is")_ diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 74d64a7b4e..aa683094a3 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -351,12 +351,7 @@ async def _make_request( # If we can't calculate the price, we don't want to fail the run. with suppress(LookupError): - ctx.state.usage.cost += calc_price( - model_response.usage, - ctx.deps.model.model_name, - provider_id=model_response.provider_name, - genai_request_timestamp=model_response.timestamp, - ).total_price + ctx.state.usage.cost = model_response.price().total_price return self._finish_handling(ctx, model_response) diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index 6ba2701391..ae29ad62b8 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -109,35 +109,35 @@ interactions: parsed_body: data: - created: 0 - id: llama-4-scout-17b-16e-instruct + id: llama-3.3-70b object: model owned_by: Cerebras - created: 0 - id: gpt-oss-120b + id: qwen-3-235b-a22b-instruct-2507 object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-thinking-2507 + id: gpt-oss-120b object: model owned_by: Cerebras - created: 0 - id: qwen-3-32b + id: qwen-3-235b-a22b-thinking-2507 object: model owned_by: Cerebras - created: 0 - id: llama-3.3-70b + id: qwen-3-coder-480b object: model owned_by: Cerebras - created: 0 - id: qwen-3-coder-480b + id: llama3.1-8b object: model owned_by: Cerebras - created: 0 - id: llama3.1-8b + id: qwen-3-32b object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-instruct-2507 + id: llama-4-scout-17b-16e-instruct object: model owned_by: Cerebras - created: 0 diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 38f8adb1b4..0e80886645 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -181,6 +181,7 @@ async def test_sync_request_text_response(allow_model_requests: None): input_tokens=5, output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}, + cost=Decimal('0.000044'), ) ) # reset the index so we get the same response again @@ -194,6 +195,7 @@ async def test_sync_request_text_response(allow_model_requests: None): input_tokens=5, output_tokens=10, details={'input_tokens': 5, 'output_tokens': 10}, + cost=Decimal('0.000044'), ) ) assert result.all_messages() == snapshot( @@ -249,6 +251,7 @@ async def test_async_request_prompt_caching(allow_model_requests: None): 'cache_creation_input_tokens': 4, 'cache_read_input_tokens': 6, }, + cost=Decimal('0.00003488'), ) ) last_message = result.all_messages()[-1] @@ -273,6 +276,7 @@ async def test_async_request_text_response(allow_model_requests: None): input_tokens=3, output_tokens=5, details={'input_tokens': 3, 'output_tokens': 5}, + cost=Decimal('0.0000224'), ) ) diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 9291ff73d6..f8fe99a054 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from datetime import timezone +from decimal import Decimal from typing import Any, Union, cast import pytest @@ -158,6 +159,7 @@ async def test_request_simple_usage(allow_model_requests: None): 'input_tokens': 1, 'output_tokens': 1, }, + cost=Decimal('1.875E-7'), ) ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 8332affc78..ee0bce1125 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -7,6 +7,7 @@ from collections.abc import AsyncIterator, Callable, Sequence from dataclasses import dataclass from datetime import timezone +from decimal import Decimal from enum import IntEnum from typing import Annotated @@ -629,7 +630,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ), ] ) - assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, cost=Decimal('6.75E-7'))) result = await agent.run('Hello', message_history=result.new_messages()) assert result.output == 'Hello world' @@ -783,7 +784,7 @@ async def get_location(loc_name: str) -> str: ), ] ) - assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6)) + assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6, cost=Decimal('6.75E-7'))) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): @@ -1654,7 +1655,7 @@ async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): result = await agent.run('Test with thought') assert result.output == 'Hello from thought test' - assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2)) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=1, output_tokens=2, cost=Decimal('6.75E-7'))) @pytest.mark.vcr() diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 86ed9cc850..e67b0b6d5e 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -2,6 +2,7 @@ import datetime import os +from decimal import Decimal from typing import Any import pytest @@ -88,6 +89,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP input_tokens=7, output_tokens=11, details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, + cost=Decimal('0.000003825'), ) ) assert result.all_messages() == snapshot( @@ -149,6 +151,7 @@ async def temperature(city: str, date: datetime.date) -> str: input_tokens=224, output_tokens=35, details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35}, + cost=Decimal('0.000015525'), ) ) assert result.all_messages() == snapshot( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 982706c327..8763ee5c2a 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -237,7 +237,7 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.output == 'world' - assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1, cost=Decimal('0.000015'))) + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1)) async def test_request_structured_response(allow_model_requests: None): @@ -418,9 +418,7 @@ async def get_location(loc_name: str) -> str: ), ] ) - assert result.usage() == snapshot( - RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, cost=Decimal('0.00004625')) - ) + assert result.usage() == snapshot(RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3)) FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] From 46629d1e4000fbe6c29071542ac126343b2eba52 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 27 Aug 2025 11:09:54 +0200 Subject: [PATCH 4/9] Pass pipeline --- tests/models/test_openai_responses.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 6c838f4e09..2692a81923 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -1,5 +1,6 @@ import json from dataclasses import replace +from decimal import Decimal from typing import Any import pytest @@ -1087,5 +1088,7 @@ async def test_openai_responses_usage_without_tokens_details(allow_model_request result = await agent.run('What is 2+2?') assert result.usage() == snapshot( - RunUsage(input_tokens=14, output_tokens=9, details={'reasoning_tokens': 0}, requests=1) + RunUsage( + input_tokens=14, output_tokens=9, details={'reasoning_tokens': 0}, requests=1, cost=Decimal('0.000125') + ) ) From 30d41be90109297b68f74b1eb58dc2a90ec4bfca Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 27 Aug 2025 11:16:52 +0200 Subject: [PATCH 5/9] Pass pipeline --- .../pydantic_ai/durable_exec/temporal/_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py index 27b2d67782..5b8f1d6b08 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py @@ -49,7 +49,7 @@ def get(self) -> ModelResponse: return self.response def usage(self) -> RequestUsage: - return self.response.usage # pragma: no cover + return self.response.usage @property def model_name(self) -> str: @@ -57,11 +57,11 @@ def model_name(self) -> str: @property def provider_name(self) -> str: - return self.response.provider_name or '' # pragma: no cover + return self.response.provider_name or '' @property def timestamp(self) -> datetime: - return self.response.timestamp # pragma: no cover + return self.response.timestamp class TemporalModel(WrapperModel): From 91552634e9557087977f2a6e9b8d98408c391104 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 27 Aug 2025 23:34:04 +0200 Subject: [PATCH 6/9] Add cost function --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 38 ++++++++++++------- .../pydantic_ai/agent/__init__.py | 6 +++ pydantic_ai_slim/pydantic_ai/messages.py | 9 ++--- .../pydantic_ai/models/__init__.py | 14 +++++++ tests/models/test_anthropic.py | 2 +- tests/models/test_cohere.py | 1 + tests/models/test_gemini.py | 4 +- tests/models/test_google.py | 2 +- 8 files changed, 53 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index aa683094a3..72f87f3244 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -3,14 +3,15 @@ import asyncio import dataclasses import hashlib +import warnings from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence -from contextlib import asynccontextmanager, contextmanager, suppress +from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import field +from decimal import Decimal from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast -from genai_prices import calc_price from opentelemetry.trace import Tracer from typing_extensions import TypeGuard, TypeVar, assert_never @@ -80,6 +81,7 @@ class GraphAgentState: usage: _usage.RunUsage retries: int run_step: int + ignore_warning_cost: bool def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None: self.retries += 1 @@ -310,15 +312,8 @@ async def stream( ) as streamed_response: self._did_stream = True ctx.state.usage.requests += 1 + ctx.state.usage.cost += cost(streamed_response, ctx.state.ignore_warning_cost) - # If we can't calculate the price, we don't want to fail the run. - with suppress(LookupError): - ctx.state.usage.cost += calc_price( - streamed_response.usage(), - ctx.deps.model.model_name, - provider_id=streamed_response.provider_name, - genai_request_timestamp=streamed_response.timestamp, - ).total_price agent_stream = result.AgentStream[DepsT, T]( streamed_response, ctx.deps.output_schema, @@ -348,10 +343,7 @@ async def _make_request( model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.requests += 1 - - # If we can't calculate the price, we don't want to fail the run. - with suppress(LookupError): - ctx.state.usage.cost = model_response.price().total_price + ctx.state.usage.cost += cost(model_response, ctx.state.ignore_warning_cost) return self._finish_handling(ctx, model_response) @@ -574,6 +566,24 @@ async def _handle_text_response( return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) +def cost(response: _messages.ModelResponse | models.StreamedResponse, ignore_warning_cost: bool) -> Decimal: + # If we can't calculate the price, we don't want to fail the run. + try: + cost = response.cost().total_price + except LookupError: + # NOTE(Marcelo): We can allow some kind of hook on the provider level, which we could retrieve via + # `ctx.deps.model.provider.calculate_cost`, but I'm not sure how would the API look like. Maybe a new parameter + # on the `Provider` classes, that parameter would be a callable that receives the same parameters as `genai_prices`. + if response.model_name not in ('test', 'function') and not ignore_warning_cost: + warnings.warn( + f'The costs with provider "{response.provider_name}" and model "{response.model_name}" ' + "couldn't be calculated. Please report this on GitHub https://github.com/pydantic/genai-prices. " + 'If you want to ignore this warning, please pass the `ignore_warning_cost=True` parameter to the `Agent`.' + ) + return Decimal(0) + return cost + + def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: """Build a `RunContext` object from the current agent graph run context.""" return RunContext[DepsT]( diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index a22bf11908..fb32cff33f 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -182,6 +182,7 @@ def __init__( instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ignore_warning_cost: bool = False, ) -> None: ... @overload @@ -211,6 +212,7 @@ def __init__( instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ignore_warning_cost: bool = False, ) -> None: ... def __init__( @@ -238,6 +240,7 @@ def __init__( instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ignore_warning_cost: bool = False, **_deprecated_kwargs: Any, ): """Create an agent. @@ -358,6 +361,8 @@ def __init__( self._event_stream_handler = event_stream_handler + self._ignore_warning_cost = ignore_warning_cost + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( @@ -579,6 +584,7 @@ async def main(): usage=usage, retries=0, run_step=0, + ignore_warning_cost=self._ignore_warning_cost, ) # Merge model settings in order of precedence: run > agent > model diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 01dc44d55c..84da19a25a 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -15,10 +15,7 @@ from typing_extensions import TypeAlias, deprecated from . import _otel_messages, _utils -from ._utils import ( - generate_tool_call_id as _generate_tool_call_id, - now_utc as _now_utc, -) +from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc from .exceptions import UnexpectedModelBehavior from .usage import RequestUsage @@ -941,8 +938,8 @@ class ModelResponse: provider_request_id: str | None = None """request ID as specified by the model provider. This can be used to track the specific request to the model.""" - def price(self) -> genai_types.PriceCalculation: - """Calculate the price of the usage. + def cost(self) -> genai_types.PriceCalculation: + """Calculate the cost of the usage. Uses [`genai-prices`](https://github.com/pydantic/genai-prices). """ diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 907e781ad1..16a64e42f8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -16,6 +16,7 @@ from typing import Any, Generic, TypeVar, overload import httpx +from genai_prices import calc_price, types as genai_types from typing_extensions import Literal, TypeAliasType, TypedDict from .. import _utils @@ -613,6 +614,19 @@ def usage(self) -> RequestUsage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage + def cost(self) -> genai_types.PriceCalculation: + """Calculate the cost of the usage. + + Uses [`genai-prices`](https://github.com/pydantic/genai-prices). + """ + assert self.model_name, 'Model name is required to calculate price' + return calc_price( + self._usage, + self.model_name, + provider_id=self.provider_name, + genai_request_timestamp=self.timestamp, + ) + @property @abstractmethod def model_name(self) -> str: diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 0e80886645..e58bfd0004 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -256,7 +256,7 @@ async def test_async_request_prompt_caching(allow_model_requests: None): ) last_message = result.all_messages()[-1] assert isinstance(last_message, ModelResponse) - assert last_message.price().total_price == snapshot(Decimal('0.00003488')) + assert last_message.cost().total_price == snapshot(Decimal('0.00003488')) async def test_async_request_text_response(allow_model_requests: None): diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index f8fe99a054..e62013ced4 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -333,6 +333,7 @@ async def get_location(loc_name: str) -> str: input_tokens=5, output_tokens=3, details={'input_tokens': 4, 'output_tokens': 2}, + cost=Decimal('6.375E-7'), ) ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ee0bce1125..326c1772af 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -784,7 +784,9 @@ async def get_location(loc_name: str) -> str: ), ] ) - assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6, cost=Decimal('6.75E-7'))) + assert result.usage() == snapshot( + RunUsage(requests=3, input_tokens=3, output_tokens=6, cost=Decimal('0.000002025')) + ) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): diff --git a/tests/models/test_google.py b/tests/models/test_google.py index e67b0b6d5e..0b0178ef5b 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -151,7 +151,7 @@ async def temperature(city: str, date: datetime.date) -> str: input_tokens=224, output_tokens=35, details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35}, - cost=Decimal('0.000015525'), + cost=Decimal('0.000027300'), ) ) assert result.all_messages() == snapshot( From 5943f56e543b750c4bf2fe51b2c6d761cacbec56 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 28 Aug 2025 08:29:10 -0400 Subject: [PATCH 7/9] fix more tests --- .../pydantic_ai/models/function.py | 7 +- pydantic_ai_slim/pydantic_ai/models/test.py | 2 - .../test_known_model_names.yaml | 18 ++--- .../test_request_simple_usage.yaml | 80 +++++++++++++++++++ tests/models/mock_openai.py | 4 +- tests/models/test_openai.py | 54 +++++++++---- tests/models/test_openai_responses.py | 2 +- 7 files changed, 134 insertions(+), 33 deletions(-) create mode 100644 tests/models/cassettes/test_openai/test_request_simple_usage.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 9c1936e1cb..16a1c49ec2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -135,6 +135,7 @@ async def request( assert isinstance(response_, ModelResponse), response_ response = response_ response.model_name = self._model_name + response.provider_name = self._system # Add usage data if not already present if not response.usage.has_values(): # pragma: no branch response.usage = _estimate_usage(chain(messages, [response])) @@ -169,6 +170,7 @@ async def request_stream( model_request_parameters=model_request_parameters, _model_name=self._model_name, _iter=response_stream, + _provider_name=self._system, ) @property @@ -261,6 +263,7 @@ class FunctionStreamedResponse(StreamedResponse): _model_name: str _iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls] _timestamp: datetime = field(default_factory=_utils.now_utc) + _provider_name: str def __post_init__(self): self._usage += _estimate_usage([]) @@ -305,9 +308,9 @@ def model_name(self) -> str: return self._model_name @property - def provider_name(self) -> None: + def provider_name(self) -> str: """Get the provider name.""" - return None + return self._provider_name @property def timestamp(self) -> datetime: diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 60d9ca19dd..1006d6a9e6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -100,8 +100,6 @@ def __init__( self.custom_output_args = custom_output_args self.seed = seed self.last_model_request_parameters = None - self._model_name = 'test' - self._system = 'test' super().__init__(settings=settings, profile=profile) async def request( diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index ae29ad62b8..85d5d2e198 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -109,39 +109,39 @@ interactions: parsed_body: data: - created: 0 - id: llama-3.3-70b + id: qwen-3-235b-a22b-thinking-2507 object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-instruct-2507 + id: qwen-3-coder-480b object: model owned_by: Cerebras - created: 0 - id: gpt-oss-120b + id: llama3.1-8b object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-thinking-2507 + id: llama-4-scout-17b-16e-instruct object: model owned_by: Cerebras - created: 0 - id: qwen-3-coder-480b + id: qwen-3-235b-a22b-instruct-2507 object: model owned_by: Cerebras - created: 0 - id: llama3.1-8b + id: qwen-3-32b object: model owned_by: Cerebras - created: 0 - id: qwen-3-32b + id: llama-3.3-70b object: model owned_by: Cerebras - created: 0 - id: llama-4-scout-17b-16e-instruct + id: llama-4-maverick-17b-128e-instruct object: model owned_by: Cerebras - created: 0 - id: llama-4-maverick-17b-128e-instruct + id: gpt-oss-120b object: model owned_by: Cerebras object: list diff --git a/tests/models/cassettes/test_openai/test_request_simple_usage.yaml b/tests/models/cassettes/test_openai/test_request_simple_usage.yaml new file mode 100644 index 0000000000..7f42cf0ed5 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_request_simple_usage.yaml @@ -0,0 +1,80 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '100' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: Hello! How are you doing? + role: user + model: gpt-4o + stream: false + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '927' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '825' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: Hello! I'm just a computer program, so I don't have feelings, but I'm here and ready to help you. How can + I assist you today? + refusal: null + role: assistant + created: 1756380852 + id: chatcmpl-C9VBMiEq0GYAxsZn9U6FURAbS9Lau + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_07871e2ad8 + usage: + completion_tokens: 30 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 14 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 44 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/mock_openai.py b/tests/models/mock_openai.py index d6aebe116f..f57b8b30be 100644 --- a/tests/models/mock_openai.py +++ b/tests/models/mock_openai.py @@ -87,7 +87,7 @@ def completion_message( id='123', choices=choices, created=1704067200, # 2024-01-01 - model='gpt-4o-123', + model='gpt-4o', object='chat.completion', usage=usage, ) @@ -150,7 +150,7 @@ def response_message( ) -> responses.Response: return responses.Response( id='123', - model='gpt-4o-123', + model='gpt-4o', object='response', created_at=1704067200, # 2024-01-01 output=list(output_items), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 2ee02df53f..1a627a3ad5 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -87,6 +87,12 @@ pytest.mark.skipif(not imports_successful(), reason='openai not installed'), pytest.mark.anyio, pytest.mark.vcr, + # TODO(Marcelo): genai-prices needs to include Cerebras prices: https://github.com/pydantic/genai-prices/issues/132 + pytest.mark.filterwarnings('ignore:The costs with provider "cerebras" and model:UserWarning'), + # NOTE(Marcelo): The following model is old, so we are probably not including it on `genai-prices`. + pytest.mark.filterwarnings( + 'ignore:The costs with provider "openai" and model "gpt-4o-search-preview-2025-03-11":UserWarning' + ), ] @@ -120,7 +126,7 @@ async def test_request_simple_success(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - model_name='gpt-4o-123', + model_name='gpt-4o', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), provider_name='openai', provider_request_id='123', @@ -128,7 +134,7 @@ async def test_request_simple_success(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='world')], - model_name='gpt-4o-123', + model_name='gpt-4o', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), provider_name='openai', provider_request_id='123', @@ -155,18 +161,28 @@ async def test_request_simple_success(allow_model_requests: None): ] -async def test_request_simple_usage(allow_model_requests: None): - c = completion_message( - ChatCompletionMessage(content='world', role='assistant'), - usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), - ) - mock_client = MockOpenAI.create_mock(c) - m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) +async def test_request_simple_usage(allow_model_requests: None, openai_api_key: str): + m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent(m) - result = await agent.run('Hello') - assert result.output == 'world' - assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1)) + result = await agent.run('Hello! How are you doing?') + assert result.output == snapshot( + "Hello! I'm just a computer program, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?" + ) + assert result.usage() == snapshot( + RunUsage( + requests=1, + input_tokens=14, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + }, + output_tokens=30, + cost=Decimal('0.000335'), + ) + ) async def test_request_structured_response(allow_model_requests: None): @@ -200,7 +216,7 @@ async def test_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - model_name='gpt-4o-123', + model_name='gpt-4o', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), provider_name='openai', provider_request_id='123', @@ -295,7 +311,7 @@ async def get_location(loc_name: str) -> str: cache_read_tokens=1, output_tokens=1, ), - model_name='gpt-4o-123', + model_name='gpt-4o', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), provider_name='openai', provider_request_id='123', @@ -323,7 +339,7 @@ async def get_location(loc_name: str) -> str: cache_read_tokens=2, output_tokens=2, ), - model_name='gpt-4o-123', + model_name='gpt-4o', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), provider_name='openai', provider_request_id='123', @@ -340,14 +356,16 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - model_name='gpt-4o-123', + model_name='gpt-4o', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), provider_name='openai', provider_request_id='123', ), ] ) - assert result.usage() == snapshot(RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3)) + assert result.usage() == snapshot( + RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, cost=Decimal('0.00004625')) + ) FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] @@ -2273,6 +2291,8 @@ def test_model_profile_strict_not_supported(): ) +# NOTE(Marcelo): You wouldn't do this because you'd use the GoogleModel. I'm unsure if this test brings any value. +@pytest.mark.filterwarnings('ignore:The costs with provider "openai" and model "gemini-2.5-pro:UserWarning') async def test_compatible_api_with_tool_calls_without_id(allow_model_requests: None, gemini_api_key: str): provider = OpenAIProvider( openai_client=AsyncOpenAI( diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index beeb5c5a21..047dfea9a1 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -1115,7 +1115,7 @@ async def test_openai_responses_usage_without_tokens_details(allow_model_request ModelResponse( parts=[TextPart(content='4')], usage=RequestUsage(input_tokens=14, output_tokens=1, details={'reasoning_tokens': 0}), - model_name='gpt-4o-123', + model_name='gpt-4o', timestamp=IsDatetime(), provider_name='openai', provider_request_id='123', From 0969e0afe0d9de639eb16c3bb5db942b35faacfe Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 1 Sep 2025 15:40:24 +0200 Subject: [PATCH 8/9] fix tests --- .../pydantic_ai/agent/__init__.py | 2 +- tests/models/test_fallback.py | 12 +++++ tests/models/test_model_function.py | 7 +++ tests/models/test_openai_responses.py | 2 +- tests/test_a2a.py | 1 + tests/test_agent.py | 44 +++++++++++++++++++ tests/test_history_processor.py | 4 ++ tests/test_streaming.py | 10 +++++ tests/test_tools.py | 8 ++++ 9 files changed, 88 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 2779860854..4e0d663a0b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -245,7 +245,7 @@ def __init__( instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, - ignore_warning_cost: bool = False, + ignore_warning_cost: bool = True, **_deprecated_kwargs: Any, ): """Create an agent. diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 484a73ac37..a99cae3ff0 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -68,6 +68,7 @@ def test_first_successful() -> None: usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -93,6 +94,7 @@ def test_first_failed() -> None: usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -119,6 +121,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:success_response:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -204,18 +207,21 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelResponse( parts=[TextPart(content='hello world')], usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelResponse( parts=[TextPart(content='hello world')], usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -431,18 +437,21 @@ async def test_first_success_streaming() -> None: usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelResponse( parts=[TextPart(content='hello world')], usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelResponse( parts=[TextPart(content='hello world')], usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -460,18 +469,21 @@ async def test_first_failed_streaming() -> None: usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelResponse( parts=[TextPart(content='hello world')], usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelResponse( parts=[TextPart(content='hello world')], usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::success_response_stream', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index c51d9edf1f..8b6ab39cf0 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -70,6 +70,7 @@ def test_simple(): usage=RequestUsage(input_tokens=51, output_tokens=3), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -84,6 +85,7 @@ def test_simple(): usage=RequestUsage(input_tokens=51, output_tokens=3), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest(parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -91,6 +93,7 @@ def test_simple(): usage=RequestUsage(input_tokens=52, output_tokens=6), model_name='function:return_last:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -161,6 +164,7 @@ def test_weather(): usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -177,6 +181,7 @@ def test_weather(): usage=RequestUsage(input_tokens=56, output_tokens=11), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -193,6 +198,7 @@ def test_weather(): usage=RequestUsage(input_tokens=57, output_tokens=14), model_name='function:weather_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -455,6 +461,7 @@ async def test_stream_text(): usage=RequestUsage(input_tokens=50, output_tokens=2), model_name='function::stream_text_function', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index ec27afc738..f41c395d8f 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -1125,6 +1125,6 @@ async def test_openai_responses_usage_without_tokens_details(allow_model_request assert result.usage() == snapshot( RunUsage( - input_tokens=14, output_tokens=9, details={'reasoning_tokens': 0}, requests=1, cost=Decimal('0.000125') + input_tokens=14, output_tokens=1, details={'reasoning_tokens': 0}, requests=1, cost=Decimal('0.000045') ) ) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index f72227f8bc..fbf4ccd9dd 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -615,6 +615,7 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon usage=RequestUsage(input_tokens=52, output_tokens=7), model_name='function:track_messages:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ diff --git a/tests/test_agent.py b/tests/test_agent.py index dc2c004ec7..f3e951cb65 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -115,6 +115,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -138,6 +139,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse usage=RequestUsage(input_tokens=87, output_tokens=14), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -238,6 +240,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -254,6 +257,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: usage=RequestUsage(input_tokens=63, output_tokens=14), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -296,6 +300,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -313,6 +318,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=74, output_tokens=8), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -856,6 +862,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=53, output_tokens=7), model_name='function:call_tool:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -878,6 +885,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=68, output_tokens=13), model_name='function:call_tool:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -932,6 +940,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=53, output_tokens=3), model_name='function:call_tool:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -947,6 +956,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=70, output_tokens=5), model_name='function:call_tool:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1120,6 +1130,7 @@ def say_world(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:say_world:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1180,6 +1191,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes usage=RequestUsage(input_tokens=52, output_tokens=6), model_name='function:call_handoff_tool:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1215,6 +1227,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes usage=RequestUsage(input_tokens=57, output_tokens=6), model_name='function:call_tool:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1428,6 +1441,7 @@ class CityLocation(BaseModel): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='function:return_city_location:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1467,6 +1481,7 @@ class Foo(BaseModel): usage=RequestUsage(input_tokens=56, output_tokens=4), model_name='function:return_foo:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1541,6 +1556,7 @@ def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=53, output_tokens=17), model_name='function:return_foo_bar:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1582,6 +1598,7 @@ class CityLocation(BaseModel): usage=RequestUsage(input_tokens=56, output_tokens=5), model_name='function:return_city_location:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1604,6 +1621,7 @@ class CityLocation(BaseModel): usage=RequestUsage(input_tokens=85, output_tokens=12), model_name='function:return_city_location:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1665,6 +1683,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=53, output_tokens=6), model_name='function:call_tool:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1680,6 +1699,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=70, output_tokens=11), model_name='function:call_tool:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -2046,6 +2066,7 @@ def test_tool() -> str: usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='function:simple_response:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -2130,6 +2151,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=51, output_tokens=2), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -2146,6 +2168,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=65, output_tokens=4), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -2170,6 +2193,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=51, output_tokens=2), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -2186,6 +2210,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=65, output_tokens=3), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -2484,6 +2509,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover usage=RequestUsage(input_tokens=53, output_tokens=27), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -2592,6 +2618,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover usage=RequestUsage(input_tokens=58, output_tokens=22), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -2750,6 +2777,7 @@ async def get_location(loc_name: str) -> str: usage=RequestUsage(input_tokens=51, output_tokens=6), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -2766,6 +2794,7 @@ async def get_location(loc_name: str) -> str: usage=RequestUsage(input_tokens=56, output_tokens=8), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -3495,6 +3524,7 @@ def my_tool(x: int) -> int: usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -3511,6 +3541,7 @@ def my_tool(x: int) -> int: usage=RequestUsage(input_tokens=52, output_tokens=10), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -3524,6 +3555,7 @@ def my_tool(x: int) -> int: usage=RequestUsage(input_tokens=53, output_tokens=10), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -3663,6 +3695,7 @@ def analyze_data() -> ToolReturn: usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -3688,6 +3721,7 @@ def analyze_data() -> ToolReturn: usage=RequestUsage(input_tokens=70, output_tokens=6), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -3738,6 +3772,7 @@ def analyze_data() -> ToolReturn: usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -3755,6 +3790,7 @@ def analyze_data() -> ToolReturn: usage=RequestUsage(input_tokens=58, output_tokens=6), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -4031,6 +4067,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=57, output_tokens=2), model_name='function:respond:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -4047,6 +4084,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=60, output_tokens=4), model_name='function:respond:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -4063,6 +4101,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: usage=RequestUsage(input_tokens=63, output_tokens=5), model_name='function:respond:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -4355,6 +4394,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon usage=RequestUsage(input_tokens=57, output_tokens=6), model_name='function:model_function:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -4370,6 +4410,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon usage=RequestUsage(input_tokens=75, output_tokens=8), model_name='function:model_function:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -4437,6 +4478,7 @@ def create_file(path: str, content: str) -> str: usage=RequestUsage(input_tokens=60, output_tokens=23), model_name='function:model_function:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -4492,6 +4534,7 @@ def create_file(path: str, content: str) -> str: usage=RequestUsage(input_tokens=60, output_tokens=23), model_name='function:model_function:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -4520,6 +4563,7 @@ def create_file(path: str, content: str) -> str: usage=RequestUsage(input_tokens=78, output_tokens=24), model_name='function:model_function:', timestamp=IsDatetime(), + provider_name='function', ), ] ) diff --git a/tests/test_history_processor.py b/tests/test_history_processor.py index 1aa138935e..3753c3e6c4 100644 --- a/tests/test_history_processor.py +++ b/tests/test_history_processor.py @@ -66,6 +66,7 @@ def no_op_history_processor(messages: list[ModelMessage]) -> list[ModelMessage]: usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -104,6 +105,7 @@ def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage] usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -138,6 +140,7 @@ def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage] usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -169,6 +172,7 @@ def capture_messages_processor(messages: list[ModelMessage]) -> list[ModelMessag usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), + provider_name='function', ), ] ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index fb54483d5c..b68c8e1dca 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -339,6 +339,7 @@ async def ret_a(x: str) -> str: usage=RequestUsage(input_tokens=50, output_tokens=5), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -361,6 +362,7 @@ async def ret_a(x: str) -> str: usage=RequestUsage(input_tokens=50, output_tokens=5), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -383,6 +385,7 @@ async def ret_a(x: str) -> str: usage=RequestUsage(input_tokens=50, output_tokens=7), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -436,6 +439,7 @@ async def ret_a(x: str) -> str: # pragma: no cover usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::stream_structured_function', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ] ) @@ -492,6 +496,7 @@ def another_tool(y: int) -> int: # pragma: no cover usage=RequestUsage(input_tokens=50, output_tokens=10), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -548,6 +553,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt usage=RequestUsage(input_tokens=50, output_tokens=8), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -615,6 +621,7 @@ def another_tool(y: int) -> int: usage=RequestUsage(input_tokens=50, output_tokens=18), model_name='function::sf', timestamp=IsNow(tz=timezone.utc), + provider_name='function', ), ModelRequest( parts=[ @@ -724,6 +731,7 @@ def another_tool(y: int) -> int: # pragma: no cover usage=RequestUsage(input_tokens=50, output_tokens=14), model_name='function::sf', timestamp=IsNow(tz=datetime.timezone.utc), + provider_name='function', kind='response', ), ModelRequest( @@ -1255,6 +1263,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: usage=RequestUsage(input_tokens=50, output_tokens=3), model_name='function::llm', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1271,6 +1280,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function::llm', timestamp=IsDatetime(), + provider_name='function', ), ] ) diff --git a/tests/test_tools.py b/tests/test_tools.py index eba764376f..44f8e22f2f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1321,6 +1321,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1337,6 +1338,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1479,6 +1481,7 @@ def buy(fruit: str): usage=RequestUsage(input_tokens=68, output_tokens=30), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1564,6 +1567,7 @@ def buy(fruit: str): usage=RequestUsage(input_tokens=68, output_tokens=30), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1624,6 +1628,7 @@ def buy(fruit: str): usage=RequestUsage(input_tokens=124, output_tokens=31), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ] ) @@ -1715,6 +1720,7 @@ def bar(x: int) -> int: usage=RequestUsage(input_tokens=51, output_tokens=12), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1765,6 +1771,7 @@ def bar(x: int) -> int: usage=RequestUsage(input_tokens=51, output_tokens=12), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ModelRequest( parts=[ @@ -1793,6 +1800,7 @@ def bar(x: int) -> int: usage=RequestUsage(input_tokens=59, output_tokens=13), model_name='function:llm:', timestamp=IsDatetime(), + provider_name='function', ), ] ) From 76390226e754cd349fe5d9b35092333643ee8afd Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 1 Sep 2025 15:51:03 +0200 Subject: [PATCH 9/9] fix tests --- docs/agents.md | 6 +++++- docs/direct.md | 1 + docs/message-history.md | 6 ++++++ docs/multi-agent-applications.md | 8 ++------ docs/output.md | 2 +- docs/tools.md | 7 +++++++ pydantic_ai_slim/pydantic_ai/agent/__init__.py | 1 + pydantic_ai_slim/pydantic_ai/agent/abstract.py | 1 + pydantic_ai_slim/pydantic_ai/agent/wrapper.py | 1 + pydantic_ai_slim/pydantic_ai/direct.py | 2 ++ .../pydantic_ai/durable_exec/temporal/_agent.py | 1 + pydantic_ai_slim/pydantic_ai/run.py | 2 ++ 12 files changed, 30 insertions(+), 8 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 9c91b59938..883608a16f 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -302,6 +302,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')), @@ -365,6 +366,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')), @@ -557,7 +559,7 @@ result_sync = agent.run_sync( print(result_sync.output) #> Rome print(result_sync.usage()) -#> RunUsage(input_tokens=62, output_tokens=1, requests=1, cost=Decimal('0.000201')) +#> RunUsage(input_tokens=62, output_tokens=1, requests=1) try: result_sync = agent.run_sync( @@ -1006,6 +1008,7 @@ with capture_run_messages() as messages: # (2)! usage=RequestUsage(input_tokens=62, output_tokens=4), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ), ModelRequest( parts=[ @@ -1028,6 +1031,7 @@ with capture_run_messages() as messages: # (2)! usage=RequestUsage(input_tokens=72, output_tokens=8), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ diff --git a/docs/direct.md b/docs/direct.md index 117180245f..79b34df133 100644 --- a/docs/direct.md +++ b/docs/direct.md @@ -87,6 +87,7 @@ async def main(): usage=RequestUsage(input_tokens=55, output_tokens=7), model_name='gpt-4.1-nano', timestamp=datetime.datetime(...), + provider_name='function', ) """ ``` diff --git a/docs/message-history.md b/docs/message-history.md index 401a5cecf2..d5aa966cce 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -61,6 +61,7 @@ print(result.all_messages()) usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ @@ -129,6 +130,7 @@ async def main(): usage=RequestUsage(input_tokens=50, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ @@ -183,6 +185,7 @@ print(result2.all_messages()) usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ), ModelRequest( parts=[ @@ -201,6 +204,7 @@ print(result2.all_messages()) usage=RequestUsage(input_tokens=61, output_tokens=26), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ @@ -302,6 +306,7 @@ print(result2.all_messages()) usage=RequestUsage(input_tokens=60, output_tokens=12), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ), ModelRequest( parts=[ @@ -320,6 +325,7 @@ print(result2.all_messages()) usage=RequestUsage(input_tokens=61, output_tokens=26), model_name='gemini-1.5-pro', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index a46afff2d8..151b6b67cf 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -52,7 +52,7 @@ result = joke_selection_agent.run_sync( print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) -#> RunUsage(input_tokens=204, output_tokens=24, requests=3, cost=Decimal('0.0003475')) +#> RunUsage(input_tokens=204, output_tokens=24, requests=3) ``` 1. The "parent" or controlling agent. @@ -143,11 +143,7 @@ async def main(): print(result.output) #> Did you hear about the toothpaste scandal? They called it Colgate. print(result.usage()) # (6)! - """ - RunUsage( - input_tokens=309, output_tokens=32, requests=4, cost=Decimal('0.00036') - ) - """ + #> RunUsage(input_tokens=309, output_tokens=32, requests=4) ``` 1. Define a dataclass to hold the client and API key dependencies. diff --git a/docs/output.md b/docs/output.md index 783762d4de..57f0479e22 100644 --- a/docs/output.md +++ b/docs/output.md @@ -24,7 +24,7 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.output) #> city='London' country='United Kingdom' print(result.usage()) -#> RunUsage(input_tokens=57, output_tokens=8, requests=1, cost=Decimal('0.000006675')) +#> RunUsage(input_tokens=57, output_tokens=8, requests=1) ``` _(This example is complete, it can be run "as is")_ diff --git a/docs/tools.md b/docs/tools.md index 34c56f89fc..ef48f1df88 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -98,6 +98,7 @@ print(dice_result.all_messages()) usage=RequestUsage(input_tokens=90, output_tokens=2), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), + provider_name='function', ), ModelRequest( parts=[ @@ -118,6 +119,7 @@ print(dice_result.all_messages()) usage=RequestUsage(input_tokens=91, output_tokens=4), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), + provider_name='function', ), ModelRequest( parts=[ @@ -138,6 +140,7 @@ print(dice_result.all_messages()) usage=RequestUsage(input_tokens=92, output_tokens=12), model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ @@ -827,6 +830,7 @@ print(result.all_messages()) usage=RequestUsage(input_tokens=63, output_tokens=21), model_name='gpt-5', timestamp=datetime.datetime(...), + provider_name='function', ), ModelRequest( parts=[ @@ -859,6 +863,7 @@ print(result.all_messages()) usage=RequestUsage(input_tokens=79, output_tokens=39), model_name='gpt-5', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ @@ -984,6 +989,7 @@ async def main(): usage=RequestUsage(input_tokens=63, output_tokens=13), model_name='gpt-5', timestamp=datetime.datetime(...), + provider_name='function', ), ModelRequest( parts=[ @@ -1004,6 +1010,7 @@ async def main(): usage=RequestUsage(input_tokens=64, output_tokens=28), model_name='gpt-5', timestamp=datetime.datetime(...), + provider_name='function', ), ] """ diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4e0d663a0b..69307cd1d5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -533,6 +533,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')), diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index fe84555652..44900bc7b7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -642,6 +642,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')), diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index e53ead8cef..0bb9b36adf 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -166,6 +166,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')), diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 0e71cd7410..97755da49b 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -60,6 +60,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), + provider_name='function', ) ''' ``` @@ -113,6 +114,7 @@ def model_request_sync( usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), + provider_name='function', ) ''' ``` diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 9906dffc96..5f775f0dd7 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -621,6 +621,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')), diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 7ed6b848c0..e688494545 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -70,6 +70,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')), @@ -205,6 +206,7 @@ async def main(): usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='gpt-4o', timestamp=datetime.datetime(...), + provider_name='function', ) ), End(data=FinalResult(output='The capital of France is Paris.')),