diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index b391f8457..71463c2a7 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -145,7 +145,10 @@ from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self, is_typeddict from difflib import get_close_matches - +from langchain_google_vertexai.functions_utils import ( + _dict_to_gapic_schema_utils, + _check_v2, +) logger = logging.getLogger(__name__) @@ -2226,8 +2229,13 @@ class Explanation(BaseModel): if isinstance(schema, type) and is_basemodel_subclass(schema): if issubclass(schema, BaseModelV1): schema_json = schema.schema() + pydantic_version = "v1" else: schema_json = schema.model_json_schema() + pydantic_version = "v2" + schema_json = _dict_to_gapic_schema_utils( + schema_json, pydantic_version=pydantic_version + ) parser = PydanticOutputParser(pydantic_object=schema) else: if is_typeddict(schema): @@ -2236,6 +2244,14 @@ class Explanation(BaseModel): schema_json = schema else: raise ValueError(f"Unsupported schema type {type(schema)}") + + pydantic_version_v2 = _check_v2(schema_json) + if pydantic_version_v2: + schema_json = _dict_to_gapic_schema_utils( + schema_json, pydantic_version="v2" + ) + else: + schema_json = _dict_to_gapic_schema_utils(schema_json) parser = JsonOutputParser() # Resolve refs in schema because they are not supported diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 806f633d2..e9a4e08a2 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -169,9 +169,10 @@ def _format_json_schema_to_gapic( return converted_schema -def _dict_to_gapic_schema( +def _dict_to_gapic_schema_utils( schema: Dict[str, Any], pydantic_version: str = "v1" -) -> gapic.Schema: +) -> Dict[str, Any]: + """Convert the schema to make gemini understand.""" # Resolve refs in schema because $refs and $defs are not supported # by the Gemini API. dereferenced_schema = dereference_refs(schema) @@ -180,6 +181,13 @@ def _dict_to_gapic_schema( formatted_schema = _format_json_schema_to_gapic_v1(dereferenced_schema) else: formatted_schema = _format_json_schema_to_gapic(dereferenced_schema) + return formatted_schema + + +def _dict_to_gapic_schema( + schema: Dict[str, Any], pydantic_version: str = "v1" +) -> gapic.Schema: + formatted_schema = _dict_to_gapic_schema_utils(schema, pydantic_version) json_schema = json.dumps(formatted_schema) return gapic.Schema.from_json(json_schema) @@ -234,25 +242,27 @@ def _format_pydantic_to_function_declaration( ) +# Ensure we send "anyOf" parameters through pydantic v2 schema parsing +def _check_v2(parameters): + properties = parameters.get("properties", {}).values() + for property in properties: + if "anyOf" in property: + return True + if "parameters" in property: + if _check_v2(property["parameters"]): + return True + if "items" in property: + if _check_v2(property["items"]): + return True + + return False + + def _format_dict_to_function_declaration( tool: Union[FunctionDescription, Dict[str, Any]], ) -> gapic.FunctionDeclaration: pydantic_version_v2 = False - # Ensure we send "anyOf" parameters through pydantic v2 schema parsing - def _check_v2(parameters): - properties = parameters.get("properties", {}).values() - for property in properties: - if "anyOf" in property: - return True - if "parameters" in property: - if _check_v2(property["parameters"]): - return True - if "items" in property: - if _check_v2(property["items"]): - return True - return False - if isinstance(tool, dict): pydantic_version_v2 = _check_v2(tool.get("parameters", {})) if pydantic_version_v2: