Skip to content

Adding CountToken to Gemini #2137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,18 @@ async def _prepare_request(

# Check usage
if ctx.deps.usage_limits: # pragma: no branch
ctx.deps.usage_limits.check_before_request(ctx.state.usage)
if ctx.deps.usage_limits.count_tokens_before_request:
model_request_parameters = await _prepare_request_parameters(ctx)
message_history = await _process_message_history(
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
)
token_count = await ctx.deps.model.count_tokens(
message_history, ctx.deps.model_settings, model_request_parameters
)
ctx.state.usage.incr(token_count.to_usage())
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
else:
ctx.deps.usage_limits.check_before_request(ctx.state.usage)

# Increment run_step
ctx.state.run_step += 1
Expand Down
16 changes: 16 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,3 +1148,19 @@ def tool_call_id(self) -> str:
HandleResponseEvent = Annotated[
Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('event_kind')
]


@dataclass(repr=False)
class BaseCountTokensResponse:
"""Structured response for token count API calls from various model providers."""

total_tokens: int | None = field(
default=None, metadata={'description': 'Total number of tokens counted in the messages.'}
)
"""Total number of tokens counted in the messages."""

def to_usage(self) -> Usage:
"""Usage object conversion for compatibility with Usage.incr."""
return Usage(request_tokens=self.total_tokens)

__repr__ = _utils.dataclasses_no_defaults_repr
20 changes: 19 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
from .._output import OutputObjectDefinition
from .._parts_manager import ModelResponsePartsManager
from ..exceptions import UserError
from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
from ..messages import (
BaseCountTokensResponse,
FileUrl,
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponseStreamEvent,
VideoUrl,
)
from ..output import OutputMode
from ..profiles._json_schema import JsonSchemaTransformer
from ..settings import ModelSettings
Expand Down Expand Up @@ -382,6 +390,16 @@ async def request(
"""Make a request to the model."""
raise NotImplementedError()

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> BaseCountTokensResponse:
"""Make a request to the model for counting tokens."""
# This method is not required, but you need to implement it if you want to support token counting before making a request
raise NotImplementedError(f'Token counting API call is not supported by this {self.__class__.__name__}')

@asynccontextmanager
async def request_stream(
self,
Expand Down
42 changes: 42 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .._output import OutputObjectDefinition
from ..exceptions import UserError
from ..messages import (
BaseCountTokensResponse,
BinaryContent,
FileUrl,
ModelMessage,
Expand Down Expand Up @@ -48,6 +49,8 @@
from google.genai.types import (
ContentDict,
ContentUnionDict,
CountTokensConfigDict,
CountTokensResponse,
FunctionCallDict,
FunctionCallingConfigDict,
FunctionCallingConfigMode,
Expand Down Expand Up @@ -181,6 +184,37 @@ async def request(
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
return self._process_response(response)

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> BaseCountTokensResponse:
check_allow_model_requests()
model_settings = cast(GoogleModelSettings, model_settings or {})
_, contents = await self._map_messages(messages)
http_options: HttpOptionsDict = {
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
}
if timeout := model_settings.get('timeout'):
if isinstance(timeout, (int, float)):
http_options['timeout'] = int(1000 * timeout)
else:
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
# system_instruction and tools parameter are not supported for Count Tokens https://github.com/googleapis/python-genai/blob/038ecd3375f7c63a8ee8c1afa50cff1976343625/google/genai/models.py#L1169
config = CountTokensConfigDict(
http_options=http_options,
system_instruction=None,
tools=None,
)

response = await self.client.aio.models.count_tokens(
model=self._model_name,
contents=contents,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include not just the messages but the entire generateContentRequest as function definitions etc also count as tokens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, There are some issues with count_token as it raise an error with tools and system instruction for gemini models, I have added the link for actual code in google genai, because of that system instruction and tools are None. Can you share your views on how to go about this

config=config,
)
return self._process_count_tokens_response(response)

@asynccontextmanager
async def request_stream(
self,
Expand Down Expand Up @@ -338,6 +372,14 @@ async def _process_streamed_response(self, response: AsyncIterator[GenerateConte
_timestamp=first_chunk.create_time or _utils.now_utc(),
)

def _process_count_tokens_response(
self,
response: CountTokensResponse,
) -> BaseCountTokensResponse:
if not hasattr(response, 'total_tokens') or response.total_tokens is None:
raise UnexpectedModelBehavior('Total tokens missing from Gemini response', str(response))
return BaseCountTokensResponse(total_tokens=response.total_tokens)

async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
contents: list[ContentUnionDict] = []
system_parts: list[PartDict] = []
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class UsageLimits:
"""The maximum number of tokens allowed in responses from the model."""
total_tokens_limit: int | None = None
"""The maximum number of tokens allowed in requests and responses combined."""
count_tokens_before_request: bool = False
"""If True, perform a token counting pass before sending the request to the model,
to enforce `request_tokens_limit` ahead of time. This may incur additional overhead
(from calling the model's `count_tokens` API before making the actual request) and is disabled by default."""

def has_token_limits(self) -> bool:
"""Returns `True` if this instance places any limits on token counts.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
interactions:
- request:
body: '{"contents": [{"parts": [{"text": "The quick brown fox jumps over the lazydog."}],
"role": "user"}]}'
headers:
Content-Type:
- application/json
user-agent:
- google-genai-sdk/1.26.0 gl-python/3.12.7
x-goog-api-client:
- google-genai-sdk/1.26.0 gl-python/3.12.7
method: post
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:countTokens
response:
body:
string: "{\n \"totalTokens\": 12,\n \"promptTokensDetails\": [\n {\n \"modality\":
\"TEXT\",\n \"tokenCount\": 12\n }\n ]\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Content-Type:
- application/json; charset=UTF-8
Date:
- Fri, 01 Aug 2025 15:59:25 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=1582
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
content-length:
- '117'
status:
code: 200
message: OK
version: 1
14 changes: 13 additions & 1 deletion tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel
from typing_extensions import TypedDict

from pydantic_ai import UsageLimitExceeded
from pydantic_ai.agent import Agent
from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError
from pydantic_ai.messages import (
Expand All @@ -36,7 +37,7 @@
VideoUrl,
)
from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
from pydantic_ai.result import Usage
from pydantic_ai.result import Usage, UsageLimits

from ..conftest import IsDatetime, IsInstance, IsStr, try_import

Expand Down Expand Up @@ -1393,3 +1394,14 @@ class CountryLanguage(BaseModel):
),
]
)


async def test_google_model_usage_limit_exceeded(allow_model_requests: None, google_provider: GoogleProvider):
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
agent = Agent(model=model)

with pytest.raises(UsageLimitExceeded, match='Exceeded the request_tokens_limit of 9 \\(request_tokens=12\\)'):
await agent.run(
'The quick brown fox jumps over the lazydog.',
usage_limits=UsageLimits(request_tokens_limit=9, count_tokens_before_request=True),
)
Loading