Skip to content

Commit 37fa0dd

Browse files
committed
test: Add unit test for MessageAggregator tool call aggregation
1 parent 783d40d commit 37fa0dd

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@
1919
import java.util.List;
2020
import java.util.Map;
2121
import java.util.Set;
22+
import java.util.concurrent.atomic.AtomicReference;
2223

2324
import org.junit.jupiter.api.Test;
2425

2526
import org.springframework.ai.chat.messages.AssistantMessage;
2627
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
28+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
29+
import reactor.core.publisher.Flux;
2730

2831
import static org.assertj.core.api.Assertions.assertThat;
2932
import static org.assertj.core.api.Assertions.assertThatThrownBy;
33+
import static org.springframework.ai.chat.messages.AssistantMessage.*;
3034

3135
/**
3236
* Unit tests for {@link ChatResponse}.
@@ -38,8 +42,8 @@ class ChatResponseTests {
3842
@Test
3943
void whenToolCallsArePresentThenReturnTrue() {
4044
ChatResponse chatResponse = ChatResponse.builder()
41-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
42-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))))))
45+
.generations(List.of(new Generation(
46+
new AssistantMessage("", Map.of(), List.of(new ToolCall("toolA", "function", "toolA", "{}"))))))
4347
.build();
4448
assertThat(chatResponse.hasToolCalls()).isTrue();
4549
}
@@ -80,4 +84,45 @@ void whenFinishReasonIsNotPresent() {
8084
assertThat(chatResponse.hasFinishReasons(Set.of("completed"))).isFalse();
8185
}
8286

87+
@Test
88+
void messageAggregatorShouldCorrectlyAggregateToolCallsFromStream() {
89+
90+
MessageAggregator aggregator = new MessageAggregator();
91+
92+
ChatResponse chunk1 = new ChatResponse(
93+
List.of(new Generation(new AssistantMessage("Thinking about the weather... "))));
94+
95+
ToolCall weatherToolCall = new ToolCall("tool-id-123", "function", "getCurrentWeather",
96+
"{\"location\": \"Seoul\"}");
97+
98+
Map<String, Object> metadataWithToolCall = Map.of("toolCalls", List.of(weatherToolCall));
99+
ChatResponseMetadata responseMetadataForChunk2 = ChatResponseMetadata.builder()
100+
.metadata(metadataWithToolCall)
101+
.build();
102+
103+
ChatResponse chunk2 = new ChatResponse(List.of(new Generation(new AssistantMessage(""))),
104+
responseMetadataForChunk2);
105+
106+
Flux<ChatResponse> streamingResponse = Flux.just(chunk1, chunk2);
107+
108+
AtomicReference<ChatResponse> aggregatedResponseRef = new AtomicReference<>();
109+
110+
aggregator.aggregate(streamingResponse, aggregatedResponseRef::set).blockLast();
111+
112+
ChatResponse finalResponse = aggregatedResponseRef.get();
113+
assertThat(finalResponse).isNotNull();
114+
115+
AssistantMessage finalAssistantMessage = finalResponse.getResult().getOutput();
116+
117+
assertThat(finalAssistantMessage).isNotNull();
118+
assertThat(finalAssistantMessage.getText()).isEqualTo("Thinking about the weather... ");
119+
assertThat(finalAssistantMessage.hasToolCalls()).isTrue();
120+
assertThat(finalAssistantMessage.getToolCalls()).hasSize(1);
121+
122+
ToolCall resultToolCall = finalAssistantMessage.getToolCalls().get(0);
123+
assertThat(resultToolCall.id()).isEqualTo("tool-id-123");
124+
assertThat(resultToolCall.name()).isEqualTo("getCurrentWeather");
125+
assertThat(resultToolCall.arguments()).isEqualTo("{\"location\": \"Seoul\"}");
126+
}
127+
83128
}

0 commit comments

Comments
 (0)