Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 254 additions & 56 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
import mimetypes
import re
import time
import uuid
import warnings
Expand Down Expand Up @@ -74,7 +75,9 @@
ToolMessage,
is_data_content_block,
)
from langchain_core.messages.ai import UsageMetadata, add_usage, subtract_usage
from langchain_core.messages.ai import (
UsageMetadata,
)
from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
Expand Down Expand Up @@ -798,6 +801,234 @@ def _extract_grounding_metadata(candidate: Any) -> Dict[str, Any]:
return result


def _sanitize_token_detail_key(raw_key: str) -> str:
"""Convert provider detail labels into snake_case keys."""
sanitized = re.sub(r"[^0-9a-zA-Z]+", "_", raw_key.strip().lower()).strip("_")
return sanitized or "unknown"


def _extract_token_detail_counts(
entries: Sequence[Mapping[str, Any]] | None,
*,
prefix: str | None = None,
) -> dict[str, int]:
"""Convert modality/token entries into a token detail mapping."""
if not entries:
return {}
detail_counts: dict[str, int] = {}
for entry in entries:
if not isinstance(entry, Mapping):
continue
raw_key = entry.get("modality") or entry.get("type") or entry.get("name")
if not raw_key:
continue
raw_value = (
entry.get("token_count")
or entry.get("tokenCount")
or entry.get("tokens_count")
or entry.get("tokensCount")
or entry.get("count")
)
try:
value_int = int(raw_value or 0)
except (TypeError, ValueError):
value_int = 0
if value_int == 0:
continue
key = _sanitize_token_detail_key(str(raw_key))
if prefix:
key = f"{prefix}{key}"
detail_counts[key] = detail_counts.get(key, 0) + value_int
return detail_counts


def _merge_detail_counts(target: dict[str, int], new_entries: dict[str, int]) -> None:
"""Accumulate modality detail counts into the provided target mapping."""
for key, value in new_entries.items():
target[key] = target.get(key, 0) + value


def _usage_proto_to_dict(raw_usage: Any) -> dict[str, Any]:
"""Coerce proto UsageMetadata (or dict) into a plain dictionary."""
if raw_usage is None:
return {}
if isinstance(raw_usage, Mapping):
return dict(raw_usage)
try:
return proto.Message.to_dict(raw_usage)
except Exception: # pragma: no cover - best effort fallback
try:
return dict(raw_usage)
except Exception: # pragma: no cover - final fallback
return {}


def _coerce_usage_metadata(raw_usage: Any) -> Optional[UsageMetadata]:
"""Normalize Gemini usage metadata into LangChain's UsageMetadata."""
usage_dict = _usage_proto_to_dict(raw_usage)
if not usage_dict:
return None

def _get_int(name: str) -> int:
value = usage_dict.get(name)
try:
return int(value or 0)
except (TypeError, ValueError):
return 0

prompt_tokens = _get_int("prompt_token_count")
response_tokens = (
_get_int("candidates_token_count")
or _get_int("response_token_count")
or _get_int("output_token_count")
)
tool_prompt_tokens = _get_int("tool_use_prompt_token_count")
reasoning_tokens = _get_int("thoughts_token_count") or _get_int(
"reasoning_token_count"
)
cache_read_tokens = _get_int("cached_content_token_count")

input_tokens = prompt_tokens + tool_prompt_tokens
output_tokens = response_tokens + reasoning_tokens
total_tokens = _get_int("total_token_count") or _get_int("total_tokens")
if total_tokens == 0:
total_tokens = input_tokens + output_tokens
if total_tokens != input_tokens + output_tokens:
total_tokens = input_tokens + output_tokens

input_details: dict[str, int] = {}
if cache_read_tokens:
input_details["cache_read"] = cache_read_tokens
if tool_prompt_tokens:
input_details["tool_use_prompt"] = tool_prompt_tokens
_merge_detail_counts(
input_details,
_extract_token_detail_counts(
usage_dict.get("prompt_tokens_details")
or usage_dict.get("promptTokensDetails"),
),
)
_merge_detail_counts(
input_details,
_extract_token_detail_counts(
usage_dict.get("tool_use_prompt_tokens_details")
or usage_dict.get("toolUsePromptTokensDetails"),
prefix="tool_use_prompt_",
),
)
_merge_detail_counts(
input_details,
_extract_token_detail_counts(
usage_dict.get("cache_tokens_details")
or usage_dict.get("cacheTokensDetails"),
prefix="cache_",
),
)

output_details: dict[str, int] = {}
if reasoning_tokens:
output_details["reasoning"] = reasoning_tokens
for key in (
"candidates_tokens_details",
"candidatesTokensDetails",
"response_tokens_details",
"responseTokensDetails",
"output_tokens_details",
"outputTokensDetails",
"total_tokens_details",
"totalTokensDetails",
):
_merge_detail_counts(
output_details, _extract_token_detail_counts(usage_dict.get(key))
)

for alt_key in ("thought", "thoughts", "reasoning_tokens"):
if alt_key in output_details:
output_details["reasoning"] = output_details.get(
"reasoning", 0
) + output_details.pop(alt_key)

usage_payload: dict[str, Any] = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}
if input_details:
usage_payload["input_token_details"] = cast(Any, input_details)
if output_details:
usage_payload["output_token_details"] = cast(Any, output_details)

return cast(UsageMetadata, usage_payload)


def _diff_token_details(
current: Mapping[str, Any] | None,
previous: Mapping[str, Any] | None,
) -> dict[str, int]:
"""Compute detail deltas between cumulative usage payloads."""
if not current and not previous:
return {}
current = current or {}
previous = previous or {}
diff: dict[str, int] = {}
for key in set(current).union(previous):
current_value = current.get(key, 0)
previous_value = previous.get(key, 0)
if isinstance(current_value, Mapping) or isinstance(previous_value, Mapping):
nested = _diff_token_details(
current_value if isinstance(current_value, Mapping) else None,
previous_value if isinstance(previous_value, Mapping) else None,
)
if nested:
diff[key] = nested # type: ignore[assignment]
continue
try:
current_int = int(current_value or 0)
except (TypeError, ValueError):
current_int = 0
try:
previous_int = int(previous_value or 0)
except (TypeError, ValueError):
previous_int = 0
delta = current_int - previous_int
if delta != 0:
diff[key] = delta
return diff


def _diff_usage_metadata(
current: UsageMetadata, previous: UsageMetadata
) -> UsageMetadata:
"""Return chunk-level usage delta between cumulative UsageMetadata values."""

input_delta = current.get("input_tokens", 0) - previous.get("input_tokens", 0)
output_delta = current.get("output_tokens", 0) - previous.get("output_tokens", 0)
total_delta = current.get("total_tokens", 0) - previous.get("total_tokens", 0)
expected_total = input_delta + output_delta
if total_delta != expected_total:
total_delta = expected_total

diff_payload: dict[str, Any] = {
"input_tokens": input_delta,
"output_tokens": output_delta,
"total_tokens": total_delta,
}

input_detail_delta = _diff_token_details(
current.get("input_token_details"), previous.get("input_token_details")
)
if input_detail_delta:
diff_payload["input_token_details"] = cast(Any, input_detail_delta)

output_detail_delta = _diff_token_details(
current.get("output_token_details"), previous.get("output_token_details")
)
if output_detail_delta:
diff_payload["output_token_details"] = cast(Any, output_detail_delta)

return cast(UsageMetadata, diff_payload)


def _response_to_result(
response: GenerateContentResponse,
stream: bool = False,
Expand All @@ -806,47 +1037,16 @@ def _response_to_result(
"""Converts a PaLM API response into a LangChain ChatResult."""
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}

# Get usage metadata
try:
input_tokens = response.usage_metadata.prompt_token_count
thought_tokens = response.usage_metadata.thoughts_token_count
output_tokens = response.usage_metadata.candidates_token_count + thought_tokens
total_tokens = response.usage_metadata.total_token_count
cache_read_tokens = response.usage_metadata.cached_content_token_count
if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0:
if thought_tokens > 0:
cumulative_usage = UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
input_token_details={"cache_read": cache_read_tokens},
output_token_details={"reasoning": thought_tokens},
)
else:
cumulative_usage = UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
input_token_details={"cache_read": cache_read_tokens},
)
# previous usage metadata needs to be subtracted because gemini api returns
# already-accumulated token counts with each chunk
lc_usage = subtract_usage(cumulative_usage, prev_usage)
if prev_usage and cumulative_usage["input_tokens"] < prev_usage.get(
"input_tokens", 0
):
# Gemini 1.5 and 2.0 return a lower cumulative count of prompt tokens
# in the final chunk. We take this count to be ground truth because
# it's consistent with the reported total tokens. So we need to
# ensure this chunk compensates (the subtract_usage funcction floors
# at zero).
lc_usage["input_tokens"] = cumulative_usage[
"input_tokens"
] - prev_usage.get("input_tokens", 0)
else:
lc_usage = None
except AttributeError:
lc_usage = None
cumulative_usage = _coerce_usage_metadata(response.usage_metadata)
if cumulative_usage:
llm_output["usage_metadata"] = cumulative_usage

if stream and cumulative_usage and prev_usage:
lc_usage: Optional[UsageMetadata] = _diff_usage_metadata(
cumulative_usage, prev_usage
)
else:
lc_usage = cumulative_usage

generations: List[ChatGeneration] = []

Expand Down Expand Up @@ -1961,19 +2161,18 @@ def _stream(
metadata=self.default_metadata,
)

prev_usage_metadata: UsageMetadata | None = None # cumulative usage
prev_usage_metadata: UsageMetadata | None = None
for chunk in response:
_chat_result = _response_to_result(
chunk, stream=True, prev_usage=prev_usage_metadata
)
gen = cast("ChatGenerationChunk", _chat_result.generations[0])
message = cast("AIMessageChunk", gen.message)

prev_usage_metadata = (
message.usage_metadata
if prev_usage_metadata is None
else add_usage(prev_usage_metadata, message.usage_metadata)
llm_output = _chat_result.llm_output or {}
cumulative_usage = cast(
Optional[UsageMetadata], llm_output.get("usage_metadata")
)
if cumulative_usage is not None:
prev_usage_metadata = cumulative_usage

if run_manager:
run_manager.on_llm_new_token(gen.text, chunk=gen)
Expand Down Expand Up @@ -2024,7 +2223,7 @@ async def _astream(
kwargs["timeout"] = self.timeout
if "max_retries" not in kwargs:
kwargs["max_retries"] = self.max_retries
prev_usage_metadata: UsageMetadata | None = None # cumulative usage
prev_usage_metadata: UsageMetadata | None = None
async for chunk in await _achat_with_retry(
request=request,
generation_method=self.async_client.stream_generate_content,
Expand All @@ -2035,13 +2234,12 @@ async def _astream(
chunk, stream=True, prev_usage=prev_usage_metadata
)
gen = cast("ChatGenerationChunk", _chat_result.generations[0])
message = cast("AIMessageChunk", gen.message)

prev_usage_metadata = (
message.usage_metadata
if prev_usage_metadata is None
else add_usage(prev_usage_metadata, message.usage_metadata)
llm_output = _chat_result.llm_output or {}
cumulative_usage = cast(
Optional[UsageMetadata], llm_output.get("usage_metadata")
)
if cumulative_usage is not None:
prev_usage_metadata = cumulative_usage

if run_manager:
await run_manager.on_llm_new_token(gen.text, chunk=gen)
Expand Down
Loading