Skip to content

Commit 5ed6b74

Browse files
wanlin31copybara-github
authored andcommitted
fix: disable AFC when there are AFC incompatible tool presented.
PiperOrigin-RevId: 812201451
1 parent 5c4d7ee commit 5ed6b74

File tree

4 files changed

+428
-0
lines changed

4 files changed

+428
-0
lines changed

google/genai/_extra_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,33 @@ def format_destination(
115115
return config
116116

117117

118+
def find_afc_incompatible_tool_indexes(
119+
config: Optional[types.GenerateContentConfigOrDict] = None,
120+
) -> list[int]:
121+
"""Checks if the config contains any AFC incompatible tools.
122+
123+
A `types.Tool` object that contains `function_declarations` is considered a
124+
non-AFC tool for this execution path.
125+
126+
Returns:
127+
True if any tool is a `types.Tool` with function declarations,
128+
False otherwise.
129+
"""
130+
if not config:
131+
return []
132+
config_model = _create_generate_content_config_model(config)
133+
incompatible_tools_indexes: list[int] = []
134+
135+
if not config_model or not config_model.tools:
136+
return incompatible_tools_indexes
137+
138+
for index, tool in enumerate(config_model.tools):
139+
if isinstance(tool, types.Tool) and tool.function_declarations:
140+
incompatible_tools_indexes.append(index)
141+
142+
return incompatible_tools_indexes
143+
144+
118145
def get_function_map(
119146
config: Optional[types.GenerateContentConfigOrDict] = None,
120147
mcp_to_genai_tool_adapters: Optional[

google/genai/models.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6439,6 +6439,9 @@ def generate_content(
64396439
# scones.
64406440
"""
64416441

6442+
incompatible_tools_indexes = (
6443+
_extra_utils.find_afc_incompatible_tool_indexes(config)
6444+
)
64426445
parsed_config = _extra_utils.parse_config_for_mcp_usage(config)
64436446
if (
64446447
parsed_config
@@ -6452,6 +6455,26 @@ def generate_content(
64526455
return self._generate_content(
64536456
model=model, contents=contents, config=parsed_config
64546457
)
6458+
if incompatible_tools_indexes:
6459+
original_tools_length = 0
6460+
if isinstance(config, types.GenerateContentConfig):
6461+
if config.tools:
6462+
original_tools_length = len(config.tools)
6463+
elif isinstance(config, dict):
6464+
tools = config.get('tools', [])
6465+
if tools:
6466+
original_tools_length = len(tools)
6467+
if len(incompatible_tools_indexes) != original_tools_length:
6468+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
6469+
logger.warning(
6470+
'Tools at indices [%s] are not compatible with automatic function '
6471+
'calling. AFC will be disabled.',
6472+
indices_str,
6473+
)
6474+
return self._generate_content(
6475+
model=model, contents=contents, config=parsed_config
6476+
)
6477+
64556478
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
64566479
parsed_config
64576480
)
@@ -6575,6 +6598,9 @@ def generate_content_stream(
65756598
# scones.
65766599
"""
65776600

6601+
incompatible_tools_indexes = (
6602+
_extra_utils.find_afc_incompatible_tool_indexes(config)
6603+
)
65786604
parsed_config = _extra_utils.parse_config_for_mcp_usage(config)
65796605
if (
65806606
parsed_config
@@ -6590,6 +6616,27 @@ def generate_content_stream(
65906616
)
65916617
return
65926618

6619+
if incompatible_tools_indexes:
6620+
original_tools_length = 0
6621+
if isinstance(config, types.GenerateContentConfig):
6622+
if config.tools:
6623+
original_tools_length = len(config.tools)
6624+
elif isinstance(config, dict):
6625+
tools = config.get('tools', [])
6626+
if tools:
6627+
original_tools_length = len(tools)
6628+
if len(incompatible_tools_indexes) != original_tools_length:
6629+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
6630+
logger.warning(
6631+
'Tools at indices [%s] are not compatible with automatic function '
6632+
'calling. AFC will be disabled.',
6633+
indices_str,
6634+
)
6635+
yield from self._generate_content_stream(
6636+
model=model, contents=contents, config=parsed_config
6637+
)
6638+
return
6639+
65936640
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
65946641
parsed_config
65956642
)
@@ -8172,13 +8219,35 @@ async def generate_content(
81728219
# J'aime les bagels.
81738220
"""
81748221
# Retrieve and cache any MCP sessions if provided.
8222+
incompatible_tools_indexes = (
8223+
_extra_utils.find_afc_incompatible_tool_indexes(config)
8224+
)
81758225
parsed_config, mcp_to_genai_tool_adapters = (
81768226
await _extra_utils.parse_config_for_mcp_sessions(config)
81778227
)
81788228
if _extra_utils.should_disable_afc(parsed_config):
81798229
return await self._generate_content(
81808230
model=model, contents=contents, config=parsed_config
81818231
)
8232+
if incompatible_tools_indexes:
8233+
original_tools_length = 0
8234+
if isinstance(config, types.GenerateContentConfig):
8235+
if config.tools:
8236+
original_tools_length = len(config.tools)
8237+
elif isinstance(config, dict):
8238+
tools = config.get('tools', [])
8239+
if tools:
8240+
original_tools_length = len(tools)
8241+
if len(incompatible_tools_indexes) != original_tools_length:
8242+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
8243+
logger.warning(
8244+
'Tools at indices [%s] are not compatible with automatic function '
8245+
'calling. AFC will be disabled.',
8246+
indices_str,
8247+
)
8248+
return await self._generate_content(
8249+
model=model, contents=contents, config=parsed_config
8250+
)
81828251
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
81838252
parsed_config
81848253
)
@@ -8303,6 +8372,10 @@ async def generate_content_stream(
83038372
# scones.
83048373
"""
83058374

8375+
# Retrieve and cache any MCP sessions if provided.
8376+
incompatible_tools_indexes = (
8377+
_extra_utils.find_afc_incompatible_tool_indexes(config)
8378+
)
83068379
# Retrieve and cache any MCP sessions if provided.
83078380
parsed_config, mcp_to_genai_tool_adapters = (
83088381
await _extra_utils.parse_config_for_mcp_sessions(config)
@@ -8318,6 +8391,32 @@ async def base_async_generator(model, contents, config): # type: ignore[no-unty
83188391

83198392
return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
83208393

8394+
if incompatible_tools_indexes:
8395+
original_tools_length = 0
8396+
if isinstance(config, types.GenerateContentConfig):
8397+
if config.tools:
8398+
original_tools_length = len(config.tools)
8399+
elif isinstance(config, dict):
8400+
tools = config.get('tools', [])
8401+
if tools:
8402+
original_tools_length = len(tools)
8403+
if len(incompatible_tools_indexes) != original_tools_length:
8404+
indices_str = ', '.join(map(str, incompatible_tools_indexes))
8405+
logger.warning(
8406+
'Tools at indices [%s] are not compatible with automatic function '
8407+
'calling. AFC will be disabled.',
8408+
indices_str,
8409+
)
8410+
response = await self._generate_content_stream(
8411+
model=model, contents=contents, config=parsed_config
8412+
)
8413+
8414+
async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def]
8415+
async for chunk in response: # type: ignore[attr-defined]
8416+
yield chunk
8417+
8418+
return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
8419+
83218420
async def async_generator(model, contents, config): # type: ignore[no-untyped-def]
83228421
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
83238422
logger.info(

0 commit comments

Comments
 (0)