21
21
import org .springframework .ai .tool .ToolCallback ;
22
22
import org .springframework .lang .Nullable ;
23
23
import org .springframework .util .Assert ;
24
- import org .springframework .util .CollectionUtils ;
25
- import org .springframework .util .StringUtils ;
26
24
27
25
import java .util .ArrayList ;
28
26
import java .util .HashMap ;
@@ -46,7 +44,7 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
46
44
private Map <String , Object > toolContext = new HashMap <>();
47
45
48
46
@ Nullable
49
- private Boolean toolCallReturnDirect ;
47
+ private Boolean toolExecutionEnabled ;
50
48
51
49
@ Nullable
52
50
private String model ;
@@ -123,13 +121,13 @@ public void setToolContext(Map<String, Object> toolContext) {
123
121
124
122
@ Override
125
123
@ Nullable
126
- public Boolean getToolCallReturnDirect () {
127
- return this .toolCallReturnDirect ;
124
+ public Boolean isToolExecutionEnabled () {
125
+ return this .toolExecutionEnabled ;
128
126
}
129
127
130
128
@ Override
131
- public void setToolCallReturnDirect (@ Nullable Boolean toolCallReturnDirect ) {
132
- this .toolCallReturnDirect = toolCallReturnDirect ;
129
+ public void setToolExecutionEnabled (@ Nullable Boolean toolExecutionEnabled ) {
130
+ this .toolExecutionEnabled = toolExecutionEnabled ;
133
131
}
134
132
135
133
@ Override
@@ -139,7 +137,12 @@ public List<FunctionCallback> getFunctionCallbacks() {
139
137
140
138
@ Override
141
139
public void setFunctionCallbacks (List <FunctionCallback > functionCallbacks ) {
142
- throw new UnsupportedOperationException ("Not supported. Call setToolCallbacks instead." );
140
+ if (functionCallbacks .stream ().allMatch (ToolCallback .class ::isInstance )) {
141
+ setToolCallbacks (functionCallbacks .stream ().map (ToolCallback .class ::cast ).toList ());
142
+ }
143
+ else {
144
+ throw new IllegalArgumentException ("functionCallbacks must be instances of ToolCallback" );
145
+ }
143
146
}
144
147
145
148
@ Override
@@ -155,12 +158,12 @@ public void setFunctions(Set<String> functions) {
155
158
@ Override
156
159
@ Nullable
157
160
public Boolean getProxyToolCalls () {
158
- return getToolCallReturnDirect ();
161
+ return ! isToolExecutionEnabled ();
159
162
}
160
163
161
164
@ Override
162
165
public void setProxyToolCalls (@ Nullable Boolean proxyToolCalls ) {
163
- setToolCallReturnDirect (proxyToolCalls != null && proxyToolCalls );
166
+ setToolExecutionEnabled (proxyToolCalls == null || ! proxyToolCalls );
164
167
}
165
168
166
169
@ Override
@@ -250,7 +253,7 @@ public <T extends ChatOptions> T copy() {
250
253
options .setToolCallbacks (getToolCallbacks ());
251
254
options .setTools (getTools ());
252
255
options .setToolContext (getToolContext ());
253
- options .setToolCallReturnDirect ( getToolCallReturnDirect ());
256
+ options .setToolExecutionEnabled ( isToolExecutionEnabled ());
254
257
options .setModel (getModel ());
255
258
options .setFrequencyPenalty (getFrequencyPenalty ());
256
259
options .setMaxTokens (getMaxTokens ());
@@ -262,55 +265,6 @@ public <T extends ChatOptions> T copy() {
262
265
return (T ) options ;
263
266
}
264
267
265
- /**
266
- * Merge the given {@link ChatOptions} into this instance.
267
- */
268
- public ToolCallingChatOptions merge (ChatOptions options ) {
269
- ToolCallingChatOptions .Builder builder = ToolCallingChatOptions .builder ();
270
- builder .model (StringUtils .hasText (options .getModel ()) ? options .getModel () : this .getModel ());
271
- builder .frequencyPenalty (
272
- options .getFrequencyPenalty () != null ? options .getFrequencyPenalty () : this .getFrequencyPenalty ());
273
- builder .maxTokens (options .getMaxTokens () != null ? options .getMaxTokens () : this .getMaxTokens ());
274
- builder .presencePenalty (
275
- options .getPresencePenalty () != null ? options .getPresencePenalty () : this .getPresencePenalty ());
276
- builder .stopSequences (options .getStopSequences () != null ? new ArrayList <>(options .getStopSequences ())
277
- : this .getStopSequences ());
278
- builder .temperature (options .getTemperature () != null ? options .getTemperature () : this .getTemperature ());
279
- builder .topK (options .getTopK () != null ? options .getTopK () : this .getTopK ());
280
- builder .topP (options .getTopP () != null ? options .getTopP () : this .getTopP ());
281
-
282
- if (options instanceof ToolCallingChatOptions toolOptions ) {
283
- List <ToolCallback > toolCallbacks = new ArrayList <>(this .toolCallbacks );
284
- if (!CollectionUtils .isEmpty (toolOptions .getToolCallbacks ())) {
285
- toolCallbacks .addAll (toolOptions .getToolCallbacks ());
286
- }
287
- builder .toolCallbacks (toolCallbacks );
288
-
289
- Set <String > tools = new HashSet <>(this .tools );
290
- if (!CollectionUtils .isEmpty (toolOptions .getTools ())) {
291
- tools .addAll (toolOptions .getTools ());
292
- }
293
- builder .tools (tools );
294
-
295
- Map <String , Object > toolContext = new HashMap <>(this .toolContext );
296
- if (!CollectionUtils .isEmpty (toolOptions .getToolContext ())) {
297
- toolContext .putAll (toolOptions .getToolContext ());
298
- }
299
- builder .toolContext (toolContext );
300
-
301
- builder .toolCallReturnDirect (toolOptions .getToolCallReturnDirect () != null
302
- ? toolOptions .getToolCallReturnDirect () : this .getToolCallReturnDirect ());
303
- }
304
- else {
305
- builder .toolCallbacks (this .toolCallbacks );
306
- builder .tools (this .tools );
307
- builder .toolContext (this .toolContext );
308
- builder .toolCallReturnDirect (this .toolCallReturnDirect );
309
- }
310
-
311
- return builder .build ();
312
- }
313
-
314
268
public static Builder builder () {
315
269
return new Builder ();
316
270
}
@@ -363,16 +317,21 @@ public ToolCallingChatOptions.Builder toolContext(String key, Object value) {
363
317
}
364
318
365
319
@ Override
366
- public ToolCallingChatOptions .Builder toolCallReturnDirect (@ Nullable Boolean toolCallReturnDirect ) {
367
- this .options .setToolCallReturnDirect ( toolCallReturnDirect );
320
+ public ToolCallingChatOptions .Builder toolExecutionEnabled (@ Nullable Boolean toolExecutionEnabled ) {
321
+ this .options .setToolExecutionEnabled ( toolExecutionEnabled );
368
322
return this ;
369
323
}
370
324
371
325
@ Override
372
326
@ Deprecated // Use toolCallbacks() instead
373
327
public ToolCallingChatOptions .Builder functionCallbacks (List <FunctionCallback > functionCallbacks ) {
374
328
Assert .notNull (functionCallbacks , "functionCallbacks cannot be null" );
375
- return toolCallbacks (functionCallbacks .stream ().map (ToolCallback .class ::cast ).toList ());
329
+ if (functionCallbacks .stream ().allMatch (ToolCallback .class ::isInstance )) {
330
+ return toolCallbacks (functionCallbacks .stream ().map (ToolCallback .class ::cast ).toList ());
331
+ }
332
+ else {
333
+ throw new IllegalArgumentException ("functionCallbacks must be instances of ToolCallback" );
334
+ }
376
335
}
377
336
378
337
@ Override
@@ -395,9 +354,9 @@ public ToolCallingChatOptions.Builder function(String function) {
395
354
}
396
355
397
356
@ Override
398
- @ Deprecated // Use toolCallReturnDirect () instead
357
+ @ Deprecated // Use toolExecutionEnabled () instead
399
358
public ToolCallingChatOptions .Builder proxyToolCalls (@ Nullable Boolean proxyToolCalls ) {
400
- return toolCallReturnDirect (proxyToolCalls != null && proxyToolCalls );
359
+ return toolExecutionEnabled (proxyToolCalls == null || ! proxyToolCalls );
401
360
}
402
361
403
362
@ Override
0 commit comments