Skip to content

Advancing Tool Support - Part 3 #2121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand All @@ -44,7 +45,9 @@
* @author Thomas Vitale
* @author Jihoon Kim
* @since 1.0.0
* @deprecated Use {@link ToolCallingManager} instead.
*/
@Deprecated
public abstract class AbstractToolCallSupport {

protected static final boolean IS_RUNTIME_CALL = true;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,7 @@
* @author Soby Chacko
* @author John Blum
* @author Alexandros Pappas
* @author Thomas Vitale
*/
public class ChatResponse implements ModelResponse<Generation> {

Expand Down Expand Up @@ -100,6 +101,16 @@ public ChatResponseMetadata getMetadata() {
return this.chatResponseMetadata;
}

/**
* Whether the model has requested the execution of a tool.
*/
public boolean hasToolCalls() {
if (CollectionUtils.isEmpty(generations)) {
return false;
}
return generations.stream().anyMatch(generation -> generation.getOutput().hasToolCalls());
}

@Override
public String toString() {
return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.function.FunctionCallback.SchemaType;
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.tool.resolution.TypeResolverHelper;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
Expand Down Expand Up @@ -55,7 +57,9 @@
* @author Christian Tzolov
* @author Christopher Smith
* @author Sebastien Deleuze
* @deprecated Use {@link SpringBeanToolCallbackResolver} instead.
*/
@Deprecated
public class DefaultFunctionCallbackResolver implements ApplicationContextAware, FunctionCallbackResolver {

private GenericApplicationContext applicationContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@

package org.springframework.ai.model.function;

import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.lang.NonNull;

/**
* Strategy interface for resolving {@link FunctionCallback} instances.
*
* @author Christian Tzolov
* @since 1.0.0
* @deprecated Use {@link ToolCallbackResolver} instead.
*/
@Deprecated
public interface FunctionCallbackResolver {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Set;
import java.util.function.Function;

import org.springframework.ai.model.tool.ToolCallingManager;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -40,7 +41,10 @@
* Helper class that reuses the {@link AbstractToolCallSupport} to implement the function
* call handling logic on the client side. Used when the withProxyToolCalls(true) option
* is enabled.
*
* @deprecated Use {@link ToolCallingManager} instead.
*/
@Deprecated
public class FunctionCallingHelper extends AbstractToolCallSupport {

public FunctionCallingHelper() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -39,14 +37,14 @@
*/
public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {

private List<ToolCallback> toolCallbacks = new ArrayList<>();
private List<FunctionCallback> toolCallbacks = new ArrayList<>();

private Set<String> tools = new HashSet<>();

private Map<String, Object> toolContext = new HashMap<>();

@Nullable
private Boolean toolCallReturnDirect;
private Boolean toolExecutionEnabled;

@Nullable
private String model;
Expand All @@ -73,23 +71,17 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
private Double topP;

@Override
public List<ToolCallback> getToolCallbacks() {
public List<FunctionCallback> getToolCallbacks() {
return List.copyOf(this.toolCallbacks);
}

@Override
public void setToolCallbacks(List<ToolCallback> toolCallbacks) {
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = new ArrayList<>(toolCallbacks);
}

@Override
public void setToolCallbacks(ToolCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
setToolCallbacks(List.of(toolCallbacks));
}

@Override
public Set<String> getTools() {
return Set.copyOf(this.tools);
Expand All @@ -103,12 +95,6 @@ public void setTools(Set<String> tools) {
this.tools = new HashSet<>(tools);
}

@Override
public void setTools(String... tools) {
Assert.notNull(tools, "tools cannot be null");
setTools(Set.of(tools));
}

@Override
public Map<String, Object> getToolContext() {
return Map.copyOf(this.toolContext);
Expand All @@ -123,23 +109,23 @@ public void setToolContext(Map<String, Object> toolContext) {

@Override
@Nullable
public Boolean getToolCallReturnDirect() {
return this.toolCallReturnDirect;
public Boolean isToolExecutionEnabled() {
return this.toolExecutionEnabled;
}

@Override
public void setToolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) {
this.toolCallReturnDirect = toolCallReturnDirect;
public void setToolExecutionEnabled(@Nullable Boolean toolExecutionEnabled) {
this.toolExecutionEnabled = toolExecutionEnabled;
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return getToolCallbacks().stream().map(FunctionCallback.class::cast).toList();
return getToolCallbacks();
}

@Override
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
throw new UnsupportedOperationException("Not supported. Call setToolCallbacks instead.");
setToolCallbacks(functionCallbacks);
}

@Override
Expand All @@ -155,12 +141,12 @@ public void setFunctions(Set<String> functions) {
@Override
@Nullable
public Boolean getProxyToolCalls() {
return getToolCallReturnDirect();
return isToolExecutionEnabled() != null ? !isToolExecutionEnabled() : null;
}

@Override
public void setProxyToolCalls(@Nullable Boolean proxyToolCalls) {
setToolCallReturnDirect(proxyToolCalls != null && proxyToolCalls);
setToolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls);
}

@Override
Expand Down Expand Up @@ -250,7 +236,7 @@ public <T extends ChatOptions> T copy() {
options.setToolCallbacks(getToolCallbacks());
options.setTools(getTools());
options.setToolContext(getToolContext());
options.setToolCallReturnDirect(getToolCallReturnDirect());
options.setToolExecutionEnabled(isToolExecutionEnabled());
options.setModel(getModel());
options.setFrequencyPenalty(getFrequencyPenalty());
options.setMaxTokens(getMaxTokens());
Expand All @@ -262,55 +248,6 @@ public <T extends ChatOptions> T copy() {
return (T) options;
}

/**
* Merge the given {@link ChatOptions} into this instance.
*/
public ToolCallingChatOptions merge(ChatOptions options) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method would not by called by any class since we always use the logic in ModelOptionsUtil to merge options, so I removed this.

ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder();
builder.model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.getModel());
builder.frequencyPenalty(
options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.getFrequencyPenalty());
builder.maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxTokens());
builder.presencePenalty(
options.getPresencePenalty() != null ? options.getPresencePenalty() : this.getPresencePenalty());
builder.stopSequences(options.getStopSequences() != null ? new ArrayList<>(options.getStopSequences())
: this.getStopSequences());
builder.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature());
builder.topK(options.getTopK() != null ? options.getTopK() : this.getTopK());
builder.topP(options.getTopP() != null ? options.getTopP() : this.getTopP());

if (options instanceof ToolCallingChatOptions toolOptions) {
List<ToolCallback> toolCallbacks = new ArrayList<>(this.toolCallbacks);
if (!CollectionUtils.isEmpty(toolOptions.getToolCallbacks())) {
toolCallbacks.addAll(toolOptions.getToolCallbacks());
}
builder.toolCallbacks(toolCallbacks);

Set<String> tools = new HashSet<>(this.tools);
if (!CollectionUtils.isEmpty(toolOptions.getTools())) {
tools.addAll(toolOptions.getTools());
}
builder.tools(tools);

Map<String, Object> toolContext = new HashMap<>(this.toolContext);
if (!CollectionUtils.isEmpty(toolOptions.getToolContext())) {
toolContext.putAll(toolOptions.getToolContext());
}
builder.toolContext(toolContext);

builder.toolCallReturnDirect(toolOptions.getToolCallReturnDirect() != null
? toolOptions.getToolCallReturnDirect() : this.getToolCallReturnDirect());
}
else {
builder.toolCallbacks(this.toolCallbacks);
builder.tools(this.tools);
builder.toolContext(this.toolContext);
builder.toolCallReturnDirect(this.toolCallReturnDirect);
}

return builder.build();
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -323,14 +260,15 @@ public static class Builder implements ToolCallingChatOptions.Builder {
private final DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions();

@Override
public ToolCallingChatOptions.Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
public ToolCallingChatOptions.Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
return this;
}

@Override
public ToolCallingChatOptions.Builder toolCallbacks(ToolCallback... toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
public ToolCallingChatOptions.Builder toolCallbacks(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.options.setToolCallbacks(Arrays.asList(toolCallbacks));
return this;
}

Expand All @@ -342,7 +280,8 @@ public ToolCallingChatOptions.Builder tools(Set<String> toolNames) {

@Override
public ToolCallingChatOptions.Builder tools(String... toolNames) {
this.options.setTools(toolNames);
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.setTools(Set.of(toolNames));
return this;
}

Expand All @@ -363,16 +302,15 @@ public ToolCallingChatOptions.Builder toolContext(String key, Object value) {
}

@Override
public ToolCallingChatOptions.Builder toolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) {
this.options.setToolCallReturnDirect(toolCallReturnDirect);
public ToolCallingChatOptions.Builder toolExecutionEnabled(@Nullable Boolean toolExecutionEnabled) {
this.options.setToolExecutionEnabled(toolExecutionEnabled);
return this;
}

@Override
@Deprecated // Use toolCallbacks() instead
public ToolCallingChatOptions.Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
return toolCallbacks(functionCallbacks.stream().map(ToolCallback.class::cast).toList());
return toolCallbacks(functionCallbacks);
}

@Override
Expand All @@ -395,9 +333,9 @@ public ToolCallingChatOptions.Builder function(String function) {
}

@Override
@Deprecated // Use toolCallReturnDirect() instead
@Deprecated // Use toolExecutionEnabled() instead
public ToolCallingChatOptions.Builder proxyToolCalls(@Nullable Boolean proxyToolCalls) {
return toolCallReturnDirect(proxyToolCalls != null && proxyToolCalls);
return toolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls);
}

@Override
Expand Down
Loading
Loading