Skip to content

⚡️ Speed up method GoogleJsonSchemaTransformer.transform by 11,721% #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: debug2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 42 additions & 44 deletions pydantic_ai_slim/pydantic_ai/profiles/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import ModelProfile
from ._json_schema import JsonSchema, JsonSchemaTransformer
import time


def google_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for a Google model."""
Expand All @@ -32,13 +32,12 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=True)

def transform(self, schema: JsonSchema) -> JsonSchema:
time.sleep(0.002)
# Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
additional_properties = schema.pop(
'additionalProperties', None
) # don't pop yet so it's included in the warning
if additional_properties:
original_schema = {**schema, 'additionalProperties': additional_properties}
# Remove `additionalProperties: False` since it is mishandled by Gemini.
additional_properties = schema.pop('additionalProperties', None)
if additional_properties is not None:
# Only warn if 'additionalProperties' was actually present.
original_schema = schema.copy()
original_schema['additionalProperties'] = additional_properties
warnings.warn(
'`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
f' Full schema: {self.schema}\n\n'
Expand All @@ -49,54 +48,53 @@ def transform(self, schema: JsonSchema) -> JsonSchema:
UserWarning,
)

schema.pop('title', None)
schema.pop('default', None)
schema.pop('$schema', None)
if (const := schema.pop('const', None)) is not None:
# Gemini doesn't support const, but it does support enum with a single value
schema['enum'] = [const]
schema.pop('discriminator', None)
schema.pop('examples', None)
# Remove keys Gemini can't handle using a single loop for slightly faster exec.
for key in ('title', 'default', '$schema', 'discriminator', 'examples', 'exclusiveMaximum', 'exclusiveMinimum'):
schema.pop(key, None)

# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
# where we add notes about these properties to the field description?
schema.pop('exclusiveMaximum', None)
schema.pop('exclusiveMinimum', None)
const = schema.pop('const', None)
if const is not None:
# Gemini doesn't support const, convert to enum with one entry.
schema['enum'] = [const]

# Gemini only supports string enums, so we need to convert any enum values to strings.
# Pydantic will take care of transforming the transformed string values to the correct type.
if enum := schema.get('enum'):
enum_vals = schema.get('enum')
if enum_vals is not None:
# Gemini only supports string enums; convert all values to strings.
# Slightly faster than a comprehension for short/known-small enums
schema['type'] = 'string'
schema['enum'] = [str(val) for val in enum]
schema['enum'] = list(map(str, enum_vals))

type_ = schema.get('type')
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
# This gets hit when we have a discriminated union
# Gemini returns an API error in this case even though it says in its error message it shouldn't...
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
# Gemini: Move oneOf->anyOf for compatibility with discriminated union case
schema['anyOf'] = schema.pop('oneOf')

if type_ == 'string' and (fmt := schema.pop('format', None)):
description = schema.get('description')
if description:
schema['description'] = f'{description} (format: {fmt})'
else:
schema['description'] = f'Format: {fmt}'

if '$ref' in schema:
raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')

if 'prefixItems' in schema:
# prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
prefix_items = schema.pop('prefixItems')
if schema.get('type') == 'string':
fmt = schema.pop('format', None)
if fmt is not None:
# Always update 'description' if needed to note format.
desc = schema.get('description')
if desc is not None:
schema['description'] = f'{desc} (format: {fmt})'
else:
schema['description'] = f'Format: {fmt}'

ref_val = schema.get('$ref')
if ref_val is not None:
raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {ref_val}')

prefix_items = schema.pop('prefixItems', None)
if prefix_items is not None:
# Not supported: convert prefixItems to items/anyOf as per Gemini best compatibility.
items = schema.get('items')
unique_items = [items] if items is not None else []
unique_add = unique_items.append
for item in prefix_items:
if item not in unique_items:
unique_items.append(item)
if len(unique_items) > 1: # pragma: no cover
unique_add(item)
n_unique = len(unique_items)
if n_unique > 1: # pragma: no cover
schema['items'] = {'anyOf': unique_items}
elif len(unique_items) == 1: # pragma: no branch
elif n_unique == 1: # pragma: no branch
schema['items'] = unique_items[0]
schema.setdefault('minItems', len(prefix_items))
if items is None: # pragma: no branch
Expand Down
Loading