1919import java .util .List ;
2020import java .util .Map ;
2121import java .util .Set ;
22+ import java .util .concurrent .atomic .AtomicReference ;
2223
2324import org .junit .jupiter .api .Test ;
2425
2526import org .springframework .ai .chat .messages .AssistantMessage ;
2627import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
28+ import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
29+ import reactor .core .publisher .Flux ;
2730
2831import static org .assertj .core .api .Assertions .assertThat ;
2932import static org .assertj .core .api .Assertions .assertThatThrownBy ;
33+ import static org .springframework .ai .chat .messages .AssistantMessage .*;
3034
3135/**
3236 * Unit tests for {@link ChatResponse}.
@@ -38,8 +42,8 @@ class ChatResponseTests {
3842 @ Test
3943 void whenToolCallsArePresentThenReturnTrue () {
4044 ChatResponse chatResponse = ChatResponse .builder ()
41- .generations (List .of (new Generation (new AssistantMessage ( "" , Map . of (),
42- List .of (new AssistantMessage . ToolCall ("toolA" , "function" , "toolA" , "{}" ))))))
45+ .generations (List .of (new Generation (
46+ new AssistantMessage ( "" , Map . of (), List .of (new ToolCall ("toolA" , "function" , "toolA" , "{}" ))))))
4347 .build ();
4448 assertThat (chatResponse .hasToolCalls ()).isTrue ();
4549 }
@@ -80,4 +84,45 @@ void whenFinishReasonIsNotPresent() {
8084 assertThat (chatResponse .hasFinishReasons (Set .of ("completed" ))).isFalse ();
8185 }
8286
87+ @ Test
88+ void messageAggregatorShouldCorrectlyAggregateToolCallsFromStream () {
89+
90+ MessageAggregator aggregator = new MessageAggregator ();
91+
92+ ChatResponse chunk1 = new ChatResponse (
93+ List .of (new Generation (new AssistantMessage ("Thinking about the weather... " ))));
94+
95+ ToolCall weatherToolCall = new ToolCall ("tool-id-123" , "function" , "getCurrentWeather" ,
96+ "{\" location\" : \" Seoul\" }" );
97+
98+ Map <String , Object > metadataWithToolCall = Map .of ("toolCalls" , List .of (weatherToolCall ));
99+ ChatResponseMetadata responseMetadataForChunk2 = ChatResponseMetadata .builder ()
100+ .metadata (metadataWithToolCall )
101+ .build ();
102+
103+ ChatResponse chunk2 = new ChatResponse (List .of (new Generation (new AssistantMessage ("" ))),
104+ responseMetadataForChunk2 );
105+
106+ Flux <ChatResponse > streamingResponse = Flux .just (chunk1 , chunk2 );
107+
108+ AtomicReference <ChatResponse > aggregatedResponseRef = new AtomicReference <>();
109+
110+ aggregator .aggregate (streamingResponse , aggregatedResponseRef ::set ).blockLast ();
111+
112+ ChatResponse finalResponse = aggregatedResponseRef .get ();
113+ assertThat (finalResponse ).isNotNull ();
114+
115+ AssistantMessage finalAssistantMessage = finalResponse .getResult ().getOutput ();
116+
117+ assertThat (finalAssistantMessage ).isNotNull ();
118+ assertThat (finalAssistantMessage .getText ()).isEqualTo ("Thinking about the weather... " );
119+ assertThat (finalAssistantMessage .hasToolCalls ()).isTrue ();
120+ assertThat (finalAssistantMessage .getToolCalls ()).hasSize (1 );
121+
122+ ToolCall resultToolCall = finalAssistantMessage .getToolCalls ().get (0 );
123+ assertThat (resultToolCall .id ()).isEqualTo ("tool-id-123" );
124+ assertThat (resultToolCall .name ()).isEqualTo ("getCurrentWeather" );
125+ assertThat (resultToolCall .arguments ()).isEqualTo ("{\" location\" : \" Seoul\" }" );
126+ }
127+
83128}
0 commit comments