Skip to content

Commit 75d6866

Browse files
QuentinLowetzolov
authored andcommitted
Change the Prompt's modelOptions() to return ChatOptions instead of ModelOptions.
1 parent 7d83ed6 commit 75d6866

File tree

20 files changed

+104
-192
lines changed

20 files changed

+104
-192
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
4242
import org.springframework.ai.chat.model.ChatResponse;
4343
import org.springframework.ai.chat.model.Generation;
44-
import org.springframework.ai.chat.model.StreamingChatModel;
4544
import org.springframework.ai.chat.messages.MessageType;
4645
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4746
import org.springframework.ai.chat.prompt.ChatOptions;
@@ -59,11 +58,12 @@
5958
* The {@link ChatModel} implementation for the Anthropic service.
6059
*
6160
* @author Christian Tzolov
61+
* @author luocongqiu
6262
* @since 1.0.0
6363
*/
6464
public class AnthropicChatModel extends
6565
AbstractFunctionCallSupport<AnthropicApi.RequestMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletion>>
66-
implements ChatModel, StreamingChatModel {
66+
implements ChatModel {
6767

6868
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);
6969

@@ -81,7 +81,7 @@ public class AnthropicChatModel extends
8181
/**
8282
* The default options used for the chat completion requests.
8383
*/
84-
private AnthropicChatOptions defaultOptions;
84+
private final AnthropicChatOptions defaultOptions;
8585

8686
/**
8787
* The retry template used to retry the OpenAI API calls.
@@ -280,20 +280,14 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
280280
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);
281281

282282
if (prompt.getOptions() != null) {
283-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
284-
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
285-
ChatOptions.class, AnthropicChatOptions.class);
283+
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
284+
ChatOptions.class, AnthropicChatOptions.class);
286285

287-
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
288-
IS_RUNTIME_CALL);
289-
functionsForThisRequest.addAll(promptEnabledFunctions);
286+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
287+
IS_RUNTIME_CALL);
288+
functionsForThisRequest.addAll(promptEnabledFunctions);
290289

291-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
292-
}
293-
else {
294-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
295-
+ prompt.getOptions().getClass().getSimpleName());
296-
}
290+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
297291
}
298292

299293
if (this.defaultOptions != null) {

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2626
import com.fasterxml.jackson.annotation.JsonProperty;
2727

28+
import org.springframework.ai.anthropic.api.AnthropicApi;
2829
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
2930
import org.springframework.ai.chat.prompt.ChatOptions;
3031
import org.springframework.ai.model.function.FunctionCallback;
@@ -90,6 +91,11 @@ public Builder withModel(String model) {
9091
return this;
9192
}
9293

94+
public Builder withModel(AnthropicApi.ChatModel model) {
95+
this.options.model = model.getValue();
96+
return this;
97+
}
98+
9399
public Builder withMaxTokens(Integer maxTokens) {
94100
this.options.maxTokens = maxTokens;
95101
return this;

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,11 @@ void functionCallTest() {
202202
List<Message> messages = new ArrayList<>(List.of(userMessage));
203203

204204
var promptOptions = AnthropicChatOptions.builder()
205-
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
205+
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS)
206206
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
207207
.withName("getCurrentWeather")
208-
.withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.")
208+
.withDescription(
209+
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
209210
.build()))
210211
.build();
211212

@@ -214,9 +215,7 @@ void functionCallTest() {
214215
logger.info("Response: {}", response);
215216

216217
Generation generation = response.getResult();
217-
assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30");
218-
assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10");
219-
assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15");
218+
assertThat(generation.getOutput().getContent()).contains("30", "10", "15");
220219
}
221220

222221
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import org.springframework.ai.chat.model.ChatModel;
5252
import org.springframework.ai.chat.model.ChatResponse;
5353
import org.springframework.ai.chat.model.Generation;
54-
import org.springframework.ai.chat.model.StreamingChatModel;
5554
import org.springframework.ai.chat.prompt.ChatOptions;
5655
import org.springframework.ai.chat.prompt.Prompt;
5756
import org.springframework.ai.model.ModelOptionsUtils;
@@ -79,12 +78,12 @@
7978
* @author Christian Tzolov
8079
* @author Grogdunn
8180
* @author Benoit Moussaud
81+
* @author luocongqiu
8282
* @see ChatModel
8383
* @see com.azure.ai.openai.OpenAIClient
8484
*/
85-
public class AzureOpenAiChatModel
86-
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
87-
implements ChatModel, StreamingChatModel {
85+
public class AzureOpenAiChatModel extends
86+
AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions> implements ChatModel {
8887

8988
private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";
9089

@@ -233,24 +232,17 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
233232
}
234233

235234
if (prompt.getOptions() != null) {
236-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
237-
AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
238-
ChatOptions.class, AzureOpenAiChatOptions.class);
239-
// JSON merge doesn't due to Azure OpenAI service bug:
240-
// https://github.com/Azure/azure-sdk-for-java/issues/38183
241-
// options = ModelOptionsUtils.merge(runtimeOptions, options,
242-
// ChatCompletionsOptions.class);
243-
options = merge(updatedRuntimeOptions, options);
244-
245-
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
246-
IS_RUNTIME_CALL);
247-
functionsForThisRequest.addAll(promptEnabledFunctions);
235+
AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
236+
ChatOptions.class, AzureOpenAiChatOptions.class);
237+
// JSON merge doesn't due to Azure OpenAI service bug:
238+
// https://github.com/Azure/azure-sdk-for-java/issues/38183
239+
// options = ModelOptionsUtils.merge(runtimeOptions, options,
240+
// ChatCompletionsOptions.class);
241+
options = merge(updatedRuntimeOptions, options);
248242

249-
}
250-
else {
251-
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:"
252-
+ prompt.getOptions().getClass().getSimpleName());
253-
}
243+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
244+
IS_RUNTIME_CALL);
245+
functionsForThisRequest.addAll(promptEnabledFunctions);
254246
}
255247

256248
// Add the enabled functions definitions to the request's tools parameter.

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,9 @@ AnthropicChatRequest createRequest(Prompt prompt) {
103103
}
104104

105105
if (prompt.getOptions() != null) {
106-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
107-
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
108-
ChatOptions.class, AnthropicChatOptions.class);
109-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class);
110-
}
111-
else {
112-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
113-
+ prompt.getOptions().getClass().getSimpleName());
114-
}
106+
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
107+
ChatOptions.class, AnthropicChatOptions.class);
108+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class);
115109
}
116110

117111
return request;

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,9 @@ AnthropicChatRequest createRequest(Prompt prompt) {
122122
}
123123

124124
if (prompt.getOptions() != null) {
125-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
126-
Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
127-
ChatOptions.class, Anthropic3ChatOptions.class);
128-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class);
129-
}
130-
else {
131-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
132-
+ prompt.getOptions().getClass().getSimpleName());
133-
}
125+
Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
126+
ChatOptions.class, Anthropic3ChatOptions.class);
127+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class);
134128
}
135129

136130
return request;

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,9 @@ CohereChatRequest createRequest(Prompt prompt, boolean stream) {
100100
.build();
101101

102102
if (prompt.getOptions() != null) {
103-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
104-
BedrockCohereChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
105-
ChatOptions.class, BedrockCohereChatOptions.class);
106-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereChatRequest.class);
107-
}
108-
else {
109-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
110-
+ prompt.getOptions().getClass().getSimpleName());
111-
}
103+
BedrockCohereChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
104+
ChatOptions.class, BedrockCohereChatOptions.class);
105+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereChatRequest.class);
112106
}
113107

114108
return request;

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,9 @@ private Ai21Jurassic2ChatRequest createRequest(Prompt prompt) {
7676
Ai21Jurassic2ChatRequest request = Ai21Jurassic2ChatRequest.builder(promptValue).build();
7777

7878
if (prompt.getOptions() != null) {
79-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
80-
BedrockAi21Jurassic2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
81-
ChatOptions.class, BedrockAi21Jurassic2ChatOptions.class);
82-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Ai21Jurassic2ChatRequest.class);
83-
}
84-
else {
85-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
86-
+ prompt.getOptions().getClass().getSimpleName());
87-
}
79+
BedrockAi21Jurassic2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
80+
ChatOptions.class, BedrockAi21Jurassic2ChatOptions.class);
81+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Ai21Jurassic2ChatRequest.class);
8882
}
8983

9084
if (this.defaultOptions != null) {

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,10 @@ LlamaChatRequest createRequest(Prompt prompt) {
115115
}
116116

117117
if (prompt.getOptions() != null) {
118-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
119-
BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
120-
ChatOptions.class, BedrockLlamaChatOptions.class);
118+
BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
119+
ChatOptions.class, BedrockLlamaChatOptions.class);
121120

122-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class);
123-
}
124-
else {
125-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
126-
+ prompt.getOptions().getClass().getSimpleName());
127-
}
121+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class);
128122
}
129123

130124
return request;

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,10 @@ TitanChatRequest createRequest(Prompt prompt) {
100100
}
101101

102102
if (prompt.getOptions() != null) {
103-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
104-
BedrockTitanChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
105-
ChatOptions.class, BedrockTitanChatOptions.class);
103+
BedrockTitanChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
104+
ChatOptions.class, BedrockTitanChatOptions.class);
106105

107-
requestBuilder = update(requestBuilder, updatedRuntimeOptions);
108-
}
109-
else {
110-
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
111-
+ prompt.getOptions().getClass().getSimpleName());
112-
}
106+
requestBuilder = update(requestBuilder, updatedRuntimeOptions);
113107
}
114108

115109
return requestBuilder.build();

0 commit comments

Comments
 (0)