Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
from difflib import get_close_matches

from langchain_google_vertexai.functions_utils import (
_dict_to_gapic_schema_utils,
_check_v2,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -2226,8 +2229,13 @@ class Explanation(BaseModel):
if isinstance(schema, type) and is_basemodel_subclass(schema):
if issubclass(schema, BaseModelV1):
schema_json = schema.schema()
pydantic_version = "v1"
else:
schema_json = schema.model_json_schema()
pydantic_version = "v2"
schema_json = _dict_to_gapic_schema_utils(
schema_json, pydantic_version=pydantic_version
)
parser = PydanticOutputParser(pydantic_object=schema)
else:
if is_typeddict(schema):
Expand All @@ -2236,6 +2244,14 @@ class Explanation(BaseModel):
schema_json = schema
else:
raise ValueError(f"Unsupported schema type {type(schema)}")

pydantic_version_v2 = _check_v2(schema_json)
if pydantic_version_v2:
schema_json = _dict_to_gapic_schema_utils(
schema_json, pydantic_version="v2"
)
else:
schema_json = _dict_to_gapic_schema_utils(schema_json)
parser = JsonOutputParser()

# Resolve refs in schema because they are not supported
Expand Down
42 changes: 26 additions & 16 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,10 @@ def _format_json_schema_to_gapic(
return converted_schema


def _dict_to_gapic_schema(
def _dict_to_gapic_schema_utils(
schema: Dict[str, Any], pydantic_version: str = "v1"
) -> gapic.Schema:
) -> Dict[str, Any]:
"""Convert the schema to make gemini understand."""
# Resolve refs in schema because $refs and $defs are not supported
# by the Gemini API.
dereferenced_schema = dereference_refs(schema)
Expand All @@ -180,6 +181,13 @@ def _dict_to_gapic_schema(
formatted_schema = _format_json_schema_to_gapic_v1(dereferenced_schema)
else:
formatted_schema = _format_json_schema_to_gapic(dereferenced_schema)
return formatted_schema


def _dict_to_gapic_schema(
schema: Dict[str, Any], pydantic_version: str = "v1"
) -> gapic.Schema:
formatted_schema = _dict_to_gapic_schema_utils(schema, pydantic_version)
json_schema = json.dumps(formatted_schema)
return gapic.Schema.from_json(json_schema)

Expand Down Expand Up @@ -234,25 +242,27 @@ def _format_pydantic_to_function_declaration(
)


# Ensure we send "anyOf" parameters through pydantic v2 schema parsing
def _check_v2(parameters):
properties = parameters.get("properties", {}).values()
for property in properties:
if "anyOf" in property:
return True
if "parameters" in property:
if _check_v2(property["parameters"]):
return True
if "items" in property:
if _check_v2(property["items"]):
return True

return False


def _format_dict_to_function_declaration(
tool: Union[FunctionDescription, Dict[str, Any]],
) -> gapic.FunctionDeclaration:
pydantic_version_v2 = False

# Ensure we send "anyOf" parameters through pydantic v2 schema parsing
def _check_v2(parameters):
properties = parameters.get("properties", {}).values()
for property in properties:
if "anyOf" in property:
return True
if "parameters" in property:
if _check_v2(property["parameters"]):
return True
if "items" in property:
if _check_v2(property["items"]):
return True
return False

if isinstance(tool, dict):
pydantic_version_v2 = _check_v2(tool.get("parameters", {}))
if pydantic_version_v2:
Expand Down