diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index 598cf75a0..1d3ecbb07 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -283,36 +283,36 @@ def _format_base_tool_to_function_declaration( ) -def _convert_pydantic_to_genai_function( - pydantic_model: Type[BaseModel], - tool_name: Optional[str] = None, - tool_description: Optional[str] = None, -) -> gapic.FunctionDeclaration: - if issubclass(pydantic_model, BaseModel): - schema = pydantic_model.model_json_schema() - elif issubclass(pydantic_model, BaseModelV1): - schema = pydantic_model.schema() - else: - raise NotImplementedError( - f"pydantic_model must be a Pydantic BaseModel, got {pydantic_model}" - ) - schema = dereference_refs(schema) - schema.pop("definitions", None) - function_declaration = gapic.FunctionDeclaration( - name=tool_name if tool_name else schema.get("title"), - description=tool_description if tool_description else schema.get("description"), - parameters={ - "properties": _get_properties_from_schema_any( - schema.get("properties") - ), # TODO: use _dict_to_gapic_schema() if possible - # "items": _get_items_from_schema_any( - # schema - # ), # TODO: fix it https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling?hl#schema - "required": schema.get("required", []), - "type_": TYPE_ENUM[schema["type"]], - }, - ) - return function_declaration +def _convert_pydantic_to_genai_function(model: BaseModel, function_name: str = None) -> dict: + """ + Converts a Pydantic BaseModel into a dictionary representing a Gemini function tool. + + Args: + model: The Pydantic BaseModel defining the function's parameters. + function_name: Optional. The name of the function. If not provided, + the Pydantic model's class name will be used (converted to snake_case). + + Returns: + A dictionary formatted as a Gemini function tool. + + Notes: + This follows the Gemini function tool schema: + https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling#schema + """ + if function_name is None: + # Convert CamelCase Pydantic model name to snake_case function name + function_name = ''.join(['_' + i.lower() if i.isupper() else i for i in model.__name__]).lstrip('_') + + return { + "function_declarations": [ + { + "name": function_name, + "description": model.__doc__.strip() if model.__doc__ else "", + "parameters": model.model_json_schema(), + } + ] + } + def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]: diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index df9e552ed..08c995074 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -1404,3 +1404,133 @@ class ToolInfo(BaseModel): items_property = kind_property["items"] assert items_property["type_"] == glm.Type.STRING assert items_property["enum"] == ["foo", "bar"] + + +def test_convert_pydantic_to_genai_function() -> None: + """ + Test the convert_pydantic_to_genai_function function with various Pydantic models. + """ + from typing import Optional, List, Dict, Union + from pydantic import BaseModel, Field + + # Test 1: Simple model with basic fields + class SimpleModel(BaseModel): + """A simple test model.""" + name: str + age: int + active: bool = True + + result = _convert_pydantic_to_genai_function(SimpleModel) + + assert isinstance(result, dict), "Expected result to be a dict" + assert "function_declarations" in result, "Expected 'function_declarations' key" + assert len(result["function_declarations"]) == 1, "Expected one function declaration" + + fn_decl = result["function_declarations"][0] + assert fn_decl["name"] == "simple_model", "Expected snake_case name conversion" + assert fn_decl["description"] == "A simple test model.", "Expected correct description" + assert "parameters" in fn_decl, "Expected 'parameters' key" + + schema = fn_decl["parameters"] + assert "properties" in schema, "Expected 'properties' in schema" + assert "name" in schema["properties"], "Expected 'name' property" + assert "age" in schema["properties"], "Expected 'age' property" + assert "active" in schema["properties"], "Expected 'active' property" + + # Test 2: Custom function name + result_custom = _convert_pydantic_to_genai_function(SimpleModel, "custom_function") + fn_decl_custom = result_custom["function_declarations"][0] + assert fn_decl_custom["name"] == "custom_function", "Expected custom function name" + + # Test 3: Model with Field descriptions + class ModelWithFields(BaseModel): + """Model with field descriptions.""" + location: str = Field(..., description="The location name") + radius: Optional[float] = Field(None, description="Search radius in km") + + result_fields = _convert_pydantic_to_genai_function(ModelWithFields) + fn_decl_fields = result_fields["function_declarations"][0] + + assert fn_decl_fields["name"] == "model_with_fields", "Expected correct snake_case conversion" + assert "Model with field descriptions." in fn_decl_fields["description"], "Expected correct description" + + # Test 4: Model with complex types + class ComplexModel(BaseModel): + """Model with complex field types.""" + tags: List[str] = [] + metadata: Dict[str, str] = {} + settings: Union[str, Dict[str, str]] = "default" + + result_complex = _convert_pydantic_to_genai_function(ComplexModel) + fn_decl_complex = result_complex["function_declarations"][0] + + assert fn_decl_complex["name"] == "complex_model", "Expected correct name" + schema_complex = fn_decl_complex["parameters"] + assert "tags" in schema_complex["properties"], "Expected 'tags' property" + assert "metadata" in schema_complex["properties"], "Expected 'metadata' property" + assert "settings" in schema_complex["properties"], "Expected 'settings' property" + + # Test 5: Model without docstring + class NoDocModel(BaseModel): + value: int + + result_no_doc = _convert_pydantic_to_genai_function(NoDocModel) + fn_decl_no_doc = result_no_doc["function_declarations"][0] + assert fn_decl_no_doc["description"] == "", "Expected empty description for model without docstring" + + +def test_convert_pydantic_to_genai_function_with_nested_models() -> None: + """ + Test convert_pydantic_to_genai_function with nested Pydantic models. + """ + from typing import List + from pydantic import BaseModel + + class Address(BaseModel): + """Address information.""" + street: str + city: str + zip_code: str + + class Person(BaseModel): + """Person with address.""" + name: str + address: Address + addresses: List[Address] = [] + + result = _convert_pydantic_to_genai_function(Person) + + assert isinstance(result, dict), "Expected result to be a dict" + fn_decl = result["function_declarations"][0] + assert fn_decl["name"] == "person", "Expected correct function name" + + schema = fn_decl["parameters"] + assert "address" in schema["properties"], "Expected 'address' property" + assert "addresses" in schema["properties"], "Expected 'addresses' property" + + +def test_convert_pydantic_to_genai_function_integration() -> None: + """ + Integration test to ensure the function works with the existing convert_to_genai_function_declarations. + """ + from pydantic import BaseModel + + class SearchQuery(BaseModel): + """Search for information.""" + query: str + limit: int = 10 + + # Test that our function output can be used with existing infrastructure + result = _convert_pydantic_to_genai_function(SearchQuery) + + # This should work with the existing convert_to_genai_function_declarations function + try: + genai_tool = convert_to_genai_function_declarations([result]) + assert genai_tool is not None, "Expected successful conversion" + assert len(genai_tool.function_declarations) == 1, "Expected one function declaration" + + fn_decl = genai_tool.function_declarations[0] + assert fn_decl.name == "search_query", "Expected correct function name" + assert "Search for information." in fn_decl.description, "Expected correct description" + except Exception as e: + pytest.fail(f"Integration with existing functions failed: {e}")