Skip to content
Merged
74 changes: 67 additions & 7 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,66 @@ def _parse_response_candidate(
)


def _extract_grounding_metadata(candidate: Any) -> Dict[str, Any]:
"""Extract grounding metadata from candidate.

Uses `proto.Message.to_dict()` for complete unfiltered extraction first,
falls back to custom field extraction in cases of failure for robustness.
"""
if not hasattr(candidate, "grounding_metadata") or not candidate.grounding_metadata:
return {}

grounding_metadata = candidate.grounding_metadata

try:
return proto.Message.to_dict(grounding_metadata)
except (AttributeError, TypeError):
# Fallback: field extraction
result: Dict[str, Any] = {}

# Extract grounding chunks
if hasattr(grounding_metadata, "grounding_chunks"):
grounding_chunks = []
for chunk in grounding_metadata.grounding_chunks:
chunk_data: Dict[str, Any] = {}
if hasattr(chunk, "web") and chunk.web:
chunk_data["web"] = {
"uri": chunk.web.uri if hasattr(chunk.web, "uri") else "",
"title": chunk.web.title if hasattr(chunk.web, "title") else "",
}
grounding_chunks.append(chunk_data)
result["grounding_chunks"] = grounding_chunks

# Extract grounding supports
if hasattr(grounding_metadata, "grounding_supports"):
grounding_supports = []
for support in grounding_metadata.grounding_supports:
support_data: Dict[str, Any] = {}
if hasattr(support, "segment") and support.segment:
support_data["segment"] = {
"start_index": getattr(support.segment, "start_index", 0),
"end_index": getattr(support.segment, "end_index", 0),
"text": getattr(support.segment, "text", ""),
"part_index": getattr(support.segment, "part_index", 0),
}
if hasattr(support, "grounding_chunk_indices"):
support_data["grounding_chunk_indices"] = list(
support.grounding_chunk_indices
)
if hasattr(support, "confidence_scores"):
support_data["confidence_scores"] = [
round(score, 6) for score in support.confidence_scores
]
grounding_supports.append(support_data)
result["grounding_supports"] = grounding_supports

# Extract web search queries
if hasattr(grounding_metadata, "web_search_queries"):
result["web_search_queries"] = list(grounding_metadata.web_search_queries)

return result


def _response_to_result(
response: GenerateContentResponse,
stream: bool = False,
Expand Down Expand Up @@ -800,15 +860,15 @@ def _response_to_result(
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
for safety_rating in candidate.safety_ratings
]
try:
if candidate.grounding_metadata:
generation_info["grounding_metadata"] = proto.Message.to_dict(
candidate.grounding_metadata
)
except AttributeError:
pass
grounding_metadata = _extract_grounding_metadata(candidate)
generation_info["grounding_metadata"] = grounding_metadata
message = _parse_response_candidate(candidate, streaming=stream)
message.usage_metadata = lc_usage

if not hasattr(message, "response_metadata"):
message.response_metadata = {}
message.response_metadata["grounding_metadata"] = grounding_metadata

if stream:
generations.append(
ChatGenerationChunk(
Expand Down