Skip to content

Commit 6021943

Browse files
committed
feat: support the chat_template_kwargs, with OpenAI-Compatible Server(https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters_3), relate: (#3409)
Signed-off-by: 家娃 <[email protected]>
1 parent 15e1eaa commit 6021943

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ public class OpenAiChatOptions implements ToolCallingChatOptions {
201201
*/
202202
private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions;
203203

204+
/**
205+
* This extra body for support thinking outside the context of the conversation.
206+
*/
207+
private @JsonProperty("chat_template_kwargs") Map<String,Object> chatTemplateKwargs;
208+
209+
204210
/**
205211
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
206212
*/
@@ -268,6 +274,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
268274
.metadata(fromOptions.getMetadata())
269275
.reasoningEffort(fromOptions.getReasoningEffort())
270276
.webSearchOptions(fromOptions.getWebSearchOptions())
277+
.chatTemplateKwargs(fromOptions.chatTemplateKwargs)
271278
.build();
272279
}
273280

@@ -564,6 +571,14 @@ public void setWebSearchOptions(WebSearchOptions webSearchOptions) {
564571
this.webSearchOptions = webSearchOptions;
565572
}
566573

574+
public Map<String, Object> getChatTemplateKwargs() {
575+
return chatTemplateKwargs;
576+
}
577+
578+
public void setChatTemplateKwargs(Map<String, Object> chatTemplateKwargs) {
579+
this.chatTemplateKwargs = chatTemplateKwargs;
580+
}
581+
567582
@Override
568583
public OpenAiChatOptions copy() {
569584
return OpenAiChatOptions.fromOptions(this);
@@ -576,7 +591,7 @@ public int hashCode() {
576591
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
577592
this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders,
578593
this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio,
579-
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions);
594+
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.chatTemplateKwargs);
580595
}
581596

582597
@Override
@@ -609,7 +624,8 @@ public boolean equals(Object o) {
609624
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
610625
&& Objects.equals(this.metadata, other.metadata)
611626
&& Objects.equals(this.reasoningEffort, other.reasoningEffort)
612-
&& Objects.equals(this.webSearchOptions, other.webSearchOptions);
627+
&& Objects.equals(this.webSearchOptions, other.webSearchOptions)
628+
&& Objects.equals(this.chatTemplateKwargs, other.chatTemplateKwargs);
613629
}
614630

615631
@Override
@@ -802,6 +818,11 @@ public Builder webSearchOptions(WebSearchOptions webSearchOptions) {
802818
return this;
803819
}
804820

821+
public Builder chatTemplateKwargs(Map<String, Object> chatTemplateKwargs) {
822+
this.options.chatTemplateKwargs = chatTemplateKwargs;
823+
return this;
824+
}
825+
805826
public OpenAiChatOptions build() {
806827
return this.options;
807828
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,8 @@ public record ChatCompletionRequest(// @formatter:off
10571057
@JsonProperty("parallel_tool_calls") Boolean parallelToolCalls,
10581058
@JsonProperty("user") String user,
10591059
@JsonProperty("reasoning_effort") String reasoningEffort,
1060-
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions) {
1060+
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions,
1061+
@JsonProperty("chat_template_kwargs") Map<String,Object> chatTemplateKwargs) {
10611062

10621063
/**
10631064
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
@@ -1069,7 +1070,7 @@ public record ChatCompletionRequest(// @formatter:off
10691070
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
10701071
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
10711072
null, null, null, false, null, temperature, null,
1072-
null, null, null, null, null, null);
1073+
null, null, null, null, null, null, null);
10731074
}
10741075

10751076
/**
@@ -1083,7 +1084,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
10831084
this(messages, model, null, null, null, null, null, null,
10841085
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
10851086
null, null, null, stream, null, null, null,
1086-
null, null, null, null, null, null);
1087+
null, null, null, null, null, null, null);
10871088
}
10881089

10891090
/**
@@ -1098,7 +1099,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
10981099
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
10991100
this(messages, model, null, null, null, null, null, null, null, null, null,
11001101
null, null, null, null, null, null, null, stream, null, temperature, null,
1101-
null, null, null, null, null, null);
1102+
null, null, null, null, null, null, null);
11021103
}
11031104

11041105
/**
@@ -1114,7 +1115,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
11141115
List<FunctionTool> tools, Object toolChoice) {
11151116
this(messages, model, null, null, null, null, null, null, null, null, null,
11161117
null, null, null, null, null, null, null, false, null, 0.8, null,
1117-
tools, toolChoice, null, null, null, null);
1118+
tools, toolChoice, null, null, null, null, null);
11181119
}
11191120

11201121
/**
@@ -1127,7 +1128,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
11271128
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
11281129
this(messages, null, null, null, null, null, null, null, null, null, null,
11291130
null, null, null, null, null, null, null, stream, null, null, null,
1130-
null, null, null, null, null, null);
1131+
null, null, null, null, null, null, null);
11311132
}
11321133

11331134
/**
@@ -1140,7 +1141,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
11401141
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
11411142
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
11421143
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
1143-
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions);
1144+
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.chatTemplateKwargs);
11441145
}
11451146

11461147
/**

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ void testCopy() {
141141
.reasoningEffort("low")
142142
.internalToolExecutionEnabled(true)
143143
.httpHeaders(Map.of("header1", "value1"))
144+
.chatTemplateKwargs(Map.of("enable_thinking", true))
144145
.build();
145146

146147
OpenAiChatOptions copiedOptions = originalOptions.copy();
@@ -189,6 +190,7 @@ void testSetters() {
189190
options.setReasoningEffort("high");
190191
options.setInternalToolExecutionEnabled(false);
191192
options.setHttpHeaders(Map.of("header2", "value2"));
193+
options.setChatTemplateKwargs(Map.of("enable_thinking", true));
192194

193195
assertThat(options.getModel()).isEqualTo("test-model");
194196
assertThat(options.getFrequencyPenalty()).isEqualTo(0.5);
@@ -223,6 +225,7 @@ void testSetters() {
223225
options.setStopSequences(List.of("s1", "s2"));
224226
assertThat(options.getStopSequences()).isEqualTo(List.of("s1", "s2"));
225227
assertThat(options.getStop()).isEqualTo(List.of("s1", "s2"));
228+
assertThat(options.getChatTemplateKwargs()).isEqualTo(Map.of("enable_thinking", true));
226229
}
227230

228231
@Test
@@ -258,6 +261,7 @@ void testDefaultValues() {
258261
assertThat(options.getToolContext()).isEqualTo(new HashMap<>());
259262
assertThat(options.getStreamUsage()).isFalse();
260263
assertThat(options.getStopSequences()).isNull();
264+
assertThat(options.getChatTemplateKwargs()).isNull();
261265
}
262266

263267
@Test

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void validateReasoningTokens() {
7575
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
7676
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null,
7777
null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null,
78-
null, null, null, "low", null);
78+
null, null, null, "low", null, null);
7979
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);
8080

8181
assertThat(response).isNotNull();

0 commit comments

Comments
 (0)