diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index f8725b9c6..c9cf47758 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -6,6 +6,7 @@ import json import logging import mimetypes +import re import time import uuid import warnings @@ -74,7 +75,9 @@ ToolMessage, is_data_content_block, ) -from langchain_core.messages.ai import UsageMetadata, add_usage, subtract_usage +from langchain_core.messages.ai import ( + UsageMetadata, +) from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike @@ -798,6 +801,234 @@ def _extract_grounding_metadata(candidate: Any) -> Dict[str, Any]: return result +def _sanitize_token_detail_key(raw_key: str) -> str: + """Convert provider detail labels into snake_case keys.""" + sanitized = re.sub(r"[^0-9a-zA-Z]+", "_", raw_key.strip().lower()).strip("_") + return sanitized or "unknown" + + +def _extract_token_detail_counts( + entries: Sequence[Mapping[str, Any]] | None, + *, + prefix: str | None = None, +) -> dict[str, int]: + """Convert modality/token entries into a token detail mapping.""" + if not entries: + return {} + detail_counts: dict[str, int] = {} + for entry in entries: + if not isinstance(entry, Mapping): + continue + raw_key = entry.get("modality") or entry.get("type") or entry.get("name") + if not raw_key: + continue + raw_value = ( + entry.get("token_count") + or entry.get("tokenCount") + or entry.get("tokens_count") + or entry.get("tokensCount") + or entry.get("count") + ) + try: + value_int = int(raw_value or 0) + except (TypeError, ValueError): + value_int = 0 + if value_int == 0: + continue + key = _sanitize_token_detail_key(str(raw_key)) + if prefix: + key = f"{prefix}{key}" + detail_counts[key] = detail_counts.get(key, 0) + value_int + return detail_counts + + +def _merge_detail_counts(target: dict[str, int], new_entries: dict[str, int]) -> None: + """Accumulate modality detail counts into the provided target mapping.""" + for key, value in new_entries.items(): + target[key] = target.get(key, 0) + value + + +def _usage_proto_to_dict(raw_usage: Any) -> dict[str, Any]: + """Coerce proto UsageMetadata (or dict) into a plain dictionary.""" + if raw_usage is None: + return {} + if isinstance(raw_usage, Mapping): + return dict(raw_usage) + try: + return proto.Message.to_dict(raw_usage) + except Exception: # pragma: no cover - best effort fallback + try: + return dict(raw_usage) + except Exception: # pragma: no cover - final fallback + return {} + + +def _coerce_usage_metadata(raw_usage: Any) -> Optional[UsageMetadata]: + """Normalize Gemini usage metadata into LangChain's UsageMetadata.""" + usage_dict = _usage_proto_to_dict(raw_usage) + if not usage_dict: + return None + + def _get_int(name: str) -> int: + value = usage_dict.get(name) + try: + return int(value or 0) + except (TypeError, ValueError): + return 0 + + prompt_tokens = _get_int("prompt_token_count") + response_tokens = ( + _get_int("candidates_token_count") + or _get_int("response_token_count") + or _get_int("output_token_count") + ) + tool_prompt_tokens = _get_int("tool_use_prompt_token_count") + reasoning_tokens = _get_int("thoughts_token_count") or _get_int( + "reasoning_token_count" + ) + cache_read_tokens = _get_int("cached_content_token_count") + + input_tokens = prompt_tokens + tool_prompt_tokens + output_tokens = response_tokens + reasoning_tokens + total_tokens = _get_int("total_token_count") or _get_int("total_tokens") + if total_tokens == 0: + total_tokens = input_tokens + output_tokens + if total_tokens != input_tokens + output_tokens: + total_tokens = input_tokens + output_tokens + + input_details: dict[str, int] = {} + if cache_read_tokens: + input_details["cache_read"] = cache_read_tokens + if tool_prompt_tokens: + input_details["tool_use_prompt"] = tool_prompt_tokens + _merge_detail_counts( + input_details, + _extract_token_detail_counts( + usage_dict.get("prompt_tokens_details") + or usage_dict.get("promptTokensDetails"), + ), + ) + _merge_detail_counts( + input_details, + _extract_token_detail_counts( + usage_dict.get("tool_use_prompt_tokens_details") + or usage_dict.get("toolUsePromptTokensDetails"), + prefix="tool_use_prompt_", + ), + ) + _merge_detail_counts( + input_details, + _extract_token_detail_counts( + usage_dict.get("cache_tokens_details") + or usage_dict.get("cacheTokensDetails"), + prefix="cache_", + ), + ) + + output_details: dict[str, int] = {} + if reasoning_tokens: + output_details["reasoning"] = reasoning_tokens + for key in ( + "candidates_tokens_details", + "candidatesTokensDetails", + "response_tokens_details", + "responseTokensDetails", + "output_tokens_details", + "outputTokensDetails", + "total_tokens_details", + "totalTokensDetails", + ): + _merge_detail_counts( + output_details, _extract_token_detail_counts(usage_dict.get(key)) + ) + + for alt_key in ("thought", "thoughts", "reasoning_tokens"): + if alt_key in output_details: + output_details["reasoning"] = output_details.get( + "reasoning", 0 + ) + output_details.pop(alt_key) + + usage_payload: dict[str, Any] = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + if input_details: + usage_payload["input_token_details"] = cast(Any, input_details) + if output_details: + usage_payload["output_token_details"] = cast(Any, output_details) + + return cast(UsageMetadata, usage_payload) + + +def _diff_token_details( + current: Mapping[str, Any] | None, + previous: Mapping[str, Any] | None, +) -> dict[str, int]: + """Compute detail deltas between cumulative usage payloads.""" + if not current and not previous: + return {} + current = current or {} + previous = previous or {} + diff: dict[str, int] = {} + for key in set(current).union(previous): + current_value = current.get(key, 0) + previous_value = previous.get(key, 0) + if isinstance(current_value, Mapping) or isinstance(previous_value, Mapping): + nested = _diff_token_details( + current_value if isinstance(current_value, Mapping) else None, + previous_value if isinstance(previous_value, Mapping) else None, + ) + if nested: + diff[key] = nested # type: ignore[assignment] + continue + try: + current_int = int(current_value or 0) + except (TypeError, ValueError): + current_int = 0 + try: + previous_int = int(previous_value or 0) + except (TypeError, ValueError): + previous_int = 0 + delta = current_int - previous_int + if delta != 0: + diff[key] = delta + return diff + + +def _diff_usage_metadata( + current: UsageMetadata, previous: UsageMetadata +) -> UsageMetadata: + """Return chunk-level usage delta between cumulative UsageMetadata values.""" + + input_delta = current.get("input_tokens", 0) - previous.get("input_tokens", 0) + output_delta = current.get("output_tokens", 0) - previous.get("output_tokens", 0) + total_delta = current.get("total_tokens", 0) - previous.get("total_tokens", 0) + expected_total = input_delta + output_delta + if total_delta != expected_total: + total_delta = expected_total + + diff_payload: dict[str, Any] = { + "input_tokens": input_delta, + "output_tokens": output_delta, + "total_tokens": total_delta, + } + + input_detail_delta = _diff_token_details( + current.get("input_token_details"), previous.get("input_token_details") + ) + if input_detail_delta: + diff_payload["input_token_details"] = cast(Any, input_detail_delta) + + output_detail_delta = _diff_token_details( + current.get("output_token_details"), previous.get("output_token_details") + ) + if output_detail_delta: + diff_payload["output_token_details"] = cast(Any, output_detail_delta) + + return cast(UsageMetadata, diff_payload) + + def _response_to_result( response: GenerateContentResponse, stream: bool = False, @@ -806,47 +1037,16 @@ def _response_to_result( """Converts a PaLM API response into a LangChain ChatResult.""" llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)} - # Get usage metadata - try: - input_tokens = response.usage_metadata.prompt_token_count - thought_tokens = response.usage_metadata.thoughts_token_count - output_tokens = response.usage_metadata.candidates_token_count + thought_tokens - total_tokens = response.usage_metadata.total_token_count - cache_read_tokens = response.usage_metadata.cached_content_token_count - if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0: - if thought_tokens > 0: - cumulative_usage = UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - input_token_details={"cache_read": cache_read_tokens}, - output_token_details={"reasoning": thought_tokens}, - ) - else: - cumulative_usage = UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - input_token_details={"cache_read": cache_read_tokens}, - ) - # previous usage metadata needs to be subtracted because gemini api returns - # already-accumulated token counts with each chunk - lc_usage = subtract_usage(cumulative_usage, prev_usage) - if prev_usage and cumulative_usage["input_tokens"] < prev_usage.get( - "input_tokens", 0 - ): - # Gemini 1.5 and 2.0 return a lower cumulative count of prompt tokens - # in the final chunk. We take this count to be ground truth because - # it's consistent with the reported total tokens. So we need to - # ensure this chunk compensates (the subtract_usage funcction floors - # at zero). - lc_usage["input_tokens"] = cumulative_usage[ - "input_tokens" - ] - prev_usage.get("input_tokens", 0) - else: - lc_usage = None - except AttributeError: - lc_usage = None + cumulative_usage = _coerce_usage_metadata(response.usage_metadata) + if cumulative_usage: + llm_output["usage_metadata"] = cumulative_usage + + if stream and cumulative_usage and prev_usage: + lc_usage: Optional[UsageMetadata] = _diff_usage_metadata( + cumulative_usage, prev_usage + ) + else: + lc_usage = cumulative_usage generations: List[ChatGeneration] = [] @@ -1961,19 +2161,18 @@ def _stream( metadata=self.default_metadata, ) - prev_usage_metadata: UsageMetadata | None = None # cumulative usage + prev_usage_metadata: UsageMetadata | None = None for chunk in response: _chat_result = _response_to_result( chunk, stream=True, prev_usage=prev_usage_metadata ) gen = cast("ChatGenerationChunk", _chat_result.generations[0]) - message = cast("AIMessageChunk", gen.message) - - prev_usage_metadata = ( - message.usage_metadata - if prev_usage_metadata is None - else add_usage(prev_usage_metadata, message.usage_metadata) + llm_output = _chat_result.llm_output or {} + cumulative_usage = cast( + Optional[UsageMetadata], llm_output.get("usage_metadata") ) + if cumulative_usage is not None: + prev_usage_metadata = cumulative_usage if run_manager: run_manager.on_llm_new_token(gen.text, chunk=gen) @@ -2024,7 +2223,7 @@ async def _astream( kwargs["timeout"] = self.timeout if "max_retries" not in kwargs: kwargs["max_retries"] = self.max_retries - prev_usage_metadata: UsageMetadata | None = None # cumulative usage + prev_usage_metadata: UsageMetadata | None = None async for chunk in await _achat_with_retry( request=request, generation_method=self.async_client.stream_generate_content, @@ -2035,13 +2234,12 @@ async def _astream( chunk, stream=True, prev_usage=prev_usage_metadata ) gen = cast("ChatGenerationChunk", _chat_result.generations[0]) - message = cast("AIMessageChunk", gen.message) - - prev_usage_metadata = ( - message.usage_metadata - if prev_usage_metadata is None - else add_usage(prev_usage_metadata, message.usage_metadata) + llm_output = _chat_result.llm_output or {} + cumulative_usage = cast( + Optional[UsageMetadata], llm_output.get("usage_metadata") ) + if cumulative_usage is not None: + prev_usage_metadata = cumulative_usage if run_manager: await run_manager.on_llm_new_token(gen.text, chunk=gen) diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index c1512d477..0046a4d63 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -6,7 +6,8 @@ import warnings from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Union +from types import SimpleNamespace +from typing import Optional, Union, cast from unittest.mock import ANY, Mock, patch import google.ai.generativelanguage as glm @@ -968,6 +969,194 @@ def test_response_to_result_grounding_metadata( assert grounding_metadata == expected_grounding_metadata +def test_response_to_result_usage_details() -> None: + """Ensure usage metadata includes tool and modality details.""" + base_response = GenerateContentResponse() + response = cast( + GenerateContentResponse, + SimpleNamespace( + prompt_feedback=base_response.prompt_feedback, + candidates=[ + Candidate( + content=Content(parts=[Part(text="Hello")]), + finish_reason=Candidate.FinishReason.STOP, + ) + ], + model_version="models/gemini-2.5-pro", + usage_metadata={ + "prompt_token_count": 12, + "tool_use_prompt_token_count": 8, + "candidates_token_count": 20, + "thoughts_token_count": 5, + "cached_content_token_count": 3, + "total_token_count": 45, + "prompt_tokens_details": [ + {"modality": "TEXT", "token_count": 12}, + {"modality": "AUDIO", "token_count": 2}, + ], + "tool_use_prompt_tokens_details": [ + { + "modality": "TEXT", + "token_count": 8, + } + ], + "candidates_tokens_details": [ + { + "modality": "TEXT", + "token_count": 20, + } + ], + }, + ), + ) + + result = _response_to_result(response, stream=False) + message = cast(AIMessage, result.generations[0].message) + + assert message.usage_metadata is not None + usage_metadata = message.usage_metadata + assert usage_metadata["input_tokens"] == 20 + assert usage_metadata["output_tokens"] == 25 + assert usage_metadata["total_tokens"] == 45 + + input_details = cast( + dict[str, int], usage_metadata.get("input_token_details", {}) or {} + ) + assert input_details["tool_use_prompt"] == 8 + assert input_details["cache_read"] == 3 + assert input_details["text"] == 12 + assert input_details["audio"] == 2 + assert input_details["tool_use_prompt_text"] == 8 + + output_details = cast( + dict[str, int], usage_metadata.get("output_token_details", {}) or {} + ) + assert output_details["reasoning"] == 5 + assert output_details["text"] == 20 + + llm_output = result.llm_output or {} + assert llm_output.get("usage_metadata") == usage_metadata + + +def test_response_to_result_streaming_delta_details() -> None: + """Streaming responses should expose per-chunk usage deltas.""" + base_response = GenerateContentResponse() + chunk_one = cast( + GenerateContentResponse, + SimpleNamespace( + prompt_feedback=base_response.prompt_feedback, + model_version="models/gemini-stream", + candidates=[ + Candidate( + content=Content(parts=[Part(text="First")]), + finish_reason=Candidate.FinishReason.STOP, + ) + ], + usage_metadata={ + "prompt_token_count": 10, + "tool_use_prompt_token_count": 2, + "candidates_token_count": 6, + "thoughts_token_count": 4, + "cached_content_token_count": 3, + "total_token_count": 22, + "prompt_tokens_details": [ + { + "modality": "TEXT", + "token_count": 10, + } + ], + "candidates_tokens_details": [ + { + "modality": "TEXT", + "token_count": 6, + } + ], + }, + ), + ) + + first_result = _response_to_result(chunk_one, stream=True) + first_message = cast(AIMessage, first_result.generations[0].message) + first_usage = first_message.usage_metadata + assert first_usage is not None + assert first_usage["input_tokens"] == 12 + assert first_usage["output_tokens"] == 10 + assert first_usage["total_tokens"] == 22 + first_input_details = cast( + dict[str, int], first_usage.get("input_token_details", {}) or {} + ) + first_output_details = cast( + dict[str, int], first_usage.get("output_token_details", {}) or {} + ) + assert first_input_details["tool_use_prompt"] == 2 + assert first_output_details["reasoning"] == 4 + llm_output_first = first_result.llm_output or {} + prev_cumulative_usage = llm_output_first.get("usage_metadata") + assert prev_cumulative_usage is not None + + chunk_two = cast( + GenerateContentResponse, + SimpleNamespace( + prompt_feedback=base_response.prompt_feedback, + model_version="models/gemini-stream", + candidates=[ + Candidate( + content=Content(parts=[Part(text="Second")]), + finish_reason=Candidate.FinishReason.STOP, + ) + ], + usage_metadata={ + "prompt_token_count": 16, + "tool_use_prompt_token_count": 4, + "candidates_token_count": 12, + "thoughts_token_count": 8, + "cached_content_token_count": 5, + "total_token_count": 40, + "prompt_tokens_details": [ + { + "modality": "TEXT", + "token_count": 14, + } + ], + "candidates_tokens_details": [ + { + "modality": "TEXT", + "token_count": 12, + } + ], + }, + ), + ) + + second_result = _response_to_result( + chunk_two, + stream=True, + prev_usage=prev_cumulative_usage, + ) + second_message = cast(AIMessage, second_result.generations[0].message) + second_usage = second_message.usage_metadata + assert second_usage is not None + assert second_usage["input_tokens"] == 8 + assert second_usage["output_tokens"] == 10 + assert second_usage["total_tokens"] == 18 + + input_details_delta = cast( + dict[str, int], second_usage.get("input_token_details", {}) or {} + ) + assert input_details_delta["tool_use_prompt"] == 2 + assert input_details_delta["cache_read"] == 2 + assert input_details_delta["text"] == 4 + + output_details_delta = cast( + dict[str, int], second_usage.get("output_token_details", {}) or {} + ) + assert output_details_delta["reasoning"] == 4 + assert output_details_delta["text"] == 6 + + llm_output = second_result.llm_output or {} + assert llm_output.get("usage_metadata") is not None + + @pytest.mark.parametrize( "is_async,mock_target,method_name", [ diff --git a/libs/vertexai/langchain_google_vertexai/_usage.py b/libs/vertexai/langchain_google_vertexai/_usage.py new file mode 100644 index 000000000..daca26282 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/_usage.py @@ -0,0 +1,227 @@ +"""Usage metadata helpers for Vertex AI Gemini models.""" + +from __future__ import annotations + +import re +from typing import Any, Mapping, Optional, Sequence, cast + +import proto # type: ignore[import-untyped] +from langchain_core.messages.ai import ( + UsageMetadata, +) + + +def _sanitize_token_detail_key(raw_key: str) -> str: + sanitized = re.sub(r"[^0-9a-zA-Z]+", "_", raw_key.strip().lower()).strip("_") + return sanitized or "unknown" + + +def _extract_token_detail_counts( + entries: Sequence[Mapping[str, Any]] | None, + *, + prefix: str | None = None, +) -> dict[str, int]: + if not entries: + return {} + detail_counts: dict[str, int] = {} + for entry in entries: + raw_key = entry.get("modality") or entry.get("type") or entry.get("name") + if not raw_key: + continue + raw_value = ( + entry.get("token_count") + or entry.get("tokenCount") + or entry.get("tokens_count") + or entry.get("tokensCount") + or entry.get("count") + ) + try: + value_int = int(raw_value or 0) + except (TypeError, ValueError): + value_int = 0 + if value_int == 0: + continue + key = _sanitize_token_detail_key(str(raw_key)) + if prefix: + key = f"{prefix}{key}" + detail_counts[key] = detail_counts.get(key, 0) + value_int + return detail_counts + + +def _merge_detail_counts(target: dict[str, int], new_entries: dict[str, int]) -> None: + for key, value in new_entries.items(): + target[key] = target.get(key, 0) + value + + +def _usage_proto_to_dict(raw_usage: Any) -> dict[str, Any]: + if raw_usage is None: + return {} + if isinstance(raw_usage, Mapping): + return dict(raw_usage) + try: + return proto.Message.to_dict(raw_usage) + except Exception: # pragma: no cover + try: + return dict(raw_usage) + except Exception: # pragma: no cover + return {} + + +def coerce_usage_metadata(raw_usage: Any) -> Optional[UsageMetadata]: + usage_dict = _usage_proto_to_dict(raw_usage) + if not usage_dict: + return None + + def _get_int(name: str) -> int: + value = usage_dict.get(name) + try: + return int(value or 0) + except (TypeError, ValueError): + return 0 + + prompt_tokens = _get_int("prompt_token_count") + response_tokens = ( + _get_int("candidates_token_count") + or _get_int("response_token_count") + or _get_int("output_token_count") + ) + tool_prompt_tokens = _get_int("tool_use_prompt_token_count") + reasoning_tokens = _get_int("thoughts_token_count") or _get_int( + "reasoning_token_count" + ) + cache_read_tokens = _get_int("cached_content_token_count") + + if all( + count == 0 + for count in ( + prompt_tokens, + response_tokens, + tool_prompt_tokens, + reasoning_tokens, + cache_read_tokens, + ) + ): + return None + + input_tokens = prompt_tokens + tool_prompt_tokens + output_tokens = response_tokens + reasoning_tokens + total_tokens = _get_int("total_token_count") or _get_int("total_tokens") + if total_tokens == 0 or total_tokens != input_tokens + output_tokens: + total_tokens = input_tokens + output_tokens + + input_details: dict[str, int] = {} + if cache_read_tokens: + input_details["cache_read"] = cache_read_tokens + if tool_prompt_tokens: + input_details["tool_use_prompt"] = tool_prompt_tokens + _merge_detail_counts( + input_details, + _extract_token_detail_counts( + usage_dict.get("prompt_tokens_details") + or usage_dict.get("promptTokensDetails"), + ), + ) + _merge_detail_counts( + input_details, + _extract_token_detail_counts( + usage_dict.get("tool_use_prompt_tokens_details") + or usage_dict.get("toolUsePromptTokensDetails"), + prefix="tool_use_prompt_", + ), + ) + _merge_detail_counts( + input_details, + _extract_token_detail_counts( + usage_dict.get("cache_tokens_details") + or usage_dict.get("cacheTokensDetails"), + prefix="cache_", + ), + ) + + output_details: dict[str, int] = {} + if reasoning_tokens: + output_details["reasoning"] = reasoning_tokens + for key in ( + "candidates_tokens_details", + "candidatesTokensDetails", + "response_tokens_details", + "responseTokensDetails", + "output_tokens_details", + "outputTokensDetails", + "total_tokens_details", + "totalTokensDetails", + ): + _merge_detail_counts( + output_details, _extract_token_detail_counts(usage_dict.get(key)) + ) + + # Normalize alternate reasoning keys if the API provides e.g. "thought" buckets. + for alt_key in ("thought", "thoughts", "reasoning_tokens"): + if alt_key in output_details: + output_details["reasoning"] = output_details.get( + "reasoning", 0 + ) + output_details.pop(alt_key) + + payload: dict[str, Any] = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + if input_details: + payload["input_token_details"] = input_details + if output_details: + payload["output_token_details"] = output_details + + return cast(UsageMetadata, payload) + + +def diff_usage_metadata( + current: Optional[UsageMetadata], previous: Optional[UsageMetadata] +) -> Optional[UsageMetadata]: + if not current: + return None + if not previous: + return current + + input_delta = current.get("input_tokens", 0) - previous.get("input_tokens", 0) + output_delta = current.get("output_tokens", 0) - previous.get("output_tokens", 0) + total_delta = current.get("total_tokens", 0) - previous.get("total_tokens", 0) + expected_total = input_delta + output_delta + if total_delta != expected_total: + total_delta = expected_total + + payload: dict[str, Any] = { + "input_tokens": input_delta, + "output_tokens": output_delta, + "total_tokens": total_delta, + } + + prev_input_details = cast( + dict[str, int], previous.get("input_token_details", {}) or {} + ) + curr_input_details = cast( + dict[str, int], current.get("input_token_details", {}) or {} + ) + input_detail_delta = { + key: curr_input_details.get(key, 0) - prev_input_details.get(key, 0) + for key in set(prev_input_details).union(curr_input_details) + } + input_detail_delta = {k: v for k, v in input_detail_delta.items() if v != 0} + if input_detail_delta: + payload["input_token_details"] = input_detail_delta + + prev_output_details = cast( + dict[str, int], previous.get("output_token_details", {}) or {} + ) + curr_output_details = cast( + dict[str, int], current.get("output_token_details", {}) or {} + ) + output_detail_delta = { + key: curr_output_details.get(key, 0) - prev_output_details.get(key, 0) + for key in set(prev_output_details).union(curr_output_details) + } + output_detail_delta = {k: v for k, v in output_detail_delta.items() if v != 0} + if output_detail_delta: + payload["output_token_details"] = output_detail_delta + + return cast(UsageMetadata, payload) diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 5dc5bca3a..b59c148b6 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -3,6 +3,7 @@ from __future__ import annotations # noqa import ast import base64 +import math from functools import cached_property import json import logging @@ -24,9 +25,15 @@ TypedDict, overload, ) -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Iterator, Sequence, Mapping import proto # type: ignore[import-untyped] +from google.protobuf.json_format import SerializeToJsonError +from google.protobuf.struct_pb2 import ( + ListValue, + Struct, + Value, +) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -51,7 +58,9 @@ convert_to_openai_image_block, is_data_content_block, ) -from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.ai import ( + UsageMetadata, +) from langchain_core.messages.tool import ( tool_call_chunk, tool_call as create_tool_call, @@ -74,6 +83,10 @@ ) from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.utils import _build_model_kwargs +from langchain_google_vertexai._usage import ( + coerce_usage_metadata, + diff_usage_metadata, +) from vertexai.generative_models import ( Tool as VertexTool, # TODO: migrate to google-genai since this is deprecated ) @@ -608,6 +621,80 @@ def _append_to_content( raise TypeError(msg) +def _json_safe_number(value: float) -> Union[str, float]: + if math.isnan(value): + return "NaN" + if math.isinf(value): + return "Infinity" if value > 0 else "-Infinity" + return value + + +def _struct_value_to_jsonable(value: Value) -> Any: + kind = value.WhichOneof("kind") + if kind == "number_value": + return _json_safe_number(value.number_value) + if kind == "string_value": + return value.string_value + if kind == "bool_value": + return bool(value.bool_value) + if kind == "null_value": + return None + if kind == "struct_value": + return _struct_to_jsonable(value.struct_value) + if kind == "list_value": + return [_struct_value_to_jsonable(item) for item in value.list_value.values] + return None + + +def _struct_to_jsonable(struct: Struct) -> dict[str, Any]: + return {key: _struct_value_to_jsonable(val) for key, val in struct.fields.items()} + + +def _make_jsonable(value: Any) -> Any: + if isinstance(value, Value): + return _struct_value_to_jsonable(value) + if isinstance(value, Struct): + return _struct_to_jsonable(value) + if isinstance(value, ListValue): + return [_struct_value_to_jsonable(item) for item in value.values] + if isinstance(value, Mapping): + return {str(key): _make_jsonable(val) for key, val in value.items()} + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [_make_jsonable(item) for item in value] + if isinstance(value, float): + return _json_safe_number(value) + return value + + +def _coerce_function_call_args(function_call: FunctionCall) -> dict[str, Any]: + try: + fc_dict = proto.Message.to_dict(function_call) + except SerializeToJsonError: + fc_dict = {} + except TypeError: + fc_dict = {} + + args_dict: Any = fc_dict.get("args") if isinstance(fc_dict, dict) else None + if isinstance(args_dict, dict): + return dict(args_dict) + + struct_args = getattr(function_call, "args", None) + if isinstance(struct_args, Struct): + return _struct_to_jsonable(struct_args) + if isinstance(struct_args, Mapping): + return {str(key): _make_jsonable(val) for key, val in struct_args.items()} + + if struct_args is not None: + try: + fallback_dict = proto.Message.to_dict(struct_args) + except Exception: + fallback_dict = {} + if isinstance(fallback_dict, dict): + return fallback_dict + + return {} + + @overload def _parse_response_candidate( response_candidate: Candidate, streaming: Literal[False] = False @@ -651,9 +738,10 @@ def _parse_response_candidate( # but in general the full set of function calls is stored in tool_calls. function_call = {"name": part.function_call.name} # dump to match other function calling llm for now - function_call_args_dict = proto.Message.to_dict(part.function_call)["args"] + function_call_args_dict = _coerce_function_call_args(part.function_call) function_call["arguments"] = json.dumps( - {k: function_call_args_dict[k] for k in function_call_args_dict} + function_call_args_dict, + allow_nan=False, ) additional_kwargs["function_call"] = function_call @@ -2697,17 +2785,7 @@ def _gemini_chunk_to_generation_chunk( # Note: some models (e.g., gemini-1.5-pro with image inputs) return # cumulative sums of token counts. total_lc_usage = _get_usage_metadata_gemini(usage_metadata) - if total_lc_usage and prev_total_usage: - lc_usage: Optional[UsageMetadata] = UsageMetadata( - input_tokens=total_lc_usage["input_tokens"] - - prev_total_usage["input_tokens"], - output_tokens=total_lc_usage["output_tokens"] - - prev_total_usage["output_tokens"], - total_tokens=total_lc_usage["total_tokens"] - - prev_total_usage["total_tokens"], - ) - else: - lc_usage = total_lc_usage + lc_usage = diff_usage_metadata(total_lc_usage, prev_total_usage) if not response_chunk.candidates: message = AIMessageChunk(content="") if lc_usage: @@ -2738,30 +2816,7 @@ def _gemini_chunk_to_generation_chunk( def _get_usage_metadata_gemini(raw_metadata: dict) -> Optional[UsageMetadata]: """Get UsageMetadata from raw response metadata.""" - input_tokens = raw_metadata.get("prompt_token_count", 0) - output_tokens = raw_metadata.get("candidates_token_count", 0) - total_tokens = raw_metadata.get("total_token_count", 0) - thought_tokens = raw_metadata.get("thoughts_token_count", 0) - cache_read_tokens = raw_metadata.get("cached_content_token_count", 0) - if all( - count == 0 - for count in [input_tokens, output_tokens, total_tokens, cache_read_tokens] - ): - return None - if thought_tokens > 0: - return UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - input_token_details={"cache_read": cache_read_tokens}, - output_token_details={"reasoning": thought_tokens}, - ) - return UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - input_token_details={"cache_read": cache_read_tokens}, - ) + return coerce_usage_metadata(raw_metadata) def _get_tool_name(tool: _ToolType) -> str: diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index 39b939519..68e58b868 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -896,7 +896,7 @@ def test_chat_vertexai_gemini_thinking_auto() -> None: assert response.usage_metadata["output_token_details"]["reasoning"] > 0 assert ( response.usage_metadata["total_tokens"] - > response.usage_metadata["input_tokens"] + == response.usage_metadata["input_tokens"] + response.usage_metadata["output_tokens"] ) @@ -912,7 +912,7 @@ def test_chat_vertexai_gemini_thinking_configured() -> None: assert response.usage_metadata["output_token_details"]["reasoning"] <= 100 assert ( response.usage_metadata["total_tokens"] - > response.usage_metadata["input_tokens"] + == response.usage_metadata["input_tokens"] + response.usage_metadata["output_tokens"] ) @@ -943,7 +943,7 @@ def test_chat_vertexai_gemini_thinking_auto_include_thoughts() -> None: assert response.usage_metadata["output_token_details"]["reasoning"] > 0 assert ( response.usage_metadata["total_tokens"] - > response.usage_metadata["input_tokens"] + == response.usage_metadata["input_tokens"] + response.usage_metadata["output_tokens"] ) diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 539fde32c..7f78f0575 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -1016,6 +1016,51 @@ def test_default_params_gemini() -> None: }, ), ), + ( + Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="handle_numbers", + args={ + "positive": float("inf"), + "negative": float("-inf"), + "not_a_number": float("nan"), + }, + ), + ), + ], + ) + ), + AIMessage( + content="", + tool_calls=[ + create_tool_call( + name="handle_numbers", + args={ + "positive": "Infinity", + "negative": "-Infinity", + "not_a_number": "NaN", + }, + id="00000000-0000-0000-0000-00000000000", + ), + ], + additional_kwargs={ + "function_call": { + "name": "handle_numbers", + "arguments": json.dumps( + { + "positive": "Infinity", + "negative": "-Infinity", + "not_a_number": "NaN", + } + ), + } + }, + ), + ), ], ) def test_parse_response_candidate(raw_candidate, expected) -> None: diff --git a/libs/vertexai/tests/unit_tests/test_usage_metadata.py b/libs/vertexai/tests/unit_tests/test_usage_metadata.py new file mode 100644 index 000000000..fb2d59f1d --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_usage_metadata.py @@ -0,0 +1,92 @@ +"""Usage metadata tests for Vertex AI Gemini helpers.""" + +import importlib.util +from pathlib import Path + +USAGE_HELPERS_PATH = ( + Path(__file__).resolve().parents[2] / "langchain_google_vertexai" / "_usage.py" +) + +spec = importlib.util.spec_from_file_location("vertex_usage", USAGE_HELPERS_PATH) +assert spec and spec.loader +usage_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(usage_module) + +coerce_usage_metadata = usage_module.coerce_usage_metadata +diff_usage_metadata = usage_module.diff_usage_metadata + + +def test_get_usage_metadata_gemini_details() -> None: + raw_metadata = { + "prompt_token_count": 12, + "tool_use_prompt_token_count": 8, + "candidates_token_count": 20, + "thoughts_token_count": 5, + "cached_content_token_count": 3, + "total_token_count": 45, + "prompt_tokens_details": [ + {"modality": "TEXT", "token_count": 12}, + {"modality": "AUDIO", "token_count": 2}, + ], + "tool_use_prompt_tokens_details": [{"modality": "TEXT", "token_count": 8}], + "candidates_tokens_details": [{"modality": "TEXT", "token_count": 20}], + } + + usage = coerce_usage_metadata(raw_metadata) + assert usage is not None + assert usage["input_tokens"] == 20 + assert usage["output_tokens"] == 25 + assert usage["total_tokens"] == 45 + + input_details = usage.get("input_token_details", {}) or {} + assert input_details["tool_use_prompt"] == 8 + assert input_details["cache_read"] == 3 + assert input_details["text"] == 12 + assert input_details["audio"] == 2 + assert input_details["tool_use_prompt_text"] == 8 + + output_details = usage.get("output_token_details", {}) or {} + assert output_details["reasoning"] == 5 + assert output_details["text"] == 20 + + +def test_get_usage_metadata_gemini_delta() -> None: + first_raw = { + "prompt_token_count": 10, + "tool_use_prompt_token_count": 2, + "candidates_token_count": 6, + "thoughts_token_count": 4, + "cached_content_token_count": 3, + "total_token_count": 22, + "prompt_tokens_details": [{"modality": "TEXT", "token_count": 10}], + "candidates_tokens_details": [{"modality": "TEXT", "token_count": 6}], + } + + second_raw = { + "prompt_token_count": 16, + "tool_use_prompt_token_count": 4, + "candidates_token_count": 12, + "thoughts_token_count": 8, + "cached_content_token_count": 5, + "total_token_count": 40, + "prompt_tokens_details": [{"modality": "TEXT", "token_count": 14}], + "candidates_tokens_details": [{"modality": "TEXT", "token_count": 12}], + } + + first_usage = coerce_usage_metadata(first_raw) + second_usage = coerce_usage_metadata(second_raw) + assert first_usage is not None + assert second_usage is not None + + delta = diff_usage_metadata(second_usage, first_usage) + assert delta is not None + + assert delta["input_tokens"] == 8 + assert delta["output_tokens"] == 10 + assert delta["total_tokens"] == 18 + + assert delta["input_token_details"]["tool_use_prompt"] == 2 + assert delta["input_token_details"]["cache_read"] == 2 + assert delta["input_token_details"]["text"] == 4 + assert delta["output_token_details"]["reasoning"] == 4 + assert delta["output_token_details"]["text"] == 6