From b61309f9d28be3d5a0ce917d5ac85fa9436bf213 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Fri, 13 Dec 2024 13:54:41 +0000 Subject: [PATCH 1/2] Fix Moonshot Chat model toolcalling token usage - Accumulate the token usage when toolcalling is invoked - Fix both call() and stream() methods - Add `usage` field to the Chat completion choice as the usage is returned via Choice - Add Mootshot chatmodel ITs for functioncalling tests --- .../ai/moonshot/MoonshotChatModel.java | 44 ++++- .../ai/moonshot/api/MoonshotApi.java | 3 +- .../MoonshotStreamFunctionCallingHelper.java | 4 +- .../ai/moonshot/MoonShotChatModelIT.java | 172 ++++++++++++++++++ .../ai/moonshot/MoonshotRetryTests.java | 2 +- 5 files changed, 214 insertions(+), 11 deletions(-) create mode 100644 models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 0f41afbdd33..cce99e93efa 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -36,6 +36,8 @@ import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -75,6 +77,7 @@ * * @author Geng Rong * @author Alexandros Pappas + * @author Ilayaperumal Gopinathan */ public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel { @@ -180,6 +183,10 @@ private static Generation buildGeneration(Choice choice, Map met @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { ChatCompletionRequest request = createRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -218,8 +225,11 @@ public ChatResponse call(Prompt prompt) { // @formatter:on return buildGeneration(choice, metadata); }).toList(); - - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); + MoonshotApi.Usage usage = completionEntity.getBody().usage(); + Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + from(completionEntity.getBody(), cumulativeUsage)); observationContext.setResponse(chatResponse); @@ -232,7 +242,7 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); } return response; } @@ -244,6 +254,10 @@ public ChatOptions getDefaultOptions() { @Override public Flux stream(Prompt prompt) { + return this.internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -287,8 +301,11 @@ public Flux stream(Prompt prompt) { // @formatter:on return buildGeneration(choice, metadata); }).toList(); + MoonshotApi.Usage usage = chatCompletion2.usage(); + Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); - return new ChatResponse(generations, from(chatCompletion2)); + return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); } catch (Exception e) { logger.error("Error processing chat completion", e); @@ -303,7 +320,7 @@ public Flux stream(Prompt prompt) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); } return Flux.just(response); }) @@ -325,6 +342,16 @@ private ChatResponseMetadata from(ChatCompletion result) { .build(); } + private ChatResponseMetadata from(ChatCompletion result, Usage usage) { + Assert.notNull(result, "Moonshot ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id() != null ? result.id() : "") + .withUsage(usage) + .withModel(result.model() != null ? result.model() : "") + .withKeyValue("created", result.created() != null ? result.created() : 0L) + .build(); + } + /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert @@ -336,10 +363,11 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) { if (delta == null) { delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT); } - return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason()); + return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage()); }).toList(); - - return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null); + // Get the usage from the latest choice + MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage(); + return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage); } /** diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java index b4a2162e28b..532fb851b8b 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java @@ -532,7 +532,8 @@ public record Choice( // @formatter:off @JsonProperty("index") Integer index, @JsonProperty("message") ChatCompletionMessage message, - @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("usage") Usage usage) { // @formatter:on } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java index 06f1dc7655d..df03cbb8015 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java @@ -64,8 +64,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { : previous.finishReason()); Integer index = (current.index() != null ? current.index() : previous.index()); + MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage(); + ChatCompletionMessage message = merge(previous.delta(), current.delta()); - return new ChunkChoice(index, message, finishReason, null); + return new ChunkChoice(index, message, finishReason, usage); } private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java new file mode 100644 index 00000000000..f7948cd418d --- /dev/null +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java @@ -0,0 +1,172 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.moonshot; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.moonshot.api.MockWeatherService; +import org.springframework.ai.moonshot.api.MoonshotApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ilayaperumal Gopinathan + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+") +public class MoonShotChatModelIT { + + @Autowired + private MoonshotChatModel chatModel; + + private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool( + MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function( + "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """)); + + @Test + public void toolFunctionCall() { + var promptOptions = MoonshotChatOptions.builder() + .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) + .withTools(Arrays.asList(FUNCTION_TOOL)) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput()); + assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco"); + assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0"); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + + @Test + public void testStreamFunctionCall() { + var promptOptions = MoonshotChatOptions.builder() + .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) + .withTools(Arrays.asList(FUNCTION_TOOL)) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + Flux chatResponse = this.chatModel.stream(prompt); + String content = chatResponse.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + assertThat(content).contains("San Francisco"); + assertThat(content).contains("30.0"); + } + + @Test + public void testStreamFunctionCallUsage() { + var promptOptions = MoonshotChatOptions.builder() + .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) + .withTools(Arrays.asList(FUNCTION_TOOL)) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + + @SpringBootConfiguration + public static class Config { + + @Bean + public MoonshotApi moonshotApi() { + return new MoonshotApi(getApiKey()); + } + + private String getApiKey() { + String apiKey = System.getenv("MOONSHOT_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name MOONSHOT_API_KEY"); + } + return apiKey; + } + + @Bean + public MoonshotChatModel moonshotChatModel(MoonshotApi api) { + return new MoonshotChatModel(api); + } + + } + +} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java index 33ef4855623..af8f4c71319 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java @@ -80,7 +80,7 @@ public void beforeEach() { public void moonshotChatTransientError() { var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), - ChatCompletionFinishReason.STOP); + ChatCompletionFinishReason.STOP, null); ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model", List.of(choice), new MoonshotApi.Usage(10, 10, 10)); From 4f4c2d0a1f3cc0b030a72fa7d3ad93d30b7e9ca0 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Thu, 2 Jan 2025 17:08:08 +0000 Subject: [PATCH 2/2] Move the tests into MoonshotChatModelFunctionCallingIT --- .../ai/moonshot/MoonshotChatModel.java | 8 +- .../ai/moonshot/MoonShotChatModelIT.java | 172 ------------------ .../MoonshotChatModelFunctionCallingIT.java | 74 +++++++- 3 files changed, 77 insertions(+), 177 deletions(-) delete mode 100644 models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index cce99e93efa..9fd7dff2599 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -345,10 +345,10 @@ private ChatResponseMetadata from(ChatCompletion result) { private ChatResponseMetadata from(ChatCompletion result, Usage usage) { Assert.notNull(result, "Moonshot ChatCompletionResult must not be null"); return ChatResponseMetadata.builder() - .withId(result.id() != null ? result.id() : "") - .withUsage(usage) - .withModel(result.model() != null ? result.model() : "") - .withKeyValue("created", result.created() != null ? result.created() : 0L) + .id(result.id() != null ? result.id() : "") + .usage(usage) + .model(result.model() != null ? result.model() : "") + .keyValue("created", result.created() != null ? result.created() : 0L) .build(); } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java deleted file mode 100644 index f7948cd418d..00000000000 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonShotChatModelIT.java +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.moonshot; - -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.moonshot.api.MockWeatherService; -import org.springframework.ai.moonshot.api.MoonshotApi; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.SpringBootConfiguration; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.context.annotation.Bean; -import org.springframework.util.StringUtils; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Ilayaperumal Gopinathan - */ -@SpringBootTest -@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+") -public class MoonShotChatModelIT { - - @Autowired - private MoonshotChatModel chatModel; - - private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool( - MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function( - "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """ - { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "lat": { - "type": "number", - "description": "The city latitude" - }, - "lon": { - "type": "number", - "description": "The city longitude" - }, - "unit": { - "type": "string", - "enum": ["C", "F"] - } - }, - "required": ["location", "lat", "lon", "unit"] - } - """)); - - @Test - public void toolFunctionCall() { - var promptOptions = MoonshotChatOptions.builder() - .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) - .withTools(Arrays.asList(FUNCTION_TOOL)) - .withFunctionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) - .description("Get the weather in location. Return temperature in 36°F or 36°C format.") - .inputType(MockWeatherService.Request.class) - .build())) - .build(); - Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", - promptOptions); - - ChatResponse chatResponse = this.chatModel.call(prompt); - assertThat(chatResponse).isNotNull(); - assertThat(chatResponse.getResult().getOutput()); - assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco"); - assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0"); - assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); - } - - @Test - public void testStreamFunctionCall() { - var promptOptions = MoonshotChatOptions.builder() - .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) - .withTools(Arrays.asList(FUNCTION_TOOL)) - .withFunctionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) - .description("Get the weather in location. Return temperature in 36°F or 36°C format.") - .inputType(MockWeatherService.Request.class) - .build())) - .build(); - Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", - promptOptions); - - Flux chatResponse = this.chatModel.stream(prompt); - String content = chatResponse.collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getText) - .collect(Collectors.joining()); - assertThat(content).contains("San Francisco"); - assertThat(content).contains("30.0"); - } - - @Test - public void testStreamFunctionCallUsage() { - var promptOptions = MoonshotChatOptions.builder() - .withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) - .withTools(Arrays.asList(FUNCTION_TOOL)) - .withFunctionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) - .description("Get the weather in location. Return temperature in 36°F or 36°C format.") - .inputType(MockWeatherService.Request.class) - .build())) - .build(); - Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", - promptOptions); - - ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast(); - assertThat(chatResponse).isNotNull(); - assertThat(chatResponse.getMetadata()).isNotNull(); - assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); - assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); - } - - @SpringBootConfiguration - public static class Config { - - @Bean - public MoonshotApi moonshotApi() { - return new MoonshotApi(getApiKey()); - } - - private String getApiKey() { - String apiKey = System.getenv("MOONSHOT_API_KEY"); - if (!StringUtils.hasText(apiKey)) { - throw new IllegalArgumentException( - "You must provide an API key. Put it in an environment variable under the name MOONSHOT_API_KEY"); - } - return apiKey; - } - - @Bean - public MoonshotChatModel moonshotChatModel(MoonshotApi api) { - return new MoonshotChatModel(api); - } - - } - -} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java index 6b4d5ba19b7..f24600653a4 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.ai.moonshot.chat; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -53,6 +54,33 @@ class MoonshotChatModelFunctionCallingIT { @Autowired ChatModel chatModel; + private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool( + MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function( + "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """)); + @Test void functionCallTest() { @@ -89,6 +117,7 @@ void streamFunctionCallTest() { .functionCallbacks(List.of(FunctionCallback.builder() .function("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) .build())) .build(); @@ -108,4 +137,47 @@ void streamFunctionCallTest() { assertThat(content).contains("30", "10", "15"); } + @Test + public void toolFunctionCallWithUsage() { + var promptOptions = MoonshotChatOptions.builder() + .model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) + .tools(Arrays.asList(FUNCTION_TOOL)) + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput()); + assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco"); + assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0"); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + + @Test + public void testStreamFunctionCallUsage() { + var promptOptions = MoonshotChatOptions.builder() + .model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) + .tools(Arrays.asList(FUNCTION_TOOL)) + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + }