Skip to content

Commit 134ec0d

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Fix the long running function response event merge logic
1) raise explicit error if the response event contains responses against multiple function call events 2) merge all function responses for the corresponding function call event PiperOrigin-RevId: 782154577
1 parent a8fcc1b commit 134ec0d

File tree

3 files changed

+194
-12
lines changed

3 files changed

+194
-12
lines changed

src/google/adk/flows/llm_flows/contents.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response(
157157
for function_call in function_calls:
158158
if function_call.id in function_responses_ids:
159159
function_call_event_idx = idx
160-
break
161-
if function_call_event_idx != -1:
162-
# in case the last response event only have part of the responses
163-
# for the function calls in the function call event
164-
for function_call in function_calls:
165-
function_responses_ids.add(function_call.id)
160+
function_call_ids = {
161+
function_call.id for function_call in function_calls
162+
}
163+
# last response event should only contain the responses for the
164+
# function calls in the same function call event
165+
if not function_responses_ids.issubset(function_call_ids):
166+
raise ValueError(
167+
'Last response event should only contain the responses for the'
168+
' function calls in the same function call event. Function'
169+
f' call ids found : {function_call_ids}, function response'
170+
f' ids provided: {function_responses_ids}'
171+
)
172+
# collect all function responses from the function call event to
173+
# the last response event
174+
function_responses_ids = function_call_ids
166175
break
167176

168177
if function_call_event_idx == -1:
@@ -363,10 +372,7 @@ def _merge_function_response_events(
363372
list is in increasing order of timestamp; 2. the first event is the
364373
initial function_response event; 3. all later events should contain at
365374
least one function_response part that related to the function_call
366-
event. (Note, 3. may not be true when aync function return some
367-
intermediate response, there could also be some intermediate model
368-
response event without any function_response and such event will be
369-
ignored.)
375+
event.
370376
Caveat: This implementation doesn't support when a parallel function_call
371377
event contains async function_call of the same name.
372378

tests/unittests/flows/llm_flows/test_contents.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,179 @@ def test_rearrange_events_for_latest_function_response():
359359
# Should remove intermediate events and merge responses
360360
assert len(rearranged) == 2
361361
assert rearranged[0] == call_event
362+
assert rearranged[1] == response_event
363+
364+
365+
def test_rearrange_events_for_latest_function_response_multiple_calls():
366+
"""Test _rearrange_events_for_latest_function_response with multiple function calls."""
367+
# Create function call event with multiple calls
368+
function_call1 = types.FunctionCall(
369+
id="func_123", name="test_function", args={"param": "value1"}
370+
)
371+
function_call2 = types.FunctionCall(
372+
id="func_456", name="test_function2", args={"param": "value2"}
373+
)
374+
375+
call_event = Event(
376+
invocation_id="test_inv1",
377+
author="agent",
378+
content=types.Content(
379+
role="model",
380+
parts=[
381+
types.Part(function_call=function_call1),
382+
types.Part(function_call=function_call2),
383+
],
384+
),
385+
)
386+
387+
# Create intermediate event
388+
intermediate_event = Event(
389+
invocation_id="test_inv2",
390+
author="agent",
391+
content=types.Content(
392+
role="model", parts=[types.Part.from_text(text="Processing...")]
393+
),
394+
)
395+
396+
# Create function response event with only one response
397+
function_response = types.FunctionResponse(
398+
id="func_123", name="test_function", response={"result": "success"}
399+
)
400+
401+
response_event = Event(
402+
invocation_id="test_inv3",
403+
author="user",
404+
content=types.Content(
405+
role="user", parts=[types.Part(function_response=function_response)]
406+
),
407+
)
408+
409+
# Test with matching function call and response
410+
events = [call_event, intermediate_event, response_event]
411+
rearranged = _rearrange_events_for_latest_function_response(events)
412+
413+
# Should remove intermediate events and merge responses
414+
assert len(rearranged) == 2
415+
assert rearranged[0] == call_event
416+
assert rearranged[1] == response_event
417+
418+
419+
def test_rearrange_events_for_latest_function_response_validation_error():
420+
"""Test _rearrange_events_for_latest_function_response with validation error."""
421+
# Create function call event with one function call
422+
function_call = types.FunctionCall(
423+
id="func_123", name="test_function", args={"param": "value"}
424+
)
425+
426+
call_event = Event(
427+
invocation_id="test_inv1",
428+
author="agent",
429+
content=types.Content(
430+
role="model", parts=[types.Part(function_call=function_call)]
431+
),
432+
)
433+
434+
# Create intermediate event
435+
intermediate_event = Event(
436+
invocation_id="test_inv2",
437+
author="agent",
438+
content=types.Content(
439+
role="model", parts=[types.Part.from_text(text="Processing...")]
440+
),
441+
)
442+
443+
# Create function response event with the matching function call AND an extra one
444+
function_response1 = types.FunctionResponse(
445+
id="func_123", name="test_function", response={"result": "success"}
446+
)
447+
function_response2 = types.FunctionResponse(
448+
id="func_456", name="other_function", response={"result": "other"}
449+
)
450+
451+
response_event = Event(
452+
invocation_id="test_inv3",
453+
author="user",
454+
content=types.Content(
455+
role="user",
456+
parts=[
457+
types.Part(function_response=function_response1),
458+
types.Part(function_response=function_response2),
459+
],
460+
),
461+
)
462+
463+
# Test with mismatched function call and response
464+
events = [call_event, intermediate_event, response_event]
465+
466+
with pytest.raises(
467+
ValueError,
468+
match=(
469+
"Last response event should only contain the responses for the"
470+
" function calls in the same function call event"
471+
),
472+
):
473+
_rearrange_events_for_latest_function_response(events)
474+
475+
476+
def test_rearrange_events_for_latest_function_response_mixed_responses():
477+
"""Test _rearrange_events_for_latest_function_response with mixed function responses."""
478+
# Create function call event with two calls
479+
function_call1 = types.FunctionCall(
480+
id="func_123", name="test_function", args={"param": "value1"}
481+
)
482+
function_call2 = types.FunctionCall(
483+
id="func_456", name="test_function2", args={"param": "value2"}
484+
)
485+
486+
call_event = Event(
487+
invocation_id="test_inv1",
488+
author="agent",
489+
content=types.Content(
490+
role="model",
491+
parts=[
492+
types.Part(function_call=function_call1),
493+
types.Part(function_call=function_call2),
494+
],
495+
),
496+
)
497+
498+
# Create intermediate event
499+
intermediate_event = Event(
500+
invocation_id="test_inv2",
501+
author="agent",
502+
content=types.Content(
503+
role="model", parts=[types.Part.from_text(text="Processing...")]
504+
),
505+
)
506+
507+
# Create function response event with one matching and one non-matching response
508+
function_response1 = types.FunctionResponse(
509+
id="func_123", name="test_function", response={"result": "success"}
510+
)
511+
function_response2 = types.FunctionResponse(
512+
id="func_789", name="test_function3", response={"result": "other"}
513+
)
514+
515+
response_event = Event(
516+
invocation_id="test_inv3",
517+
author="user",
518+
content=types.Content(
519+
role="user",
520+
parts=[
521+
types.Part(function_response=function_response1),
522+
types.Part(function_response=function_response2),
523+
],
524+
),
525+
)
526+
527+
# Test with mixed function responses
528+
events = [call_event, intermediate_event, response_event]
529+
530+
with pytest.raises(
531+
ValueError,
532+
match=(
533+
"Last response event should only contain the responses for the"
534+
" function calls in the same function call event"
535+
),
536+
):
537+
_rearrange_events_for_latest_function_response(events)

tests/unittests/flows/llm_flows/test_functions_request_euc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,13 +549,13 @@ def call_external_api2(tool_context: ToolContext) -> int:
549549
],
550550
),
551551
)
552-
# assert function_invoked == 4
552+
assert function_invoked == 4
553553
assert len(mock_model.requests) == 4
554554
request = mock_model.requests[-1]
555555
content = request.contents[-1]
556556
parts = content.parts
557557
assert len(parts) == 2
558558
assert parts[0].function_response.name == 'call_external_api1'
559-
assert parts[0].function_response.response == {'result': None}
559+
assert parts[0].function_response.response == {'result': 1}
560560
assert parts[1].function_response.name == 'call_external_api2'
561561
assert parts[1].function_response.response == {'result': 2}

0 commit comments

Comments
 (0)