-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Adding CountToken to Gemini #2137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6f86735
5cd88e0
a302345
3b2e26a
dc4d29b
16f18dc
24d6c25
90fc8bb
2bfc8d0
8be5932
c644a5e
e07c989
bae4ca9
72c8125
fa9de61
d1a97fb
27b1fb6
b9801a0
ac57d72
4d2e480
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
from .._output import OutputObjectDefinition | ||
from ..exceptions import UserError | ||
from ..messages import ( | ||
BaseCountTokensResponse, | ||
BinaryContent, | ||
FileUrl, | ||
ModelMessage, | ||
|
@@ -48,6 +49,8 @@ | |
from google.genai.types import ( | ||
ContentDict, | ||
ContentUnionDict, | ||
CountTokensConfigDict, | ||
CountTokensResponse, | ||
FunctionCallDict, | ||
FunctionCallingConfigDict, | ||
FunctionCallingConfigMode, | ||
|
@@ -181,6 +184,37 @@ async def request( | |
response = await self._generate_content(messages, False, model_settings, model_request_parameters) | ||
return self._process_response(response) | ||
|
||
async def count_tokens( | ||
self, | ||
messages: list[ModelMessage], | ||
model_settings: ModelSettings | None, | ||
model_request_parameters: ModelRequestParameters, | ||
) -> BaseCountTokensResponse: | ||
check_allow_model_requests() | ||
model_settings = cast(GoogleModelSettings, model_settings or {}) | ||
_, contents = await self._map_messages(messages) | ||
http_options: HttpOptionsDict = { | ||
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} | ||
} | ||
if timeout := model_settings.get('timeout'): | ||
if isinstance(timeout, (int, float)): | ||
http_options['timeout'] = int(1000 * timeout) | ||
else: | ||
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout') | ||
# system_instruction and tools parameter are not supported for Count Tokens https://github.com/googleapis/python-genai/blob/038ecd3375f7c63a8ee8c1afa50cff1976343625/google/genai/models.py#L1169 | ||
config = CountTokensConfigDict( | ||
http_options=http_options, | ||
system_instruction=None, | ||
tools=None, | ||
) | ||
|
||
response = await self.client.aio.models.count_tokens( | ||
model=self._model_name, | ||
contents=contents, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should include not just the messages but the entire There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense, There are some issues with count_token as it raise an error with tools and system instruction for gemini models, I have added the link for actual code in google genai, because of that system instruction and tools are None. Can you share your views on how to go about this |
||
config=config, | ||
) | ||
return self._process_count_tokens_response(response) | ||
|
||
@asynccontextmanager | ||
async def request_stream( | ||
self, | ||
|
@@ -338,6 +372,14 @@ async def _process_streamed_response(self, response: AsyncIterator[GenerateConte | |
_timestamp=first_chunk.create_time or _utils.now_utc(), | ||
) | ||
|
||
def _process_count_tokens_response( | ||
self, | ||
response: CountTokensResponse, | ||
) -> BaseCountTokensResponse: | ||
if not hasattr(response, 'total_tokens') or response.total_tokens is None: | ||
raise UnexpectedModelBehavior('Total tokens missing from Gemini response', str(response)) | ||
return BaseCountTokensResponse(total_tokens=response.total_tokens) | ||
|
||
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: | ||
contents: list[ContentUnionDict] = [] | ||
system_parts: list[PartDict] = [] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
interactions: | ||
- request: | ||
body: '{"contents": [{"parts": [{"text": "The quick brown fox jumps over the lazydog."}], | ||
"role": "user"}]}' | ||
headers: | ||
Content-Type: | ||
- application/json | ||
user-agent: | ||
- google-genai-sdk/1.26.0 gl-python/3.12.7 | ||
x-goog-api-client: | ||
- google-genai-sdk/1.26.0 gl-python/3.12.7 | ||
method: post | ||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:countTokens | ||
response: | ||
body: | ||
string: "{\n \"totalTokens\": 12,\n \"promptTokensDetails\": [\n {\n \"modality\": | ||
\"TEXT\",\n \"tokenCount\": 12\n }\n ]\n}\n" | ||
headers: | ||
Alt-Svc: | ||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 | ||
Content-Type: | ||
- application/json; charset=UTF-8 | ||
Date: | ||
- Fri, 01 Aug 2025 15:59:25 GMT | ||
Server: | ||
- scaffolding on HTTPServer2 | ||
Server-Timing: | ||
- gfet4t7; dur=1582 | ||
Transfer-Encoding: | ||
- chunked | ||
Vary: | ||
- Origin | ||
- X-Origin | ||
- Referer | ||
X-Content-Type-Options: | ||
- nosniff | ||
X-Frame-Options: | ||
- SAMEORIGIN | ||
X-XSS-Protection: | ||
- '0' | ||
content-length: | ||
- '117' | ||
status: | ||
code: 200 | ||
message: OK | ||
version: 1 |
Uh oh!
There was an error while loading. Please reload this page.