Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ async def main():
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
)
),
End(data=FinalResult(output='The capital of France is Paris.')),
Expand Down Expand Up @@ -365,6 +366,7 @@ async def main():
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
)
),
End(data=FinalResult(output='The capital of France is Paris.')),
Expand Down Expand Up @@ -1006,6 +1008,7 @@ with capture_run_messages() as messages: # (2)!
usage=RequestUsage(input_tokens=62, output_tokens=4),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
),
ModelRequest(
parts=[
Expand All @@ -1028,6 +1031,7 @@ with capture_run_messages() as messages: # (2)!
usage=RequestUsage(input_tokens=72, output_tokens=8),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down
1 change: 1 addition & 0 deletions docs/direct.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def main():
usage=RequestUsage(input_tokens=55, output_tokens=7),
model_name='gpt-4.1-nano',
timestamp=datetime.datetime(...),
provider_name='function',
)
"""
```
Expand Down
6 changes: 6 additions & 0 deletions docs/message-history.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ print(result.all_messages())
usage=RequestUsage(input_tokens=60, output_tokens=12),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down Expand Up @@ -129,6 +130,7 @@ async def main():
usage=RequestUsage(input_tokens=50, output_tokens=12),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down Expand Up @@ -183,6 +185,7 @@ print(result2.all_messages())
usage=RequestUsage(input_tokens=60, output_tokens=12),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
),
ModelRequest(
parts=[
Expand All @@ -201,6 +204,7 @@ print(result2.all_messages())
usage=RequestUsage(input_tokens=61, output_tokens=26),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down Expand Up @@ -302,6 +306,7 @@ print(result2.all_messages())
usage=RequestUsage(input_tokens=60, output_tokens=12),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
),
ModelRequest(
parts=[
Expand All @@ -320,6 +325,7 @@ print(result2.all_messages())
usage=RequestUsage(input_tokens=61, output_tokens=26),
model_name='gemini-1.5-pro',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down
7 changes: 7 additions & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ print(dice_result.all_messages())
usage=RequestUsage(input_tokens=90, output_tokens=2),
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
provider_name='function',
),
ModelRequest(
parts=[
Expand All @@ -118,6 +119,7 @@ print(dice_result.all_messages())
usage=RequestUsage(input_tokens=91, output_tokens=4),
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
provider_name='function',
),
ModelRequest(
parts=[
Expand All @@ -138,6 +140,7 @@ print(dice_result.all_messages())
usage=RequestUsage(input_tokens=92, output_tokens=12),
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down Expand Up @@ -827,6 +830,7 @@ print(result.all_messages())
usage=RequestUsage(input_tokens=63, output_tokens=21),
model_name='gpt-5',
timestamp=datetime.datetime(...),
provider_name='function',
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -859,6 +863,7 @@ print(result.all_messages())
usage=RequestUsage(input_tokens=79, output_tokens=39),
model_name='gpt-5',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down Expand Up @@ -984,6 +989,7 @@ async def main():
usage=RequestUsage(input_tokens=63, output_tokens=13),
model_name='gpt-5',
timestamp=datetime.datetime(...),
provider_name='function',
),
ModelRequest(
parts=[
Expand All @@ -1004,6 +1010,7 @@ async def main():
usage=RequestUsage(input_tokens=64, output_tokens=28),
model_name='gpt-5',
timestamp=datetime.datetime(...),
provider_name='function',
),
]
"""
Expand Down
24 changes: 24 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import asyncio
import dataclasses
import hashlib
import warnings
from collections import defaultdict, deque
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import field
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast

from opentelemetry.trace import Tracer
Expand Down Expand Up @@ -86,6 +88,7 @@ class GraphAgentState:
usage: _usage.RunUsage
retries: int
run_step: int
ignore_warning_cost: bool

def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
self.retries += 1
Expand Down Expand Up @@ -390,6 +393,8 @@ async def stream(
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1
ctx.state.usage.cost += cost(streamed_response, ctx.state.ignore_warning_cost)

agent_stream = result.AgentStream[DepsT, T](
_raw_stream_response=streamed_response,
_output_schema=ctx.deps.output_schema,
Expand Down Expand Up @@ -419,6 +424,7 @@ async def _make_request(
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.requests += 1
ctx.state.usage.cost += cost(model_response, ctx.state.ignore_warning_cost)

return self._finish_handling(ctx, model_response)

Expand Down Expand Up @@ -634,6 +640,24 @@ async def _handle_text_response(
return self._handle_final_result(ctx, result.FinalResult(result_data), [])


def cost(response: _messages.ModelResponse | models.StreamedResponse, ignore_warning_cost: bool) -> Decimal:
# If we can't calculate the price, we don't want to fail the run.
try:
cost = response.cost().total_price
except LookupError:
# NOTE(Marcelo): We can allow some kind of hook on the provider level, which we could retrieve via
# `ctx.deps.model.provider.calculate_cost`, but I'm not sure how would the API look like. Maybe a new parameter
# on the `Provider` classes, that parameter would be a callable that receives the same parameters as `genai_prices`.
if response.model_name not in ('test', 'function') and not ignore_warning_cost:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that test/function provider names would somehow make this special casing easier, I didn't even think about the model names. Since we can just use this, maybe the provider names aren't worth it. I especially don't like how the examples and docs now show model_name='gpt-4o', but also provider_name='function',

warnings.warn(
f'The costs with provider "{response.provider_name}" and model "{response.model_name}" '
"couldn't be calculated. Please report this on GitHub https://github.com/pydantic/genai-prices. "
'If you want to ignore this warning, please pass the `ignore_warning_cost=True` parameter to the `Agent`.'
)
return Decimal(0)
return cost


def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
"""Build a `RunContext` object from the current agent graph run context."""
return RunContext[DepsT](
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
ignore_warning_cost: bool = False,
) -> None: ...

@overload
Expand Down Expand Up @@ -216,6 +217,7 @@ def __init__(
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
ignore_warning_cost: bool = False,
) -> None: ...

def __init__(
Expand Down Expand Up @@ -243,6 +245,7 @@ def __init__(
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
ignore_warning_cost: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to True to make tests pass easily, since many things were failing where we didn't know the price.

I think that Agent.run itself shouldn't create a warning by default, but it should note when a price is unknown. The warning should be emitted when the user explicitly requests the cost of a run.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left it false because I think it's the right behavior, but also to help me understand what we didn't support yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the warning useful if it's opt-in to actually have it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the parameter wouldn't be here, it would always collect and store the information about future warnings during the run. there would be a parameter on some other method for retrieving the cost, where the warning is emitted by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense

**_deprecated_kwargs: Any,
):
"""Create an agent.
Expand Down Expand Up @@ -295,6 +298,7 @@ def __init__(
Each processor takes a list of messages and returns a modified list of messages.
Processors can be sync or async and are applied in sequence.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools.
ignore_warning_cost: If `True`, the agent will ignore warnings about token cost.
"""
if model is None or defer_model_check:
self._model = model
Expand Down Expand Up @@ -365,6 +369,8 @@ def __init__(

self._event_stream_handler = event_stream_handler

self._ignore_warning_cost = ignore_warning_cost

self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar(
Expand Down Expand Up @@ -527,6 +533,7 @@ async def main():
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
)
),
End(data=FinalResult(output='The capital of France is Paris.')),
Expand Down Expand Up @@ -590,6 +597,7 @@ async def main():
usage=usage,
retries=0,
run_step=0,
ignore_warning_cost=self._ignore_warning_cost,
)

# Merge model settings in order of precedence: run > agent > model
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ async def main():
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
)
),
End(data=FinalResult(output='The capital of France is Paris.')),
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async def main():
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
)
),
End(data=FinalResult(output='The capital of France is Paris.')),
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def main():
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='claude-3-5-haiku-latest',
timestamp=datetime.datetime(...),
provider_name='function',
)
'''
```
Expand Down Expand Up @@ -113,6 +114,7 @@ def model_request_sync(
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='claude-3-5-haiku-latest',
timestamp=datetime.datetime(...),
provider_name='function',
)
'''
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ async def main():
usage=RequestUsage(input_tokens=56, output_tokens=7),
model_name='gpt-4o',
timestamp=datetime.datetime(...),
provider_name='function',
)
),
End(data=FinalResult(output='The capital of France is Paris.')),
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ def get(self) -> ModelResponse:
return self.response

def usage(self) -> RequestUsage:
return self.response.usage # pragma: no cover
return self.response.usage

@property
def model_name(self) -> str:
return self.response.model_name or '' # pragma: no cover

@property
def provider_name(self) -> str:
return self.response.provider_name or '' # pragma: no cover
return self.response.provider_name or ''

@property
def timestamp(self) -> datetime:
return self.response.timestamp # pragma: no cover
return self.response.timestamp


class TemporalModel(WrapperModel):
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,8 @@ class ModelResponse:
provider_response_id: str | None = None
"""request ID as specified by the model provider. This can be used to track the specific request to the model."""

def price(self) -> genai_types.PriceCalculation:
"""Calculate the price of the usage.
def cost(self) -> genai_types.PriceCalculation:
"""Calculate the cost of the usage.

Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
"""
Expand Down
14 changes: 14 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Generic, Literal, TypeVar, overload

import httpx
from genai_prices import calc_price, types as genai_types
from typing_extensions import TypeAliasType, TypedDict

from .. import _utils
Expand Down Expand Up @@ -615,6 +616,19 @@ def usage(self) -> RequestUsage:
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
return self._usage

def cost(self) -> genai_types.PriceCalculation:
"""Calculate the cost of the usage.

Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
"""
assert self.model_name, 'Model name is required to calculate price'
return calc_price(
self._usage,
self.model_name,
provider_id=self.provider_name,
genai_request_timestamp=self.timestamp,
)

@property
@abstractmethod
def model_name(self) -> str:
Expand Down
Loading
Loading