Skip to content

Advancing Tool Support - Part 5 #2162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -271,10 +272,19 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon

if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null
&& response.hasToolCalls()) {
var toolCallConversation = this.toolCallingManager.executeToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return ChatResponse.builder()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a given tool is marked with the option returnDirect, the tool call result is sent back to the user directly, wrapping each tool call result in a Generation object.

That is especially useful when building agents. For example, if I have a RAG tool, I might want to get the result back directly instead of having the main agent model post-process the RAG answer. Or I might want certain tools to end the reasoning loop of an agent.

.from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build();
}
else {
// Send the tool execution result back to the model.
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
}

return response;
Expand Down Expand Up @@ -335,10 +345,17 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
// @formatter:off
Flux<ChatResponse> chatResponseFlux = chatResponse.flatMap(response -> {
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {
var toolCallConversation = this.toolCallingManager.executeToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
} else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
}
else {
return Flux.just(response);
Expand Down Expand Up @@ -379,13 +396,13 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
// Merge tool names and tool callbacks explicitly since they are ignored by
// Jackson, used by ModelOptionsUtils.
if (runtimeOptions != null) {
requestOptions.setTools(
ToolCallingChatOptions.mergeToolNames(runtimeOptions.getTools(), this.defaultOptions.getTools()));
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
this.defaultOptions.getToolNames()));
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
this.defaultOptions.getToolCallbacks()));
}
else {
requestOptions.setTools(this.defaultOptions.getTools());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
.mirostatEta(fromOptions.getMirostatEta())
.penalizeNewline(fromOptions.getPenalizeNewline())
.stop(fromOptions.getStop())
.tools(fromOptions.getTools())
.toolNames(fromOptions.getToolNames())
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
.toolCallbacks(fromOptions.getToolCallbacks())
.toolContext(fromOptions.getToolContext()).build();
Expand Down Expand Up @@ -700,13 +700,13 @@ public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {

@Override
@JsonIgnore
public Set<String> getTools() {
public Set<String> getToolNames() {
return this.toolNames;
}

@Override
@JsonIgnore
public void setTools(Set<String> toolNames) {
public void setToolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
Expand Down Expand Up @@ -744,14 +744,14 @@ public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
@Deprecated
@JsonIgnore
public Set<String> getFunctions() {
return this.getTools();
return this.getToolNames();
}

@Override
@Deprecated
@JsonIgnore
public void setFunctions(Set<String> functions) {
this.setTools(functions);
this.setToolNames(functions);
}

@Override
Expand Down Expand Up @@ -1028,12 +1028,12 @@ public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
return this;
}

public Builder tools(Set<String> toolNames) {
this.options.setTools(toolNames);
public Builder toolNames(Set<String> toolNames) {
this.options.setToolNames(toolNames);
return this;
}

public Builder tools(String... toolNames) {
public Builder toolNames(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.toolNames.addAll(Set.of(toolNames));
return this;
Expand All @@ -1051,12 +1051,12 @@ public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {

@Deprecated
public Builder functions(Set<String> functions) {
return tools(functions);
return toolNames(functions);
}

@Deprecated
public Builder function(String functionName) {
return tools(functionName);
return toolNames(functionName);
}

@Deprecated
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,7 +34,7 @@
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.ai.util.json.JsonSchemaGenerator;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
<com.google.cloud.version>26.48.0</com.google.cloud.version>
<qdrant.version>1.9.1</qdrant.version>
<ibm.sdk.version>9.20.0</ibm.sdk.version>
<jsonschema.version>4.35.0</jsonschema.version>
<jsonschema.version>4.37.0</jsonschema.version>
<swagger-annotations.version>2.2.25</swagger-annotations.version>
<spring-cloud-bindings.version>2.0.3</spring-cloud-bindings.version>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ interface ChatClientRequestSpec {

ChatClientRequestSpec tools(String... toolNames);

ChatClientRequestSpec tools(Object... toolObjects);
ChatClientRequestSpec tools(FunctionCallback... toolCallbacks);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Java compiler complaints when trying to pass an array of FunctionCallback and requires casting. In this way, we keep the same API, but it's more explicit and no casting is required.


// ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);
ChatClientRequestSpec tools(Object... toolObjects);

@Deprecated
<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
Expand Down Expand Up @@ -281,6 +281,8 @@ interface Builder {

Builder defaultTools(String... toolNames);

Builder defaultTools(FunctionCallback... toolCallbacks);

Builder defaultTools(Object... toolObjects);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -845,42 +845,28 @@ public ChatClientRequestSpec tools(String... toolNames) {
return this;
}

@Override
public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.functionCallbacks.addAll(List.of(toolCallbacks));
return this;
}

@Override
public ChatClientRequestSpec tools(Object... toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");

List<FunctionCallback> functionCallbacks = new ArrayList<>();
List<Object> nonFunctinCallbacks = new ArrayList<>();
for (Object toolObject : toolObjects) {
if (toolObject instanceof FunctionCallback) {
functionCallbacks.add((FunctionCallback) toolObject);
}
else {
nonFunctinCallbacks.add(toolObject);
}
}
this.functionCallbacks.addAll(functionCallbacks);
this.functionCallbacks.addAll(Arrays
.asList(ToolCallbacks.from(nonFunctinCallbacks.toArray(new Object[nonFunctinCallbacks.size()]))));
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
return this;
}

// @Override
// public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) {
// Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
// Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null
// elements");
// this.functionCallbacks.addAll(Arrays.asList(toolCallbacks));
// return this;
// }

@Deprecated
@Deprecated // Use tools()
public ChatClientRequestSpec functions(String... functionBeanNames) {
return tools(functionBeanNames);
}

@Deprecated
@Deprecated // Use tools()
public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.core.io.Resource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -151,7 +150,13 @@ public Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer) {

@Override
public Builder defaultTools(String... toolNames) {
this.defaultRequest.functions(toolNames);
this.defaultRequest.tools(toolNames);
return this;
}

@Override
public Builder defaultTools(FunctionCallback... toolCallbacks) {
this.defaultRequest.tools(toolCallbacks);
return this;
}

Expand All @@ -161,24 +166,28 @@ public Builder defaultTools(Object... toolObjects) {
return this;
}

@Deprecated // Use defaultTools()
public <I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function) {
this.defaultRequest
.functions(FunctionCallback.builder().function(name, function).description(description).build());
return this;
}

@Deprecated // Use defaultTools()
public <I, O> Builder defaultFunction(String name, String description,
java.util.function.BiFunction<I, ToolContext, O> biFunction) {
this.defaultRequest
.functions(FunctionCallback.builder().function(name, biFunction).description(description).build());
return this;
}

@Deprecated // Use defaultTools()
public Builder defaultFunctions(String... functionNames) {
this.defaultRequest.functions(functionNames);
return this;
}

@Deprecated // Use defaultTools()
public Builder defaultFunctions(FunctionCallback... functionCallbacks) {
this.defaultRequest.functions(functionCallbacks);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {

private List<FunctionCallback> toolCallbacks = new ArrayList<>();

private Set<String> tools = new HashSet<>();
private Set<String> toolNames = new HashSet<>();

private Map<String, Object> toolContext = new HashMap<>();

Expand Down Expand Up @@ -83,16 +83,16 @@ public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
}

@Override
public Set<String> getTools() {
return Set.copyOf(this.tools);
public Set<String> getToolNames() {
return Set.copyOf(this.toolNames);
}

@Override
public void setTools(Set<String> tools) {
Assert.notNull(tools, "tools cannot be null");
Assert.noNullElements(tools, "tools cannot contain null elements");
tools.forEach(tool -> Assert.hasText(tool, "tools cannot contain empty elements"));
this.tools = new HashSet<>(tools);
public void setToolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
toolNames.forEach(toolName -> Assert.hasText(toolName, "toolNames cannot contain empty elements"));
this.toolNames = new HashSet<>(toolNames);
}

@Override
Expand Down Expand Up @@ -130,12 +130,12 @@ public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {

@Override
public Set<String> getFunctions() {
return getTools();
return getToolNames();
}

@Override
public void setFunctions(Set<String> functions) {
setTools(functions);
setToolNames(functions);
}

@Override
Expand Down Expand Up @@ -234,7 +234,7 @@ public void setTopP(@Nullable Double topP) {
public <T extends ChatOptions> T copy() {
DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions();
options.setToolCallbacks(getToolCallbacks());
options.setTools(getTools());
options.setToolNames(getToolNames());
options.setToolContext(getToolContext());
options.setInternalToolExecutionEnabled(isInternalToolExecutionEnabled());
options.setModel(getModel());
Expand Down Expand Up @@ -273,15 +273,15 @@ public ToolCallingChatOptions.Builder toolCallbacks(FunctionCallback... toolCall
}

@Override
public ToolCallingChatOptions.Builder tools(Set<String> toolNames) {
this.options.setTools(toolNames);
public ToolCallingChatOptions.Builder toolNames(Set<String> toolNames) {
this.options.setToolNames(toolNames);
return this;
}

@Override
public ToolCallingChatOptions.Builder tools(String... toolNames) {
public ToolCallingChatOptions.Builder toolNames(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.setTools(Set.of(toolNames));
this.options.setToolNames(Set.of(toolNames));
return this;
}

Expand Down Expand Up @@ -322,15 +322,15 @@ public ToolCallingChatOptions.Builder functionCallbacks(FunctionCallback... func
}

@Override
@Deprecated // Use tools() instead
@Deprecated // Use toolNames() instead
public ToolCallingChatOptions.Builder functions(Set<String> functions) {
return tools(functions);
return toolNames(functions);
}

@Override
@Deprecated // Use tools() instead
@Deprecated // Use toolNames() instead
public ToolCallingChatOptions.Builder function(String function) {
return tools(function);
return toolNames(function);
}

@Override
Expand Down
Loading