Skip to content

Commit de9a356

Browse files
committed
Fix Moonshot Chat model toolcalling token usage
- Accumulate the token usage when toolcalling is invoked - Fix both call() and stream() methods - Add `usage` field to the Chat completion choice as the usage is returned via Choice - Add Mootshot chatmodel ITs for functioncalling tests
1 parent 218c967 commit de9a356

File tree

5 files changed

+214
-11
lines changed

5 files changed

+214
-11
lines changed

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3737
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
3838
import org.springframework.ai.chat.metadata.EmptyUsage;
39+
import org.springframework.ai.chat.metadata.Usage;
40+
import org.springframework.ai.chat.metadata.UsageUtils;
3941
import org.springframework.ai.chat.model.AbstractToolCallSupport;
4042
import org.springframework.ai.chat.model.ChatModel;
4143
import org.springframework.ai.chat.model.ChatResponse;
@@ -74,6 +76,7 @@
7476
* MoonshotChatModel is a {@link ChatModel} implementation that uses the Moonshot
7577
*
7678
* @author Geng Rong
79+
* @author Ilayaperumal Gopinathan
7780
*/
7881
public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel {
7982

@@ -179,6 +182,10 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met
179182

180183
@Override
181184
public ChatResponse call(Prompt prompt) {
185+
return this.internalCall(prompt, null);
186+
}
187+
188+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
182189
ChatCompletionRequest request = createRequest(prompt, false);
183190

184191
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -217,8 +224,11 @@ public ChatResponse call(Prompt prompt) {
217224
// @formatter:on
218225
return buildGeneration(choice, metadata);
219226
}).toList();
220-
221-
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
227+
MoonshotApi.Usage usage = completionEntity.getBody().usage();
228+
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
229+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
230+
ChatResponse chatResponse = new ChatResponse(generations,
231+
from(completionEntity.getBody(), cumulativeUsage));
222232

223233
observationContext.setResponse(chatResponse);
224234

@@ -231,7 +241,7 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS
231241
var toolCallConversation = handleToolCalls(prompt, response);
232242
// Recursively call the call method with the tool call message
233243
// conversation that contains the call responses.
234-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
244+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
235245
}
236246
return response;
237247
}
@@ -243,6 +253,10 @@ public ChatOptions getDefaultOptions() {
243253

244254
@Override
245255
public Flux<ChatResponse> stream(Prompt prompt) {
256+
return this.internalStream(prompt, null);
257+
}
258+
259+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
246260
return Flux.deferContextual(contextView -> {
247261
ChatCompletionRequest request = createRequest(prompt, true);
248262

@@ -286,8 +300,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
286300
// @formatter:on
287301
return buildGeneration(choice, metadata);
288302
}).toList();
303+
MoonshotApi.Usage usage = chatCompletion2.usage();
304+
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
305+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
289306

290-
return new ChatResponse(generations, from(chatCompletion2));
307+
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
291308
}
292309
catch (Exception e) {
293310
logger.error("Error processing chat completion", e);
@@ -302,7 +319,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
302319
var toolCallConversation = handleToolCalls(prompt, response);
303320
// Recursively call the stream method with the tool call message
304321
// conversation that contains the call responses.
305-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
322+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
306323
}
307324
return Flux.just(response);
308325
})
@@ -324,6 +341,16 @@ private ChatResponseMetadata from(ChatCompletion result) {
324341
.build();
325342
}
326343

344+
private ChatResponseMetadata from(ChatCompletion result, Usage usage) {
345+
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
346+
return ChatResponseMetadata.builder()
347+
.withId(result.id() != null ? result.id() : "")
348+
.withUsage(usage)
349+
.withModel(result.model() != null ? result.model() : "")
350+
.withKeyValue("created", result.created() != null ? result.created() : 0L)
351+
.build();
352+
}
353+
327354
/**
328355
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
329356
* @param chunk the ChatCompletionChunk to convert
@@ -335,10 +362,11 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
335362
if (delta == null) {
336363
delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT);
337364
}
338-
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason());
365+
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage());
339366
}).toList();
340-
341-
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
367+
// Get the usage from the latest choice
368+
MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage();
369+
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage);
342370
}
343371

344372
/**

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ public record Choice(
532532
// @formatter:off
533533
@JsonProperty("index") Integer index,
534534
@JsonProperty("message") ChatCompletionMessage message,
535-
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
535+
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
536+
@JsonProperty("usage") Usage usage) {
536537
// @formatter:on
537538
}
538539

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
6464
: previous.finishReason());
6565
Integer index = (current.index() != null ? current.index() : previous.index());
6666

67+
MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage();
68+
6769
ChatCompletionMessage message = merge(previous.delta(), current.delta());
68-
return new ChunkChoice(index, message, finishReason, null);
70+
return new ChunkChoice(index, message, finishReason, usage);
6971
}
7072

7173
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.moonshot;
17+
18+
import java.util.Arrays;
19+
import java.util.List;
20+
import java.util.stream.Collectors;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
24+
import reactor.core.publisher.Flux;
25+
26+
import org.springframework.ai.chat.messages.AssistantMessage;
27+
import org.springframework.ai.chat.model.ChatResponse;
28+
import org.springframework.ai.chat.model.Generation;
29+
import org.springframework.ai.chat.prompt.Prompt;
30+
import org.springframework.ai.model.function.FunctionCallback;
31+
import org.springframework.ai.moonshot.api.MockWeatherService;
32+
import org.springframework.ai.moonshot.api.MoonshotApi;
33+
import org.springframework.beans.factory.annotation.Autowired;
34+
import org.springframework.boot.SpringBootConfiguration;
35+
import org.springframework.boot.test.context.SpringBootTest;
36+
import org.springframework.context.annotation.Bean;
37+
import org.springframework.util.StringUtils;
38+
39+
import static org.assertj.core.api.Assertions.assertThat;
40+
41+
/**
42+
* @author Ilayaperumal Gopinathan
43+
*/
44+
@SpringBootTest
45+
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
46+
public class MoonShotChatModelIT {
47+
48+
@Autowired
49+
private MoonshotChatModel chatModel;
50+
51+
private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool(
52+
MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function(
53+
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
54+
{
55+
"type": "object",
56+
"properties": {
57+
"location": {
58+
"type": "string",
59+
"description": "The city and state e.g. San Francisco, CA"
60+
},
61+
"lat": {
62+
"type": "number",
63+
"description": "The city latitude"
64+
},
65+
"lon": {
66+
"type": "number",
67+
"description": "The city longitude"
68+
},
69+
"unit": {
70+
"type": "string",
71+
"enum": ["C", "F"]
72+
}
73+
},
74+
"required": ["location", "lat", "lon", "unit"]
75+
}
76+
"""));
77+
78+
@Test
79+
public void toolFunctionCall() {
80+
var promptOptions = MoonshotChatOptions.builder()
81+
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
82+
.withTools(Arrays.asList(FUNCTION_TOOL))
83+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
84+
.function("getCurrentWeather", new MockWeatherService())
85+
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
86+
.inputType(MockWeatherService.Request.class)
87+
.build()))
88+
.build();
89+
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
90+
promptOptions);
91+
92+
ChatResponse chatResponse = this.chatModel.call(prompt);
93+
assertThat(chatResponse).isNotNull();
94+
assertThat(chatResponse.getResult().getOutput());
95+
assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco");
96+
assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0");
97+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
98+
}
99+
100+
@Test
101+
public void testStreamFunctionCall() {
102+
var promptOptions = MoonshotChatOptions.builder()
103+
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
104+
.withTools(Arrays.asList(FUNCTION_TOOL))
105+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
106+
.function("getCurrentWeather", new MockWeatherService())
107+
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
108+
.inputType(MockWeatherService.Request.class)
109+
.build()))
110+
.build();
111+
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
112+
promptOptions);
113+
114+
Flux<ChatResponse> chatResponse = this.chatModel.stream(prompt);
115+
String content = chatResponse.collectList()
116+
.block()
117+
.stream()
118+
.map(ChatResponse::getResults)
119+
.flatMap(List::stream)
120+
.map(Generation::getOutput)
121+
.map(AssistantMessage::getText)
122+
.collect(Collectors.joining());
123+
assertThat(content).contains("San Francisco");
124+
assertThat(content).contains("30.0");
125+
}
126+
127+
@Test
128+
public void testStreamFunctionCallUsage() {
129+
var promptOptions = MoonshotChatOptions.builder()
130+
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
131+
.withTools(Arrays.asList(FUNCTION_TOOL))
132+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
133+
.function("getCurrentWeather", new MockWeatherService())
134+
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
135+
.inputType(MockWeatherService.Request.class)
136+
.build()))
137+
.build();
138+
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
139+
promptOptions);
140+
141+
ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast();
142+
assertThat(chatResponse).isNotNull();
143+
assertThat(chatResponse.getMetadata()).isNotNull();
144+
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
145+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
146+
}
147+
148+
@SpringBootConfiguration
149+
public static class Config {
150+
151+
@Bean
152+
public MoonshotApi moonshotApi() {
153+
return new MoonshotApi(getApiKey());
154+
}
155+
156+
private String getApiKey() {
157+
String apiKey = System.getenv("MOONSHOT_API_KEY");
158+
if (!StringUtils.hasText(apiKey)) {
159+
throw new IllegalArgumentException(
160+
"You must provide an API key. Put it in an environment variable under the name MOONSHOT_API_KEY");
161+
}
162+
return apiKey;
163+
}
164+
165+
@Bean
166+
public MoonshotChatModel moonshotChatModel(MoonshotApi api) {
167+
return new MoonshotChatModel(api);
168+
}
169+
170+
}
171+
172+
}

models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public void beforeEach() {
7979
public void moonshotChatTransientError() {
8080

8181
var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
82-
ChatCompletionFinishReason.STOP);
82+
ChatCompletionFinishReason.STOP, null);
8383
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model",
8484
List.of(choice), new MoonshotApi.Usage(10, 10, 10));
8585

0 commit comments

Comments
 (0)