1616
1717package  org .springframework .ai .chat .model ;
1818
19+ import  java .util .ArrayList ;
1920import  java .util .HashMap ;
2021import  java .util .List ;
2122import  java .util .Map ;
2425
2526import  org .slf4j .Logger ;
2627import  org .slf4j .LoggerFactory ;
28+ import  org .springframework .util .CollectionUtils ;
2729import  reactor .core .publisher .Flux ;
2830
2931import  org .springframework .ai .chat .messages .AssistantMessage ;
3537import  org .springframework .ai .chat .metadata .Usage ;
3638import  org .springframework .util .StringUtils ;
3739
40+ import  static  org .springframework .ai .chat .messages .AssistantMessage .*;
41+ 
3842/** 
3943 * Helper that for streaming chat responses, aggregate the chat response messages into a 
4044 * single AssistantMessage. Job is performed in parallel to the chat response processing. 
4145 * 
4246 * @author Christian Tzolov 
4347 * @author Alexandros Pappas 
4448 * @author Thomas Vitale 
49+  * @author Heonwoo Kim 
4550 * @since 1.0.0 
4651 */ 
4752public  class  MessageAggregator  {
@@ -54,6 +59,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
5459		// Assistant Message 
5560		AtomicReference <StringBuilder > messageTextContentRef  = new  AtomicReference <>(new  StringBuilder ());
5661		AtomicReference <Map <String , Object >> messageMetadataMapRef  = new  AtomicReference <>();
62+ 		AtomicReference <List <ToolCall >> toolCallsRef  = new  AtomicReference <>(new  ArrayList <>());
5763
5864		// ChatGeneration Metadata 
5965		AtomicReference <ChatGenerationMetadata > generationMetadataRef  = new  AtomicReference <>(
@@ -73,6 +79,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
7379		return  fluxChatResponse .doOnSubscribe (subscription  -> {
7480			messageTextContentRef .set (new  StringBuilder ());
7581			messageMetadataMapRef .set (new  HashMap <>());
82+ 			toolCallsRef .set (new  ArrayList <>());
7683			metadataIdRef .set ("" );
7784			metadataModelRef .set ("" );
7885			metadataUsagePromptTokensRef .set (0 );
@@ -94,6 +101,11 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
94101				if  (chatResponse .getResult ().getOutput ().getMetadata () != null ) {
95102					messageMetadataMapRef .get ().putAll (chatResponse .getResult ().getOutput ().getMetadata ());
96103				}
104+ 				AssistantMessage  outputMessage  = chatResponse .getResult ().getOutput ();
105+ 				if  (!CollectionUtils .isEmpty (outputMessage .getToolCalls ())) {
106+ 					toolCallsRef .get ().addAll (outputMessage .getToolCalls ());
107+ 				}
108+ 
97109			}
98110			if  (chatResponse .getMetadata () != null ) {
99111				if  (chatResponse .getMetadata ().getUsage () != null ) {
@@ -119,6 +131,13 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
119131				if  (StringUtils .hasText (chatResponse .getMetadata ().getModel ())) {
120132					metadataModelRef .set (chatResponse .getMetadata ().getModel ());
121133				}
134+ 				Object  toolCallsFromMetadata  = chatResponse .getMetadata ().get ("toolCalls" );
135+ 				if  (toolCallsFromMetadata  instanceof  List ) {
136+ 					@ SuppressWarnings ("unchecked" )
137+ 					List <ToolCall > toolCallsList  = (List <ToolCall >) toolCallsFromMetadata ;
138+ 					toolCallsRef .get ().addAll (toolCallsList );
139+ 				}
140+ 
122141			}
123142		}).doOnComplete (() -> {
124143
@@ -133,12 +152,25 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
133152				.promptMetadata (metadataPromptMetadataRef .get ())
134153				.build ();
135154
136- 			onAggregationComplete .accept (new  ChatResponse (List .of (new  Generation (
137- 					new  AssistantMessage (messageTextContentRef .get ().toString (), messageMetadataMapRef .get ()),
155+ 			AssistantMessage  finalAssistantMessage ;
156+ 			List <ToolCall > collectedToolCalls  = toolCallsRef .get ();
157+ 
158+ 			if  (!CollectionUtils .isEmpty (collectedToolCalls )) {
159+ 
160+ 				finalAssistantMessage  = new  AssistantMessage (messageTextContentRef .get ().toString (),
161+ 						messageMetadataMapRef .get (), collectedToolCalls );
162+ 			}
163+ 			else  {
164+ 				finalAssistantMessage  = new  AssistantMessage (messageTextContentRef .get ().toString (),
165+ 						messageMetadataMapRef .get ());
166+ 			}
167+ 			onAggregationComplete .accept (new  ChatResponse (List .of (new  Generation (finalAssistantMessage ,
168+ 
138169					generationMetadataRef .get ())), chatResponseMetadata ));
139170
140171			messageTextContentRef .set (new  StringBuilder ());
141172			messageMetadataMapRef .set (new  HashMap <>());
173+ 			toolCallsRef .set (new  ArrayList <>());
142174			metadataIdRef .set ("" );
143175			metadataModelRef .set ("" );
144176			metadataUsagePromptTokensRef .set (0 );
0 commit comments