Skip to content

Fix Moonshot Chat model toolcalling token usage #1927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -75,6 +77,7 @@
*
* @author Geng Rong
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
*/
public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel {

Expand Down Expand Up @@ -180,6 +183,10 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand Down Expand Up @@ -218,8 +225,11 @@ public ChatResponse call(Prompt prompt) {
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
MoonshotApi.Usage usage = completionEntity.getBody().usage();
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line of code seems unnecessary. MoonshotApi.Usage can directly implement the Usage interface, and UsageUtils.getCumulativeUsage already handles null checks for the input. This approach ensures that callers don’t need to repeatedly check for currentUsage in every call, avoiding redundant code and improving cohesion.
If an EmptyUsage is truly needed, it could simply be a static constant like Usage.empty() or EmptyUsage.instance().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your perspective but we wanted to address this as part of #1487 - not just for OpenAI usage but across the models. But for now, I think we can go with the consistent approach across the models by passing the EmptyUsage object when the ChatModel API usage is null. I hope that is ok with you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i understand, i think is ok

Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
from(completionEntity.getBody(), cumulativeUsage));

observationContext.setResponse(chatResponse);

Expand All @@ -232,7 +242,7 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
return response;
}
Expand All @@ -244,6 +254,10 @@ public ChatOptions getDefaultOptions() {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand Down Expand Up @@ -287,8 +301,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
MoonshotApi.Usage usage = chatCompletion2.usage();
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above

Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);

return new ChatResponse(generations, from(chatCompletion2));
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
Expand All @@ -303,7 +320,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
return Flux.just(response);
})
Expand All @@ -325,6 +342,16 @@ private ChatResponseMetadata from(ChatCompletion result) {
.build();
}

private ChatResponseMetadata from(ChatCompletion result, Usage usage) {
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.id(result.id() != null ? result.id() : "")
.usage(usage)
.model(result.model() != null ? result.model() : "")
.keyValue("created", result.created() != null ? result.created() : 0L)
.build();
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
Expand All @@ -336,10 +363,11 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
if (delta == null) {
delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT);
}
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason());
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage());
}).toList();

return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
// Get the usage from the latest choice
MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage();
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ public record Choice(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("message") ChatCompletionMessage message,
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("usage") Usage usage) {
// @formatter:on
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
: previous.finishReason());
Integer index = (current.index() != null ? current.index() : previous.index());

MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage();

ChatCompletionMessage message = merge(previous.delta(), current.delta());
return new ChunkChoice(index, message, finishReason, null);
return new ChunkChoice(index, message, finishReason, usage);
}

private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void beforeEach() {
public void moonshotChatTransientError() {

var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
ChatCompletionFinishReason.STOP);
ChatCompletionFinishReason.STOP, null);
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model",
List.of(choice), new MoonshotApi.Usage(10, 10, 10));

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.ai.moonshot.chat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -53,6 +54,33 @@ class MoonshotChatModelFunctionCallingIT {
@Autowired
ChatModel chatModel;

private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool(
MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function(
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"lat": {
"type": "number",
"description": "The city latitude"
},
"lon": {
"type": "number",
"description": "The city longitude"
},
"unit": {
"type": "string",
"enum": ["C", "F"]
}
},
"required": ["location", "lat", "lon", "unit"]
}
"""));

@Test
void functionCallTest() {

Expand Down Expand Up @@ -89,6 +117,7 @@ void streamFunctionCallTest() {
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build();

Expand All @@ -108,4 +137,47 @@ void streamFunctionCallTest() {
assertThat(content).contains("30", "10", "15");
}

@Test
public void toolFunctionCallWithUsage() {
var promptOptions = MoonshotChatOptions.builder()
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
.tools(Arrays.asList(FUNCTION_TOOL))
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
promptOptions);

ChatResponse chatResponse = this.chatModel.call(prompt);
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getResult().getOutput());
assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco");
assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0");
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
}

@Test
public void testStreamFunctionCallUsage() {
var promptOptions = MoonshotChatOptions.builder()
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
.tools(Arrays.asList(FUNCTION_TOOL))
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
promptOptions);

ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast();
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getMetadata()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
}

}
Loading