Skip to content

Commit 783d40d

Browse files
committed
refactor : refactor MessageAggregator to include toolCalls
1 parent 161c437 commit 783d40d

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.chat.model;
1818

19+
import java.util.ArrayList;
1920
import java.util.HashMap;
2021
import java.util.List;
2122
import java.util.Map;
@@ -24,6 +25,7 @@
2425

2526
import org.slf4j.Logger;
2627
import org.slf4j.LoggerFactory;
28+
import org.springframework.util.CollectionUtils;
2729
import reactor.core.publisher.Flux;
2830

2931
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -35,13 +37,16 @@
3537
import org.springframework.ai.chat.metadata.Usage;
3638
import org.springframework.util.StringUtils;
3739

40+
import static org.springframework.ai.chat.messages.AssistantMessage.*;
41+
3842
/**
3943
* Helper that for streaming chat responses, aggregate the chat response messages into a
4044
* single AssistantMessage. Job is performed in parallel to the chat response processing.
4145
*
4246
* @author Christian Tzolov
4347
* @author Alexandros Pappas
4448
* @author Thomas Vitale
49+
* @author Heonwoo Kim
4550
* @since 1.0.0
4651
*/
4752
public class MessageAggregator {
@@ -54,6 +59,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
5459
// Assistant Message
5560
AtomicReference<StringBuilder> messageTextContentRef = new AtomicReference<>(new StringBuilder());
5661
AtomicReference<Map<String, Object>> messageMetadataMapRef = new AtomicReference<>();
62+
AtomicReference<List<ToolCall>> toolCallsRef = new AtomicReference<>(new ArrayList<>());
5763

5864
// ChatGeneration Metadata
5965
AtomicReference<ChatGenerationMetadata> generationMetadataRef = new AtomicReference<>(
@@ -73,6 +79,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
7379
return fluxChatResponse.doOnSubscribe(subscription -> {
7480
messageTextContentRef.set(new StringBuilder());
7581
messageMetadataMapRef.set(new HashMap<>());
82+
toolCallsRef.set(new ArrayList<>());
7683
metadataIdRef.set("");
7784
metadataModelRef.set("");
7885
metadataUsagePromptTokensRef.set(0);
@@ -94,6 +101,11 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
94101
if (chatResponse.getResult().getOutput().getMetadata() != null) {
95102
messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata());
96103
}
104+
AssistantMessage outputMessage = chatResponse.getResult().getOutput();
105+
if (!CollectionUtils.isEmpty(outputMessage.getToolCalls())) {
106+
toolCallsRef.get().addAll(outputMessage.getToolCalls());
107+
}
108+
97109
}
98110
if (chatResponse.getMetadata() != null) {
99111
if (chatResponse.getMetadata().getUsage() != null) {
@@ -119,6 +131,13 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
119131
if (StringUtils.hasText(chatResponse.getMetadata().getModel())) {
120132
metadataModelRef.set(chatResponse.getMetadata().getModel());
121133
}
134+
Object toolCallsFromMetadata = chatResponse.getMetadata().get("toolCalls");
135+
if (toolCallsFromMetadata instanceof List) {
136+
@SuppressWarnings("unchecked")
137+
List<ToolCall> toolCallsList = (List<ToolCall>) toolCallsFromMetadata;
138+
toolCallsRef.get().addAll(toolCallsList);
139+
}
140+
122141
}
123142
}).doOnComplete(() -> {
124143

@@ -133,12 +152,25 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
133152
.promptMetadata(metadataPromptMetadataRef.get())
134153
.build();
135154

136-
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(
137-
new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()),
155+
AssistantMessage finalAssistantMessage;
156+
List<ToolCall> collectedToolCalls = toolCallsRef.get();
157+
158+
if (!CollectionUtils.isEmpty(collectedToolCalls)) {
159+
160+
finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(),
161+
messageMetadataMapRef.get(), collectedToolCalls);
162+
}
163+
else {
164+
finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(),
165+
messageMetadataMapRef.get());
166+
}
167+
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(finalAssistantMessage,
168+
138169
generationMetadataRef.get())), chatResponseMetadata));
139170

140171
messageTextContentRef.set(new StringBuilder());
141172
messageMetadataMapRef.set(new HashMap<>());
173+
toolCallsRef.set(new ArrayList<>());
142174
metadataIdRef.set("");
143175
metadataModelRef.set("");
144176
metadataUsagePromptTokensRef.set(0);

0 commit comments

Comments
 (0)