Skip to content

Commit 986ba5b

Browse files
committed
Fix Gemini usage metadata handling
1 parent 0d98d7a commit 986ba5b

File tree

5 files changed

+759
-93
lines changed

5 files changed

+759
-93
lines changed

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 248 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
import mimetypes
9+
import re
910
import time
1011
import uuid
1112
import warnings
@@ -74,7 +75,9 @@
7475
ToolMessage,
7576
is_data_content_block,
7677
)
77-
from langchain_core.messages.ai import UsageMetadata, add_usage, subtract_usage
78+
from langchain_core.messages.ai import (
79+
UsageMetadata,
80+
)
7881
from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
7982
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
8083
from langchain_core.output_parsers.base import OutputParserLike
@@ -798,6 +801,228 @@ def _extract_grounding_metadata(candidate: Any) -> Dict[str, Any]:
798801
return result
799802

800803

804+
def _sanitize_token_detail_key(raw_key: str) -> str:
805+
"""Convert provider detail labels into snake_case keys."""
806+
sanitized = re.sub(r"[^0-9a-zA-Z]+", "_", raw_key.strip().lower()).strip("_")
807+
return sanitized or "unknown"
808+
809+
810+
def _extract_token_detail_counts(
811+
entries: Sequence[Mapping[str, Any]] | None,
812+
*,
813+
prefix: str | None = None,
814+
) -> dict[str, int]:
815+
"""Convert modality/token entries into a token detail mapping."""
816+
if not entries:
817+
return {}
818+
detail_counts: dict[str, int] = {}
819+
for entry in entries:
820+
if not isinstance(entry, Mapping):
821+
continue
822+
raw_key = entry.get("modality") or entry.get("type") or entry.get("name")
823+
if not raw_key:
824+
continue
825+
raw_value = (
826+
entry.get("token_count")
827+
or entry.get("tokenCount")
828+
or entry.get("tokens_count")
829+
or entry.get("tokensCount")
830+
or entry.get("count")
831+
)
832+
try:
833+
value_int = int(raw_value or 0)
834+
except (TypeError, ValueError):
835+
value_int = 0
836+
if value_int == 0:
837+
continue
838+
key = _sanitize_token_detail_key(str(raw_key))
839+
if prefix:
840+
key = f"{prefix}{key}"
841+
detail_counts[key] = detail_counts.get(key, 0) + value_int
842+
return detail_counts
843+
844+
845+
def _merge_detail_counts(target: dict[str, int], new_entries: dict[str, int]) -> None:
846+
"""Accumulate modality detail counts into the provided target mapping."""
847+
for key, value in new_entries.items():
848+
target[key] = target.get(key, 0) + value
849+
850+
851+
def _usage_proto_to_dict(raw_usage: Any) -> dict[str, Any]:
852+
"""Coerce proto UsageMetadata (or dict) into a plain dictionary."""
853+
if raw_usage is None:
854+
return {}
855+
if isinstance(raw_usage, Mapping):
856+
return dict(raw_usage)
857+
try:
858+
return proto.Message.to_dict(raw_usage)
859+
except Exception: # pragma: no cover - best effort fallback
860+
try:
861+
return dict(raw_usage)
862+
except Exception: # pragma: no cover - final fallback
863+
return {}
864+
865+
866+
def _coerce_usage_metadata(raw_usage: Any) -> Optional[UsageMetadata]:
867+
"""Normalize Gemini usage metadata into LangChain's UsageMetadata."""
868+
usage_dict = _usage_proto_to_dict(raw_usage)
869+
if not usage_dict:
870+
return None
871+
872+
def _get_int(name: str) -> int:
873+
value = usage_dict.get(name)
874+
try:
875+
return int(value or 0)
876+
except (TypeError, ValueError):
877+
return 0
878+
879+
prompt_tokens = _get_int("prompt_token_count")
880+
response_tokens = (
881+
_get_int("candidates_token_count")
882+
or _get_int("response_token_count")
883+
or _get_int("output_token_count")
884+
)
885+
tool_prompt_tokens = _get_int("tool_use_prompt_token_count")
886+
reasoning_tokens = _get_int("thoughts_token_count") or _get_int(
887+
"reasoning_token_count"
888+
)
889+
cache_read_tokens = _get_int("cached_content_token_count")
890+
891+
input_tokens = prompt_tokens + tool_prompt_tokens
892+
output_tokens = response_tokens + reasoning_tokens
893+
total_tokens = _get_int("total_token_count") or _get_int("total_tokens")
894+
if total_tokens == 0:
895+
total_tokens = input_tokens + output_tokens
896+
if total_tokens != input_tokens + output_tokens:
897+
total_tokens = input_tokens + output_tokens
898+
899+
input_details: dict[str, int] = {}
900+
if cache_read_tokens:
901+
input_details["cache_read"] = cache_read_tokens
902+
if tool_prompt_tokens:
903+
input_details["tool_use_prompt"] = tool_prompt_tokens
904+
_merge_detail_counts(
905+
input_details,
906+
_extract_token_detail_counts(
907+
usage_dict.get("prompt_tokens_details")
908+
or usage_dict.get("promptTokensDetails"),
909+
),
910+
)
911+
_merge_detail_counts(
912+
input_details,
913+
_extract_token_detail_counts(
914+
usage_dict.get("tool_use_prompt_tokens_details")
915+
or usage_dict.get("toolUsePromptTokensDetails"),
916+
prefix="tool_use_prompt_",
917+
),
918+
)
919+
_merge_detail_counts(
920+
input_details,
921+
_extract_token_detail_counts(
922+
usage_dict.get("cache_tokens_details")
923+
or usage_dict.get("cacheTokensDetails"),
924+
prefix="cache_",
925+
),
926+
)
927+
928+
output_details: dict[str, int] = {}
929+
if reasoning_tokens:
930+
output_details["reasoning"] = reasoning_tokens
931+
for key in (
932+
"candidates_tokens_details",
933+
"candidatesTokensDetails",
934+
"response_tokens_details",
935+
"responseTokensDetails",
936+
"output_tokens_details",
937+
"outputTokensDetails",
938+
"total_tokens_details",
939+
"totalTokensDetails",
940+
):
941+
_merge_detail_counts(
942+
output_details, _extract_token_detail_counts(usage_dict.get(key))
943+
)
944+
945+
usage_payload: dict[str, Any] = {
946+
"input_tokens": input_tokens,
947+
"output_tokens": output_tokens,
948+
"total_tokens": total_tokens,
949+
}
950+
if input_details:
951+
usage_payload["input_token_details"] = cast(Any, input_details)
952+
if output_details:
953+
usage_payload["output_token_details"] = cast(Any, output_details)
954+
955+
return cast(UsageMetadata, usage_payload)
956+
957+
958+
def _diff_token_details(
959+
current: Mapping[str, Any] | None,
960+
previous: Mapping[str, Any] | None,
961+
) -> dict[str, int]:
962+
"""Compute detail deltas between cumulative usage payloads."""
963+
if not current and not previous:
964+
return {}
965+
current = current or {}
966+
previous = previous or {}
967+
diff: dict[str, int] = {}
968+
for key in set(current).union(previous):
969+
current_value = current.get(key, 0)
970+
previous_value = previous.get(key, 0)
971+
if isinstance(current_value, Mapping) or isinstance(previous_value, Mapping):
972+
nested = _diff_token_details(
973+
current_value if isinstance(current_value, Mapping) else None,
974+
previous_value if isinstance(previous_value, Mapping) else None,
975+
)
976+
if nested:
977+
diff[key] = nested # type: ignore[assignment]
978+
continue
979+
try:
980+
current_int = int(current_value or 0)
981+
except (TypeError, ValueError):
982+
current_int = 0
983+
try:
984+
previous_int = int(previous_value or 0)
985+
except (TypeError, ValueError):
986+
previous_int = 0
987+
delta = current_int - previous_int
988+
if delta != 0:
989+
diff[key] = delta
990+
return diff
991+
992+
993+
def _diff_usage_metadata(
994+
current: UsageMetadata, previous: UsageMetadata
995+
) -> UsageMetadata:
996+
"""Return chunk-level usage delta between cumulative UsageMetadata values."""
997+
998+
input_delta = current.get("input_tokens", 0) - previous.get("input_tokens", 0)
999+
output_delta = current.get("output_tokens", 0) - previous.get("output_tokens", 0)
1000+
total_delta = current.get("total_tokens", 0) - previous.get("total_tokens", 0)
1001+
expected_total = input_delta + output_delta
1002+
if total_delta != expected_total:
1003+
total_delta = expected_total
1004+
1005+
diff_payload: dict[str, Any] = {
1006+
"input_tokens": input_delta,
1007+
"output_tokens": output_delta,
1008+
"total_tokens": total_delta,
1009+
}
1010+
1011+
input_detail_delta = _diff_token_details(
1012+
current.get("input_token_details"), previous.get("input_token_details")
1013+
)
1014+
if input_detail_delta:
1015+
diff_payload["input_token_details"] = cast(Any, input_detail_delta)
1016+
1017+
output_detail_delta = _diff_token_details(
1018+
current.get("output_token_details"), previous.get("output_token_details")
1019+
)
1020+
if output_detail_delta:
1021+
diff_payload["output_token_details"] = cast(Any, output_detail_delta)
1022+
1023+
return cast(UsageMetadata, diff_payload)
1024+
1025+
8011026
def _response_to_result(
8021027
response: GenerateContentResponse,
8031028
stream: bool = False,
@@ -806,47 +1031,16 @@ def _response_to_result(
8061031
"""Converts a PaLM API response into a LangChain ChatResult."""
8071032
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
8081033

809-
# Get usage metadata
810-
try:
811-
input_tokens = response.usage_metadata.prompt_token_count
812-
thought_tokens = response.usage_metadata.thoughts_token_count
813-
output_tokens = response.usage_metadata.candidates_token_count + thought_tokens
814-
total_tokens = response.usage_metadata.total_token_count
815-
cache_read_tokens = response.usage_metadata.cached_content_token_count
816-
if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0:
817-
if thought_tokens > 0:
818-
cumulative_usage = UsageMetadata(
819-
input_tokens=input_tokens,
820-
output_tokens=output_tokens,
821-
total_tokens=total_tokens,
822-
input_token_details={"cache_read": cache_read_tokens},
823-
output_token_details={"reasoning": thought_tokens},
824-
)
825-
else:
826-
cumulative_usage = UsageMetadata(
827-
input_tokens=input_tokens,
828-
output_tokens=output_tokens,
829-
total_tokens=total_tokens,
830-
input_token_details={"cache_read": cache_read_tokens},
831-
)
832-
# previous usage metadata needs to be subtracted because gemini api returns
833-
# already-accumulated token counts with each chunk
834-
lc_usage = subtract_usage(cumulative_usage, prev_usage)
835-
if prev_usage and cumulative_usage["input_tokens"] < prev_usage.get(
836-
"input_tokens", 0
837-
):
838-
# Gemini 1.5 and 2.0 return a lower cumulative count of prompt tokens
839-
# in the final chunk. We take this count to be ground truth because
840-
# it's consistent with the reported total tokens. So we need to
841-
# ensure this chunk compensates (the subtract_usage funcction floors
842-
# at zero).
843-
lc_usage["input_tokens"] = cumulative_usage[
844-
"input_tokens"
845-
] - prev_usage.get("input_tokens", 0)
846-
else:
847-
lc_usage = None
848-
except AttributeError:
849-
lc_usage = None
1034+
cumulative_usage = _coerce_usage_metadata(response.usage_metadata)
1035+
if cumulative_usage:
1036+
llm_output["usage_metadata"] = cumulative_usage
1037+
1038+
if stream and cumulative_usage and prev_usage:
1039+
lc_usage: Optional[UsageMetadata] = _diff_usage_metadata(
1040+
cumulative_usage, prev_usage
1041+
)
1042+
else:
1043+
lc_usage = cumulative_usage
8501044

8511045
generations: List[ChatGeneration] = []
8521046

@@ -1961,19 +2155,18 @@ def _stream(
19612155
metadata=self.default_metadata,
19622156
)
19632157

1964-
prev_usage_metadata: UsageMetadata | None = None # cumulative usage
2158+
prev_usage_metadata: UsageMetadata | None = None
19652159
for chunk in response:
19662160
_chat_result = _response_to_result(
19672161
chunk, stream=True, prev_usage=prev_usage_metadata
19682162
)
19692163
gen = cast("ChatGenerationChunk", _chat_result.generations[0])
1970-
message = cast("AIMessageChunk", gen.message)
1971-
1972-
prev_usage_metadata = (
1973-
message.usage_metadata
1974-
if prev_usage_metadata is None
1975-
else add_usage(prev_usage_metadata, message.usage_metadata)
2164+
llm_output = _chat_result.llm_output or {}
2165+
cumulative_usage = cast(
2166+
Optional[UsageMetadata], llm_output.get("usage_metadata")
19762167
)
2168+
if cumulative_usage is not None:
2169+
prev_usage_metadata = cumulative_usage
19772170

19782171
if run_manager:
19792172
run_manager.on_llm_new_token(gen.text, chunk=gen)
@@ -2024,7 +2217,7 @@ async def _astream(
20242217
kwargs["timeout"] = self.timeout
20252218
if "max_retries" not in kwargs:
20262219
kwargs["max_retries"] = self.max_retries
2027-
prev_usage_metadata: UsageMetadata | None = None # cumulative usage
2220+
prev_usage_metadata: UsageMetadata | None = None
20282221
async for chunk in await _achat_with_retry(
20292222
request=request,
20302223
generation_method=self.async_client.stream_generate_content,
@@ -2035,13 +2228,12 @@ async def _astream(
20352228
chunk, stream=True, prev_usage=prev_usage_metadata
20362229
)
20372230
gen = cast("ChatGenerationChunk", _chat_result.generations[0])
2038-
message = cast("AIMessageChunk", gen.message)
2039-
2040-
prev_usage_metadata = (
2041-
message.usage_metadata
2042-
if prev_usage_metadata is None
2043-
else add_usage(prev_usage_metadata, message.usage_metadata)
2231+
llm_output = _chat_result.llm_output or {}
2232+
cumulative_usage = cast(
2233+
Optional[UsageMetadata], llm_output.get("usage_metadata")
20442234
)
2235+
if cumulative_usage is not None:
2236+
prev_usage_metadata = cumulative_usage
20452237

20462238
if run_manager:
20472239
await run_manager.on_llm_new_token(gen.text, chunk=gen)

0 commit comments

Comments
 (0)