Skip to content

⚡️ Speed up function tool_from_langchain by 64% #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 90 commits into
base: try-refinement
Choose a base branch
from

Conversation

codeflash-ai[bot]
Copy link

@codeflash-ai codeflash-ai bot commented Jul 22, 2025

📄 64% (0.64x) speedup for tool_from_langchain in pydantic_ai_slim/pydantic_ai/ext/langchain.py

⏱️ Runtime : 406 microseconds 248 microseconds (best of 158 runs)

📝 Explanation and details

Here is an optimized version of your program. The main runtime cost in your code is in the tool_from_langchain function, especially dictionary copying and sorting for the required fields and defaults. I've minimized unnecessary work, memory allocations, and dict operations by.

  • Avoiding a set for required names just to immediately sort it - instead, produce an ordered list once.
  • Building required, defaults in a single iteration.
  • Moved schema['required'] = required before setting 'additionalProperties' so overwrites don't happen.

Summary of high-impact changes:

  • Only one iteration through args to build both required/defaults.
  • No unnecessary dict copy when possible.
  • Efficient and minimal dictionary operations for merging kwargs/defaults.
  • No refactor of class design or signature, so full compatibility.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 8 Passed
🌀 Generated Regression Tests 29 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
⚙️ Existing Unit Tests and Runtime
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
ext/test_langchain.py::test_langchain_tool_conversion 5.42μs 4.96μs ✅9.24%
ext/test_langchain.py::test_langchain_tool_conversion_no_defaults 5.42μs 5.00μs ✅8.32%
ext/test_langchain.py::test_langchain_tool_conversion_no_required 4.92μs 4.79μs ✅2.61%
ext/test_langchain.py::test_langchain_tool_default_override 5.79μs 5.25μs ✅10.3%
ext/test_langchain.py::test_langchain_tool_defaults 5.58μs 5.29μs ✅5.52%
ext/test_langchain.py::test_langchain_tool_no_additional_properties 6.00μs 5.12μs ✅17.1%
ext/test_langchain.py::test_langchain_tool_positional 5.12μs 4.92μs ✅4.23%
🌀 Generated Regression Tests and Runtime
from typing import Any, Dict, Protocol

# imports
import pytest  # used for our unit tests
from pydantic_ai.ext.langchain import tool_from_langchain


# Helper: Minimal mock object for LangChainTool
class MockLangChainTool:
    def __init__(self, name, description, args, input_jsonschema, run_return_func):
        self.name = name
        self.description = description
        self.args = args
        self._input_jsonschema = input_jsonschema
        self._run_return_func = run_return_func
        self.run_calls = []  # For inspection in tests

    def get_input_jsonschema(self):
        return self._input_jsonschema.copy()

    def run(self, kwargs):
        self.run_calls.append(kwargs.copy())
        return self._run_return_func(kwargs)

# ----------------- UNIT TESTS -----------------

# 1. Basic Test Cases

def test_basic_required_and_optional_args():
    """
    Test with a tool that has both required and optional arguments.
    """
    args = {
        'a': {},
        'b': {'default': 10},
    }
    schema = {
        'type': 'object',
        'properties': {
            'a': {'type': 'integer'},
            'b': {'type': 'integer'},
        }
    }
    def run_fn(kwargs):
        return f"{kwargs['a']},{kwargs['b']}"
    tool = MockLangChainTool(
        name='mytool',
        description='desc',
        args=args,
        input_jsonschema=schema,
        run_return_func=run_fn
    )
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 5.42μs -> 5.25μs (3.16% faster)

    # Proxy should call run with defaults merged
    result = t.function(a=5)

    # Override default
    result2 = t.function(a=1, b=2)

def test_basic_no_optional_args():
    """
    Tool with all required arguments.
    """
    args = {'x': {}, 'y': {}}
    schema = {'type': 'object', 'properties': {'x': {'type': 'string'}, 'y': {'type': 'string'}}}
    def run_fn(kwargs): return kwargs['x'] + kwargs['y']
    tool = MockLangChainTool('t', 'd', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 5.17μs -> 4.75μs (8.78% faster)
    res = t.function(x='a', y='b')

def test_basic_no_required_args():
    """
    Tool with all optional arguments.
    """
    args = {'foo': {'default': 1}, 'bar': {'default': 2}}
    schema = {'type': 'object', 'properties': {'foo': {'type': 'integer'}, 'bar': {'type': 'integer'}}}
    def run_fn(kwargs): return kwargs['foo'] + kwargs['bar']
    tool = MockLangChainTool('opt', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 5.04μs -> 4.67μs (8.04% faster)

def test_basic_schema_preserves_properties():
    """
    The returned schema should not mutate the input schema dict.
    """
    args = {'a': {}}
    orig_schema = {'type': 'object', 'properties': {'a': {'type': 'integer'}}}
    tool = MockLangChainTool('t', 'd', args, orig_schema, lambda k: "ok")
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 4.62μs -> 4.46μs (3.75% faster)

def test_basic_proxy_asserts_no_args():
    """
    The proxy should assert that no positional args are passed.
    """
    args = {'a': {}}
    schema = {'type': 'object', 'properties': {'a': {'type': 'integer'}}}
    tool = MockLangChainTool('t', 'd', args, schema, lambda k: "ok")
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 4.79μs -> 4.38μs (9.53% faster)
    with pytest.raises(AssertionError):
        t.function(1, a=2)

# 2. Edge Test Cases

def test_edge_empty_args_and_schema():
    """
    Tool with no arguments at all.
    """
    args = {}
    schema = {'type': 'object', 'properties': {}}
    tool = MockLangChainTool('empty', 'desc', args, schema, lambda k: "empty")
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 4.62μs -> 4.12μs (12.1% faster)
    res = t.function()

def test_edge_input_schema_with_additional_properties_true():
    """
    If input schema already has 'additionalProperties', do not overwrite.
    """
    args = {'x': {}}
    schema = {'type': 'object', 'properties': {'x': {'type': 'string'}}, 'additionalProperties': True}
    tool = MockLangChainTool('t', 'd', args, schema, lambda k: "ok")
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 4.88μs -> 4.29μs (13.6% faster)

def test_edge_args_with_weird_names():
    """
    Arguments with names that are Python keywords or odd strings.
    """
    args = {'class': {}, 'def': {'default': 5}}
    schema = {'type': 'object', 'properties': {'class': {'type': 'integer'}, 'def': {'type': 'integer'}}}
    def run_fn(kwargs): return kwargs['class'] + kwargs['def']
    tool = MockLangChainTool('kw', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 4.96μs -> 4.38μs (13.3% faster)

def test_edge_default_value_is_none():
    """
    Optional argument with default=None should be included in defaults.
    """
    args = {'foo': {'default': None}}
    schema = {'type': 'object', 'properties': {'foo': {'type': 'integer'}}}
    def run_fn(kwargs): return kwargs['foo']
    tool = MockLangChainTool('none', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 4.83μs -> 4.25μs (13.7% faster)

def test_edge_required_sorting():
    """
    Required should be sorted alphabetically.
    """
    args = {'b': {}, 'a': {}, 'c': {'default': 1}}
    schema = {'type': 'object', 'properties': {'a': {}, 'b': {}, 'c': {}}}
    tool = MockLangChainTool('sort', 'desc', args, schema, lambda k: "ok")
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 5.12μs -> 4.42μs (16.0% faster)


def test_edge_run_kwargs_are_copied():
    """
    The proxy should not mutate the input kwargs dict.
    """
    args = {'x': {}, 'y': {'default': 2}}
    schema = {'type': 'object', 'properties': {'x': {}, 'y': {}}}
    captured = []
    def run_fn(kwargs):
        captured.append(kwargs.copy())
        return kwargs['x'] + kwargs['y']
    tool = MockLangChainTool('t', 'd', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 5.83μs -> 5.04μs (15.7% faster)
    kw = {'x': 1}
    t.function(**kw)

def test_edge_tool_with_many_defaults_and_required():
    """
    Tool with many required and optional args, check correct required and defaults.
    """
    args = {f'k{i}': {} if i % 2 == 0 else {'default': i} for i in range(10)}
    schema = {'type': 'object', 'properties': {f'k{i}': {'type': 'integer'} for i in range(10)}}
    def run_fn(kwargs): return sum(kwargs.values())
    tool = MockLangChainTool('many', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 6.50μs -> 5.46μs (19.1% faster)
    required = sorted(f'k{i}' for i in range(10) if i % 2 == 0)
    # Provide only required args
    call_args = {k: 1 for k in required}

def test_edge_tool_with_non_str_keys_in_args():
    """
    If args dict has non-str keys, should still work (though not expected in real use).
    """
    args = {1: {}, 'x': {'default': 5}}
    schema = {'type': 'object', 'properties': {1: {'type': 'integer'}, 'x': {'type': 'integer'}}}
    def run_fn(kwargs): return kwargs.get(1, 0) + kwargs.get('x', 0)
    tool = MockLangChainTool('weird', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 5.71μs -> 4.71μs (21.2% faster)

# 3. Large Scale Test Cases

def test_large_many_args_and_defaults():
    """
    Tool with 500 arguments, half required, half optional.
    """
    N = 500
    args = {f'k{i}': {} if i % 2 == 0 else {'default': i} for i in range(N)}
    schema = {'type': 'object', 'properties': {f'k{i}': {'type': 'integer'} for i in range(N)}}
    def run_fn(kwargs): return sum(kwargs.values())
    tool = MockLangChainTool('bigtool', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 53.2μs -> 27.7μs (92.5% faster)
    required = sorted(f'k{i}' for i in range(N) if i % 2 == 0)
    # Provide only required args
    call_args = {k: 1 for k in required}
    result = t.function(**call_args)
    # sum: N/2 times 1 + sum of odd numbers from 1 to N-1
    odd_sum = sum(i for i in range(1, N, 2))


def test_large_tool_with_long_arg_names():
    """
    Tool with very long argument names.
    """
    long_name = 'x' * 200
    args = {long_name: {}, 'y': {'default': 3}}
    schema = {'type': 'object', 'properties': {long_name: {'type': 'integer'}, 'y': {'type': 'integer'}}}
    def run_fn(kwargs): return kwargs[long_name] + kwargs['y']
    tool = MockLangChainTool('long', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 6.12μs -> 5.04μs (21.5% faster)

def test_large_tool_with_max_properties():
    """
    Tool with 999 arguments, all required.
    """
    N = 999
    args = {f'k{i}': {} for i in range(N)}
    schema = {'type': 'object', 'properties': {f'k{i}': {'type': 'integer'} for i in range(N)}}
    def run_fn(kwargs): return sum(kwargs.values())
    tool = MockLangChainTool('huge', 'desc', args, schema, run_fn)
    codeflash_output = tool_from_langchain(tool); t = codeflash_output # 144μs -> 38.9μs (271% faster)
    call_args = {f'k{i}': 1 for i in range(N)}
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from dataclasses import dataclass, field
from typing import (Any, Callable, Concatenate, Dict, Generic, Literal,
                    ParamSpec, Protocol, TypeVar, Union)

# imports
import pytest  # used for our unit tests
# function to test and dependencies (as provided above)
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_ai.ext.langchain import tool_from_langchain
from pydantic_core import SchemaValidator, core_schema
from typing_extensions import Self, TypeAlias


# --- Test helpers ---
class DummyLangChainTool:
    """A dummy LangChainTool for testing."""

    def __init__(self, name, description, args, input_schema, run_func):
        self.name = name
        self.description = description
        self.args = args  # dict of arg name -> {type, default?}
        self._input_schema = input_schema
        self._run_func = run_func

    def get_input_jsonschema(self):
        return dict(self._input_schema)  # defensive copy

    def run(self, kwargs):
        return self._run_func(kwargs)

# --- Unit Tests ---

# 1. BASIC TEST CASES

def test_basic_tool_with_required_and_optional_args():
    """Test a tool with both required and optional arguments."""
    args = {
        "x": {"type": "integer"},
        "y": {"type": "integer", "default": 10}
    }
    input_schema = {
        "type": "object",
        "properties": {
            "x": {"type": "integer"},
            "y": {"type": "integer"}
        }
    }
    # run function returns sum as string
    def run_func(kwargs):
        return str(kwargs["x"] + kwargs.get("y", 10))

    tool = DummyLangChainTool(
        name="add",
        description="Add two numbers",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.62μs -> 4.88μs (15.4% faster)
    # Check function
    result = ptool.function(x=3, y=7)
    # y is optional
    result2 = ptool.function(x=5)

def test_basic_tool_with_only_required_args():
    """Test a tool with only required arguments."""
    args = {
        "a": {"type": "string"},
        "b": {"type": "string"}
    }
    input_schema = {
        "type": "object",
        "properties": {
            "a": {"type": "string"},
            "b": {"type": "string"}
        }
    }
    def run_func(kwargs):
        return kwargs["a"] + kwargs["b"]
    tool = DummyLangChainTool(
        name="concat",
        description="Concatenate two strings",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.25μs -> 4.75μs (10.5% faster)

def test_basic_tool_with_only_optional_args():
    """Test a tool with only optional arguments (all have defaults)."""
    args = {
        "flag": {"type": "boolean", "default": True},
        "count": {"type": "integer", "default": 2}
    }
    input_schema = {
        "type": "object",
        "properties": {
            "flag": {"type": "boolean"},
            "count": {"type": "integer"}
        }
    }
    def run_func(kwargs):
        return f"{kwargs['flag']}-{kwargs['count']}"
    tool = DummyLangChainTool(
        name="opt",
        description="All optional",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.12μs -> 4.67μs (9.81% faster)

def test_basic_tool_with_no_args():
    """Test a tool with no arguments at all."""
    args = {}
    input_schema = {
        "type": "object",
        "properties": {}
    }
    def run_func(kwargs):
        return "no-args"
    tool = DummyLangChainTool(
        name="noargs",
        description="No arguments",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 4.92μs -> 4.38μs (12.4% faster)

def test_tool_proxy_asserts_on_args():
    """Test that the proxy function asserts if called with positional args."""
    args = {"x": {"type": "integer"}}
    input_schema = {"type": "object", "properties": {"x": {"type": "integer"}}}
    def run_func(kwargs): return str(kwargs["x"])
    tool = DummyLangChainTool(
        name="foo",
        description="desc",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.33μs -> 4.71μs (13.3% faster)
    with pytest.raises(AssertionError):
        ptool.function(1)

# 2. EDGE TEST CASES

def test_tool_with_empty_required_list():
    """Test that required is omitted or empty if no required args."""
    args = {"x": {"type": "integer", "default": 1}}
    input_schema = {"type": "object", "properties": {"x": {"type": "integer"}}}
    def run_func(kwargs): return str(kwargs.get("x"))
    tool = DummyLangChainTool(
        name="defonly",
        description="defaults only",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 4.83μs -> 4.33μs (11.5% faster)

def test_tool_with_additional_properties_in_schema():
    """Test that if additionalProperties is already set, it is not overwritten."""
    args = {"x": {"type": "integer"}}
    input_schema = {
        "type": "object",
        "properties": {"x": {"type": "integer"}},
        "additionalProperties": True
    }
    def run_func(kwargs): return str(kwargs["x"])
    tool = DummyLangChainTool(
        name="aprop",
        description="desc",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.21μs -> 4.54μs (14.7% faster)

def test_tool_with_unusual_arg_names_and_types():
    """Test support for unusual argument names and types."""
    args = {
        "spaced name": {"type": "string"},
        "unicode_ß": {"type": "number", "default": 3.14}
    }
    input_schema = {
        "type": "object",
        "properties": {
            "spaced name": {"type": "string"},
            "unicode_ß": {"type": "number"}
        }
    }
    def run_func(kwargs):
        return f"{kwargs['spaced name']}-{kwargs['unicode_ß']}"
    tool = DummyLangChainTool(
        name="weird",
        description="strange names",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.12μs -> 4.50μs (13.9% faster)

def test_tool_with_large_default_value():
    """Test a tool with a large default value (long string)."""
    big_str = "x" * 512
    args = {"msg": {"type": "string", "default": big_str}}
    input_schema = {"type": "object", "properties": {"msg": {"type": "string"}}}
    def run_func(kwargs): return kwargs["msg"]
    tool = DummyLangChainTool(
        name="bigdefault",
        description="Large default",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 4.79μs -> 4.33μs (10.6% faster)

def test_tool_with_missing_properties_in_schema():
    """Test tool where input_schema lacks properties key."""
    args = {"foo": {"type": "integer"}}
    input_schema = {"type": "object"}
    def run_func(kwargs): return str(kwargs["foo"])
    tool = DummyLangChainTool(
        name="noproperties",
        description="no properties",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 4.83μs -> 4.21μs (14.8% faster)


def test_tool_with_many_args_and_defaults():
    """Test a tool with a large number of arguments and defaults."""
    N = 500
    args = {f"arg{i}": {"type": "integer", "default": i} for i in range(N)}
    input_schema = {
        "type": "object",
        "properties": {f"arg{i}": {"type": "integer"} for i in range(N)}
    }
    def run_func(kwargs):
        # sum all args
        return str(sum(kwargs[f"arg{i}"] for i in range(N)))
    tool = DummyLangChainTool(
        name="bigtool",
        description="Lots of args",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 34.4μs -> 25.7μs (33.9% faster)
    # Override some
    result = ptool.function(**{f"arg{i}": 1000 for i in range(10)})
    expected = sum([1000]*10 + list(range(10, N)))


def test_tool_with_large_schema_properties():
    """Test a tool with a large schema 'properties' dict, but few args."""
    # The schema has many unused properties
    args = {"foo": {"type": "integer"}}
    input_schema = {
        "type": "object",
        "properties": {f"p{i}": {"type": "integer"} for i in range(1000)}
    }
    input_schema["properties"]["foo"] = {"type": "integer"}
    def run_func(kwargs): return str(kwargs["foo"])
    tool = DummyLangChainTool(
        name="largeschema",
        description="Big schema",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.96μs -> 5.62μs (5.92% faster)

def test_tool_with_large_string_args():
    """Test performance with large string argument values."""
    args = {"text": {"type": "string"}}
    input_schema = {"type": "object", "properties": {"text": {"type": "string"}}}
    def run_func(kwargs): return str(len(kwargs["text"]))
    tool = DummyLangChainTool(
        name="strlen",
        description="String length",
        args=args,
        input_schema=input_schema,
        run_func=run_func
    )
    codeflash_output = tool_from_langchain(tool); ptool = codeflash_output # 5.08μs -> 4.54μs (11.9% faster)
    big_text = "z" * 1000
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from pydantic_ai.ext.langchain import tool_from_langchain

To edit these changes git checkout codeflash/optimize-tool_from_langchain-mdexqq2r and push.

Codeflash

fswair and others added 30 commits July 10, 2025 15:51
Co-authored-by: Marcelo Trylesinski <[email protected]>
Co-authored-by: burtenshaw <[email protected]>
…pydantic#2196)

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
…2198)

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
joshualipman123 and others added 29 commits July 24, 2025 13:31
… with output tools (pydantic#2314)

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: Douwe Maan <[email protected]>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: Douwe Maan <[email protected]>
Signed-off-by: Saurabh Misra <[email protected]>
Copy link
Author

codeflash-ai bot commented Jul 30, 2025

This PR is now faster! 🚀 Saurabh Misra accepted my code suggestion above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡️ codeflash Optimization PR opened by Codeflash AI
Projects
None yet
Development

Successfully merging this pull request may close these issues.