-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Adding CountToken to Gemini #2137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Gemini Provides an endpoint to count token before sending an response https://ai.google.dev/api/tokens#method:-models.counttokens
added type adaptor
Removed extra assignment
Linting
Linting
Linting
Removed White Space
@kauabh I agree that if a model API has a method to count tokens, it would be nice to expose that on the But I don't think we should automatically use it when That check would need to be implemented here, just before we call pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py Lines 379 to 393 in b31c77d
This would require a method that exists on every model, so it'd be implemented as an abstract method on the base As for that concrete implementation, I recommend adding it to |
@DouweM make sense, let me rework on this. Thanks for detailed input, appreciate your time |
* adding count token for google * resolved conflicts
Hey @DouweM I have made changes as per comments, looks like quite a few files got touched, It will would be great if you can provide some feedback on the changes till now. Also if you can share some thoughts on changing "instrumented.py" with count_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kauabh Thanks! We're almost there :)
@@ -382,6 +390,14 @@ async def request( | |||
"""Make a request to the model.""" | |||
raise NotImplementedError() | |||
|
|||
@abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this method is not required to be implemented, we can do the same thing we do in request_stream
, meaning not mark it as @abstractmethod
(so we can drop all the empty implementations from the model classes) and put the Token counting is not supported by <X>
error message here.
self, | ||
messages: list[ModelMessage], | ||
) -> BaseCountTokensResponse: | ||
"""Make a request to the model.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs an update!
@@ -77,6 +77,13 @@ async def request( | |||
|
|||
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) | |||
|
|||
async def count_tokens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll want to forward this to the models in question
messages: list[ModelMessage], | ||
) -> BaseCountTokensResponse: | ||
"""Token counting is not supported by the CohereModel.""" | ||
raise NotImplementedError('Token counting is not supported by CohereModel') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong model name, but should be fixed by making the super method non-abstract and dropping the definition here
_, contents = await self._map_messages(messages) | ||
response = self.client.models.count_tokens( | ||
model=self._model_name, | ||
contents=contents, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should include not just the messages but the entire generateContentRequest
as function definitions etc also count as tokens
|
||
|
||
@dataclass(repr=False) | ||
class BaseCountTokensResponse: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this entire response? Or could count_tokens
just return the total tokens counted, or even better: a Usage
object we can use with Usage.incr
?
if ctx.deps.usage_limits and ctx.deps.usage_limits.pre_request_token_check_with_overhead: | ||
token_count = await ctx.deps.model.count_tokens(message_history) | ||
|
||
ctx.deps.usage_limits.check_tokens(_usage.Usage(request_tokens=token_count.total_tokens)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use check_before_request
instead of check_tokens
.
Also, I think this should be in prepare_request
where we currently call check_before_request
. I think likely move the model_request_parameters
and message_history
stuff there as well, and reduce the duplication between this method and _stream
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, instead of checking if this particular request's total tokens exceeded the limit, shouldn't we check all the tokens so far plus the newly counted tokens? That'd be consistent with what _finish_handling
currently does:
ctx.state.usage.incr(response.usage)
if ctx.deps.usage_limits: # pragma: no branch
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
So we'd want to copy ctx.state.usage
, call incr
with the new usage, and then run the check against that.
@@ -96,6 +96,10 @@ class UsageLimits: | |||
"""The maximum number of tokens allowed in responses from the model.""" | |||
total_tokens_limit: int | None = None | |||
"""The maximum number of tokens allowed in requests and responses combined.""" | |||
pre_request_token_check_with_overhead: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call it count_tokens_before_request
, and clarify the description slightly to say this typically requires an API call.
), | ||
] | ||
result = await model.count_tokens(messages) | ||
assert result.total_tokens == snapshot(7) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll want to test not just the count_tokens
method but the actual usage limit enforcement!
Gemini Provides an endpoint to count tokens https://ai.google.dev/api/tokens#method:-models.counttokens.
I think it will be useful and address some concerns in this issue #1794 (at least for gemini).
@DouweM Wanted to check if this will be helpful. If yes and if the approach is right, wanted to know if you can share some pointers around adding it in usage_limits for gemini. Happy to work on other models too, if this one make it through.