From f678969f5a10b01dd7c4420f2a9dff7dcf504249 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Mon, 27 Jan 2025 07:38:57 +0100 Subject: [PATCH] Advancing Tool Support - Part 3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduced ToolCallingManager to manage the tool calling activities for resolving and executing tools. A default implementation is provided. It can be used to handle explicit tool execution on the client-side, superseding the previous FunctionCallingHelper class. It’s ready to be instrumented via Micrometer, and support exception handling when tool calls fail. * Introduced ToolCallExceptionConverter to handle exceptions in tool calling, and provided a default implementation propagating the error message to the chat morel. * Introduced ToolCallbackResolver to resolve ToolCallback instances. A default implementation is provided (DelegatingToolCallbackResolver), capable of delegating the resolution to a series of resolvers, including static resolution (StaticToolCallbackResolver) and dynamic resolution from the Spring context (SpringBeanToolCallbackResolver). * Improved configuration in ToolCallingChatOptions to enable/disable the tool execution within a ChatModel (superseding the previous proxyToolCalls option). * Added unit and integration tests to cover all the new use cases and existing functionality which was not covered by autotests (tool resolution from Spring context). * Deprecated FunctionCallbackResolver, AbstractToolCallSupport, and FunctionCallingHelper. Relates to gh-2049 Signed-off-by: Thomas Vitale --- .../chat/model/AbstractToolCallSupport.java | 3 + .../ai/chat/model/ChatResponse.java | 13 +- .../DefaultFunctionCallbackResolver.java | 4 + .../function/FunctionCallbackResolver.java | 3 + .../model/function/FunctionCallingHelper.java | 4 + .../tool/DefaultToolCallingChatOptions.java | 112 ++------ .../model/tool/DefaultToolCallingManager.java | 249 +++++++++++++++++ .../ai/model/tool/ToolCallingChatOptions.java | 67 +++-- .../ai/model/tool/ToolCallingManager.java | 51 ++++ .../definition/DefaultToolDefinition.java | 4 +- .../DefaultToolCallExceptionConverter.java | 65 +++++ .../execution/ToolCallExceptionConverter.java | 35 +++ .../ai/tool/metadata/DefaultToolMetadata.java | 2 +- .../DelegatingToolCallbackResolver.java | 54 ++++ .../SpringBeanToolCallbackResolver.java | 240 ++++++++++++++++ .../StaticToolCallbackResolver.java | 49 ++++ .../tool/resolution/ToolCallbackResolver.java | 36 +++ .../resolution}/TypeResolverHelper.java | 6 +- .../ai/tool/resolution/package-info.java | 22 ++ .../ai/util/json/SchemaType.java | 37 +++ .../ai/chat/model/ChatResponseTests.java | 51 ++++ .../DefaultToolCallingChatOptionsTests.java | 73 ++--- .../tool/DefaultToolCallingManagerTests.java | 259 ++++++++++++++++++ .../tool/ToolCallingChatOptionsTests.java | 70 +++++ ...efaultToolCallExceptionConverterTests.java | 61 +++++ .../DelegatingToolCallbackResolverTests.java | 64 +++++ .../SpringBeanToolCallbackResolverTests.java | 175 ++++++++++++ .../StandaloneWeatherFunction.java | 8 +- .../StaticToolCallbackResolverTests.java | 66 +++++ .../resolution}/TypeResolverHelperIT.java | 9 +- .../resolution}/TypeResolverHelperTests.java | 9 +- .../component/ComponentWeatherFunction.java | 8 +- .../TypeResolverHelperConfiguration.java | 6 +- .../StandaloneWeatherKotlinFunction.kt | 8 - .../FunctionCallbackExtensionsTests.kt | 6 +- ...ringBeanToolCallbackResolverKotlinTests.kt | 114 ++++++++ .../StandaloneWeatherKotlinFunction.kt | 24 ++ .../resolution}/TypeResolverHelperKotlinIT.kt | 6 +- .../TypeResolverHelperKotlinConfiguration.kt | 6 +- .../RewriteQueryTransformerIT.java | 2 +- .../tests/tool/FunctionToolCallbackTests.java | 40 +-- .../tests/tool/MethodToolCallbackTests.java | 37 +-- .../tests/tool/ToolCallingManagerTests.java | 158 +++++++++++ .../integration/tests/tool/domain/Author.java | 23 ++ .../integration/tests/tool/domain/Book.java | 23 ++ .../tests/tool/domain/BookService.java | 50 ++++ 46 files changed, 2123 insertions(+), 289 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingManager.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallExceptionConverter.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java rename spring-ai-core/src/main/java/org/springframework/ai/{model/function => tool/resolution}/TypeResolverHelper.java (98%) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/package-info.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/util/json/SchemaType.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverterTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolverTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverTests.java rename spring-ai-core/src/test/java/org/springframework/ai/{model/function => tool/resolution}/StandaloneWeatherFunction.java (76%) create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java rename spring-ai-core/src/test/java/org/springframework/ai/{model/function => tool/resolution}/TypeResolverHelperIT.java (92%) rename spring-ai-core/src/test/java/org/springframework/ai/{model/function => tool/resolution}/TypeResolverHelperTests.java (91%) rename spring-ai-core/src/test/java/org/springframework/ai/{model/function => tool/resolution}/component/ComponentWeatherFunction.java (76%) rename spring-ai-core/src/test/java/org/springframework/ai/{model/function => tool/resolution}/config/TypeResolverHelperConfiguration.java (82%) delete mode 100644 spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/StandaloneWeatherKotlinFunction.kt rename spring-ai-core/src/test/kotlin/org/springframework/ai/{model/function => tool/resolution}/FunctionCallbackExtensionsTests.kt (82%) create mode 100644 spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverKotlinTests.kt create mode 100644 spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/StandaloneWeatherKotlinFunction.kt rename spring-ai-core/src/test/kotlin/org/springframework/ai/{model/function => tool/resolution}/TypeResolverHelperKotlinIT.kt (93%) rename spring-ai-core/src/test/kotlin/org/springframework/ai/{model/function => tool/resolution}/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt (82%) create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Author.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Book.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/BookService.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index eddd1d28bda..01f694c8aaa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -32,6 +32,7 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -44,7 +45,9 @@ * @author Thomas Vitale * @author Jihoon Kim * @since 1.0.0 + * @deprecated Use {@link ToolCallingManager} instead. */ +@Deprecated public abstract class AbstractToolCallSupport { protected static final boolean IS_RUNTIME_CALL = true; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java index 3b82319de6b..adf3075706c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ * @author Soby Chacko * @author John Blum * @author Alexandros Pappas + * @author Thomas Vitale */ public class ChatResponse implements ModelResponse { @@ -100,6 +101,16 @@ public ChatResponseMetadata getMetadata() { return this.chatResponseMetadata; } + /** + * Whether the model has requested the execution of a tool. + */ + public boolean hasToolCalls() { + if (CollectionUtils.isEmpty(generations)) { + return false; + } + return generations.stream().anyMatch(generation -> generation.getOutput().hasToolCalls()); + } + @Override public String toString() { return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]"; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackResolver.java index 9e9e0fd1a60..6278a450be3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackResolver.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackResolver.java @@ -28,6 +28,8 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.function.FunctionCallback.SchemaType; +import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; +import org.springframework.ai.tool.resolution.TypeResolverHelper; import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; @@ -55,7 +57,9 @@ * @author Christian Tzolov * @author Christopher Smith * @author Sebastien Deleuze + * @deprecated Use {@link SpringBeanToolCallbackResolver} instead. */ +@Deprecated public class DefaultFunctionCallbackResolver implements ApplicationContextAware, FunctionCallbackResolver { private GenericApplicationContext applicationContext; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackResolver.java index cb7a96cdc25..6da86bd6e01 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackResolver.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackResolver.java @@ -16,6 +16,7 @@ package org.springframework.ai.model.function; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; import org.springframework.lang.NonNull; /** @@ -23,7 +24,9 @@ * * @author Christian Tzolov * @since 1.0.0 + * @deprecated Use {@link ToolCallbackResolver} instead. */ +@Deprecated public interface FunctionCallbackResolver { /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java index 4f50e9b582e..ececf6fd7f4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java @@ -23,6 +23,7 @@ import java.util.Set; import java.util.function.Function; +import org.springframework.ai.model.tool.ToolCallingManager; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -40,7 +41,10 @@ * Helper class that reuses the {@link AbstractToolCallSupport} to implement the function * call handling logic on the client side. Used when the withProxyToolCalls(true) option * is enabled. + * + * @deprecated Use {@link ToolCallingManager} instead. */ +@Deprecated public class FunctionCallingHelper extends AbstractToolCallSupport { public FunctionCallingHelper() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index 15a8f015917..7e4471867c8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -18,13 +18,11 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -39,14 +37,14 @@ */ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); private Set tools = new HashSet<>(); private Map toolContext = new HashMap<>(); @Nullable - private Boolean toolCallReturnDirect; + private Boolean toolExecutionEnabled; @Nullable private String model; @@ -73,23 +71,17 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { private Double topP; @Override - public List getToolCallbacks() { + public List getToolCallbacks() { return List.copyOf(this.toolCallbacks); } @Override - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = new ArrayList<>(toolCallbacks); } - @Override - public void setToolCallbacks(ToolCallback... toolCallbacks) { - Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - setToolCallbacks(List.of(toolCallbacks)); - } - @Override public Set getTools() { return Set.copyOf(this.tools); @@ -103,12 +95,6 @@ public void setTools(Set tools) { this.tools = new HashSet<>(tools); } - @Override - public void setTools(String... tools) { - Assert.notNull(tools, "tools cannot be null"); - setTools(Set.of(tools)); - } - @Override public Map getToolContext() { return Map.copyOf(this.toolContext); @@ -123,23 +109,23 @@ public void setToolContext(Map toolContext) { @Override @Nullable - public Boolean getToolCallReturnDirect() { - return this.toolCallReturnDirect; + public Boolean isToolExecutionEnabled() { + return this.toolExecutionEnabled; } @Override - public void setToolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) { - this.toolCallReturnDirect = toolCallReturnDirect; + public void setToolExecutionEnabled(@Nullable Boolean toolExecutionEnabled) { + this.toolExecutionEnabled = toolExecutionEnabled; } @Override public List getFunctionCallbacks() { - return getToolCallbacks().stream().map(FunctionCallback.class::cast).toList(); + return getToolCallbacks(); } @Override public void setFunctionCallbacks(List functionCallbacks) { - throw new UnsupportedOperationException("Not supported. Call setToolCallbacks instead."); + setToolCallbacks(functionCallbacks); } @Override @@ -155,12 +141,12 @@ public void setFunctions(Set functions) { @Override @Nullable public Boolean getProxyToolCalls() { - return getToolCallReturnDirect(); + return isToolExecutionEnabled() != null ? !isToolExecutionEnabled() : null; } @Override public void setProxyToolCalls(@Nullable Boolean proxyToolCalls) { - setToolCallReturnDirect(proxyToolCalls != null && proxyToolCalls); + setToolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls); } @Override @@ -250,7 +236,7 @@ public T copy() { options.setToolCallbacks(getToolCallbacks()); options.setTools(getTools()); options.setToolContext(getToolContext()); - options.setToolCallReturnDirect(getToolCallReturnDirect()); + options.setToolExecutionEnabled(isToolExecutionEnabled()); options.setModel(getModel()); options.setFrequencyPenalty(getFrequencyPenalty()); options.setMaxTokens(getMaxTokens()); @@ -262,55 +248,6 @@ public T copy() { return (T) options; } - /** - * Merge the given {@link ChatOptions} into this instance. - */ - public ToolCallingChatOptions merge(ChatOptions options) { - ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder(); - builder.model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.getModel()); - builder.frequencyPenalty( - options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.getFrequencyPenalty()); - builder.maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxTokens()); - builder.presencePenalty( - options.getPresencePenalty() != null ? options.getPresencePenalty() : this.getPresencePenalty()); - builder.stopSequences(options.getStopSequences() != null ? new ArrayList<>(options.getStopSequences()) - : this.getStopSequences()); - builder.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature()); - builder.topK(options.getTopK() != null ? options.getTopK() : this.getTopK()); - builder.topP(options.getTopP() != null ? options.getTopP() : this.getTopP()); - - if (options instanceof ToolCallingChatOptions toolOptions) { - List toolCallbacks = new ArrayList<>(this.toolCallbacks); - if (!CollectionUtils.isEmpty(toolOptions.getToolCallbacks())) { - toolCallbacks.addAll(toolOptions.getToolCallbacks()); - } - builder.toolCallbacks(toolCallbacks); - - Set tools = new HashSet<>(this.tools); - if (!CollectionUtils.isEmpty(toolOptions.getTools())) { - tools.addAll(toolOptions.getTools()); - } - builder.tools(tools); - - Map toolContext = new HashMap<>(this.toolContext); - if (!CollectionUtils.isEmpty(toolOptions.getToolContext())) { - toolContext.putAll(toolOptions.getToolContext()); - } - builder.toolContext(toolContext); - - builder.toolCallReturnDirect(toolOptions.getToolCallReturnDirect() != null - ? toolOptions.getToolCallReturnDirect() : this.getToolCallReturnDirect()); - } - else { - builder.toolCallbacks(this.toolCallbacks); - builder.tools(this.tools); - builder.toolContext(this.toolContext); - builder.toolCallReturnDirect(this.toolCallReturnDirect); - } - - return builder.build(); - } - public static Builder builder() { return new Builder(); } @@ -323,14 +260,15 @@ public static class Builder implements ToolCallingChatOptions.Builder { private final DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); @Override - public ToolCallingChatOptions.Builder toolCallbacks(List toolCallbacks) { + public ToolCallingChatOptions.Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; } @Override - public ToolCallingChatOptions.Builder toolCallbacks(ToolCallback... toolCallbacks) { - this.options.setToolCallbacks(toolCallbacks); + public ToolCallingChatOptions.Builder toolCallbacks(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.setToolCallbacks(Arrays.asList(toolCallbacks)); return this; } @@ -342,7 +280,8 @@ public ToolCallingChatOptions.Builder tools(Set toolNames) { @Override public ToolCallingChatOptions.Builder tools(String... toolNames) { - this.options.setTools(toolNames); + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.setTools(Set.of(toolNames)); return this; } @@ -363,16 +302,15 @@ public ToolCallingChatOptions.Builder toolContext(String key, Object value) { } @Override - public ToolCallingChatOptions.Builder toolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) { - this.options.setToolCallReturnDirect(toolCallReturnDirect); + public ToolCallingChatOptions.Builder toolExecutionEnabled(@Nullable Boolean toolExecutionEnabled) { + this.options.setToolExecutionEnabled(toolExecutionEnabled); return this; } @Override @Deprecated // Use toolCallbacks() instead public ToolCallingChatOptions.Builder functionCallbacks(List functionCallbacks) { - Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); - return toolCallbacks(functionCallbacks.stream().map(ToolCallback.class::cast).toList()); + return toolCallbacks(functionCallbacks); } @Override @@ -395,9 +333,9 @@ public ToolCallingChatOptions.Builder function(String function) { } @Override - @Deprecated // Use toolCallReturnDirect() instead + @Deprecated // Use toolExecutionEnabled() instead public ToolCallingChatOptions.Builder proxyToolCalls(@Nullable Boolean proxyToolCalls) { - return toolCallReturnDirect(proxyToolCalls != null && proxyToolCalls); + return toolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java new file mode 100644 index 00000000000..f350b7c85ef --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -0,0 +1,249 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.DefaultToolCallExceptionConverter; +import org.springframework.ai.tool.execution.ToolCallExceptionConverter; +import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Default implementation of {@link ToolCallingManager}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class DefaultToolCallingManager implements ToolCallingManager { + + // @formatter:off + + private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY + = ObservationRegistry.NOOP; + + private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER + = new DelegatingToolCallbackResolver(List.of()); + + private static final ToolCallExceptionConverter DEFAULT_TOOL_CALL_EXCEPTION_CONVERTER + = DefaultToolCallExceptionConverter.builder().build(); + + // @formatter:on + + private final ObservationRegistry observationRegistry; + + private final ToolCallbackResolver toolCallbackResolver; + + private final ToolCallExceptionConverter toolCallExceptionConverter; + + public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver, + ToolCallExceptionConverter toolCallExceptionConverter) { + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null"); + Assert.notNull(toolCallExceptionConverter, "toolCallExceptionConverter cannot be null"); + + this.observationRegistry = observationRegistry; + this.toolCallbackResolver = toolCallbackResolver; + this.toolCallExceptionConverter = toolCallExceptionConverter; + } + + @Override + public List resolveToolDefinitions(ToolCallingChatOptions chatOptions) { + Assert.notNull(chatOptions, "chatOptions cannot be null"); + + List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); + for (String toolName : chatOptions.getTools()) { + ToolCallback toolCallback = toolCallbackResolver.resolve(toolName); + if (toolCallback == null) { + throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); + } + toolCallbacks.add(toolCallback); + } + + return toolCallbacks.stream().map(functionCallback -> { + if (functionCallback instanceof ToolCallback toolCallback) { + return toolCallback.getToolDefinition(); + } + else { + return ToolDefinition.builder() + .name(functionCallback.getName()) + .description(functionCallback.getDescription()) + .inputSchema(functionCallback.getInputTypeSchema()) + .build(); + } + }).toList(); + } + + @Override + public List executeToolCalls(Prompt prompt, ChatResponse chatResponse) { + Assert.notNull(prompt, "prompt cannot be null"); + Assert.notNull(chatResponse, "chatResponse cannot be null"); + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + if (toolCallGeneration.isEmpty()) { + throw new IllegalStateException("No tool call requested by the chat model"); + } + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + ToolContext toolContext = buildToolContext(prompt, assistantMessage); + + ToolResponseMessage toolMessageResponse = executeToolCall(prompt, assistantMessage, toolContext); + + return buildConversationHistoryAfterToolExecution(prompt.getInstructions(), assistantMessage, + toolMessageResponse); + } + + private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) { + Map toolContextMap = Map.of(); + + if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions + && !CollectionUtils.isEmpty(functionOptions.getToolContext())) { + toolContextMap = new HashMap<>(functionOptions.getToolContext()); + + List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); + messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), + assistantMessage.getToolCalls())); + + toolContextMap.put(ToolContext.TOOL_CALL_HISTORY, + buildConversationHistoryBeforeToolExecution(prompt, assistantMessage)); + } + + return new ToolContext(toolContextMap); + } + + private static List buildConversationHistoryBeforeToolExecution(Prompt prompt, + AssistantMessage assistantMessage) { + List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); + messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), + assistantMessage.getToolCalls())); + return messageHistory; + } + + /** + * Execute the tool call and return the response message. To ensure backward + * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are + * supported. + */ + private ToolResponseMessage executeToolCall(Prompt prompt, AssistantMessage assistantMessage, + ToolContext toolContext) { + List toolCallbacks = List.of(); + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + toolCallbacks = toolCallingChatOptions.getToolCallbacks(); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) { + toolCallbacks = functionOptions.getFunctionCallbacks(); + } + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + String toolName = toolCall.name(); + String toolInputArguments = toolCall.arguments(); + + FunctionCallback toolCallback = toolCallbacks.stream() + .filter(tool -> toolName.equals(tool.getName())) + .findFirst() + .orElse(toolCallbackResolver.resolve(toolName)); + + if (toolCallback == null) { + throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); + } + + String toolResult; + try { + toolResult = toolCallback.call(toolInputArguments, toolContext); + } + catch (ToolExecutionException ex) { + toolResult = toolCallExceptionConverter.convert(ex); + } + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult)); + } + + return new ToolResponseMessage(toolResponses, Map.of()); + } + + private List buildConversationHistoryAfterToolExecution(List previousMessages, + AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) { + List messages = new ArrayList<>(previousMessages); + messages.add(assistantMessage); + messages.add(toolResponseMessage); + return messages; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ObservationRegistry observationRegistry = DEFAULT_OBSERVATION_REGISTRY; + + private ToolCallbackResolver toolCallbackResolver = DEFAULT_TOOL_CALLBACK_RESOLVER; + + private ToolCallExceptionConverter toolCallExceptionConverter = DEFAULT_TOOL_CALL_EXCEPTION_CONVERTER; + + private Builder() { + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public Builder toolCallbackResolver(ToolCallbackResolver toolCallbackResolver) { + this.toolCallbackResolver = toolCallbackResolver; + return this; + } + + public Builder toolCallExceptionConverter(ToolCallExceptionConverter toolCallExceptionConverter) { + this.toolCallExceptionConverter = toolCallExceptionConverter; + return this; + } + + public DefaultToolCallingManager build() { + return new DefaultToolCallingManager(observationRegistry, toolCallbackResolver, toolCallExceptionConverter); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index 19ab72055c2..9bfad2a69f2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -16,11 +16,12 @@ package org.springframework.ai.model.tool; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import java.util.List; import java.util.Map; @@ -35,20 +36,17 @@ */ public interface ToolCallingChatOptions extends FunctionCallingOptions { - /** - * ToolCallbacks to be registered with the ChatModel. - */ - List getToolCallbacks(); + boolean DEFAULT_TOOL_EXECUTION_ENABLED = true; /** - * Set the ToolCallbacks to be registered with the ChatModel. + * ToolCallbacks to be registered with the ChatModel. */ - void setToolCallbacks(List toolCallbacks); + List getToolCallbacks(); /** * Set the ToolCallbacks to be registered with the ChatModel. */ - void setToolCallbacks(ToolCallback... toolCallbacks); + void setToolCallbacks(List toolCallbacks); /** * Names of the tools to register with the ChatModel. @@ -58,27 +56,20 @@ public interface ToolCallingChatOptions extends FunctionCallingOptions { /** * Set the names of the tools to register with the ChatModel. */ - void setTools(Set tools); + void setTools(Set toolNames); /** - * Set the names of the tools to register with the ChatModel. - */ - void setTools(String... tools); - - /** - * Whether the result of each tool call should be returned directly or passed back to - * the model. It can be overridden for each {@link ToolCallback} instance via - * {@link ToolMetadata#returnDirect()}. + * Whether the {@link ChatModel} is responsible for executing the tools requested by + * the model or if the tools should be executed directly by the caller. */ @Nullable - Boolean getToolCallReturnDirect(); + Boolean isToolExecutionEnabled(); /** - * Set whether the result of each tool call should be returned directly or passed back - * to the model. It can be overridden for each {@link ToolCallback} instance via - * {@link ToolMetadata#returnDirect()}. + * Set whether the {@link ChatModel} is responsible for executing the tools requested + * by the model or if the tools should be executed directly by the caller. */ - void setToolCallReturnDirect(@Nullable Boolean toolCallReturnDirect); + void setToolExecutionEnabled(@Nullable Boolean toolExecutionEnabled); /** * A builder to create a new {@link ToolCallingChatOptions} instance. @@ -95,12 +86,12 @@ interface Builder extends FunctionCallingOptions.Builder { /** * ToolCallbacks to be registered with the ChatModel. */ - Builder toolCallbacks(List functionCallbacks); + Builder toolCallbacks(List functionCallbacks); /** * ToolCallbacks to be registered with the ChatModel. */ - Builder toolCallbacks(ToolCallback... functionCallbacks); + Builder toolCallbacks(FunctionCallback... functionCallbacks); /** * Names of the tools to register with the ChatModel. @@ -113,11 +104,10 @@ interface Builder extends FunctionCallingOptions.Builder { Builder tools(String... toolNames); /** - * Whether the result of each tool call should be returned directly or passed back - * to the model. It can be overridden for each {@link ToolCallback} instance via - * {@link ToolMetadata#returnDirect()}. + * Whether the {@link ChatModel} is responsible for executing the tools requested + * by the model or if the tools should be executed directly by the caller. */ - Builder toolCallReturnDirect(@Nullable Boolean toolCallReturnDirect); + Builder toolExecutionEnabled(@Nullable Boolean toolExecutionEnabled); // FunctionCallingOptions.Builder methods @@ -144,7 +134,7 @@ interface Builder extends FunctionCallingOptions.Builder { Builder function(String function); @Override - @Deprecated // Use toolCallReturnDirect() instead + @Deprecated // Use toolExecutionEnabled() instead Builder proxyToolCalls(@Nullable Boolean proxyToolCalls); // ChatOptions.Builder methods @@ -178,4 +168,21 @@ interface Builder extends FunctionCallingOptions.Builder { } + static boolean isToolExecutionEnabled(ChatOptions chatOptions) { + Assert.notNull(chatOptions, "chatOptions cannot be null"); + boolean toolExecutionEnabled; + if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions + && toolCallingChatOptions.isToolExecutionEnabled() != null) { + toolExecutionEnabled = Boolean.TRUE.equals(toolCallingChatOptions.isToolExecutionEnabled()); + } + else if (chatOptions instanceof FunctionCallingOptions functionCallingOptions + && functionCallingOptions.getProxyToolCalls() != null) { + toolExecutionEnabled = Boolean.TRUE.equals(!functionCallingOptions.getProxyToolCalls()); + } + else { + toolExecutionEnabled = DEFAULT_TOOL_EXECUTION_ENABLED; + } + return toolExecutionEnabled; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingManager.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingManager.java new file mode 100644 index 00000000000..63e7ba5cbe6 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingManager.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.definition.ToolDefinition; + +import java.util.List; + +/** + * Service responsible for managing the tool calling process for a chat model. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ToolCallingManager { + + /** + * Resolve the tool definitions from the model's tool calling options. + */ + List resolveToolDefinitions(ToolCallingChatOptions chatOptions); + + /** + * Execute the tool calls requested by the model. + */ + List executeToolCalls(Prompt prompt, ChatResponse chatResponse); + + /** + * Create a default {@link ToolCallingManager} builder. + */ + static DefaultToolCallingManager.Builder builder() { + return DefaultToolCallingManager.builder(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java index 3c2e9056249..7623a4bc05f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java @@ -64,9 +64,9 @@ public Builder inputSchema(String inputSchema) { return this; } - public DefaultToolDefinition build() { + public ToolDefinition build() { if (!StringUtils.hasText(description)) { - description = ToolUtils.getToolDescriptionFromName(description); + description = ToolUtils.getToolDescriptionFromName(name); } return new DefaultToolDefinition(name, description, inputSchema); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java new file mode 100644 index 00000000000..30ec1947e87 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +import org.springframework.util.Assert; + +/** + * Default implementation of {@link ToolCallExceptionConverter}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class DefaultToolCallExceptionConverter implements ToolCallExceptionConverter { + + private static final boolean DEFAULT_ALWAYS_THROW = false; + + private final boolean alwaysThrow; + + public DefaultToolCallExceptionConverter(boolean alwaysThrow) { + this.alwaysThrow = alwaysThrow; + } + + @Override + public String convert(ToolExecutionException exception) { + Assert.notNull(exception, "exception cannot be null"); + if (alwaysThrow) { + throw exception; + } + return exception.getMessage(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private boolean alwaysThrow = DEFAULT_ALWAYS_THROW; + + public Builder alwaysThrow(boolean alwaysThrow) { + this.alwaysThrow = alwaysThrow; + return this; + } + + public DefaultToolCallExceptionConverter build() { + return new DefaultToolCallExceptionConverter(alwaysThrow); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallExceptionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallExceptionConverter.java new file mode 100644 index 00000000000..76975b3d62e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallExceptionConverter.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +/** + * A functional interface to convert a tool call exception to a String that can be sent + * back to the AI model. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@FunctionalInterface +public interface ToolCallExceptionConverter { + + /** + * Convert an exception thrown by a tool to a String that can be sent back to the AI + * model. + */ + String convert(ToolExecutionException exception); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/DefaultToolMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/DefaultToolMetadata.java index 2956a28b307..64bc6414069 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/DefaultToolMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/DefaultToolMetadata.java @@ -40,7 +40,7 @@ public Builder returnDirect(boolean returnDirect) { return this; } - public DefaultToolMetadata build() { + public ToolMetadata build() { return new DefaultToolMetadata(returnDirect); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java new file mode 100644 index 00000000000..afbf3e48231 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * A {@link ToolCallbackResolver} that delegates to a list of {@link ToolCallbackResolver} + * instances. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class DelegatingToolCallbackResolver implements ToolCallbackResolver { + + private final List toolCallbackResolvers; + + public DelegatingToolCallbackResolver(List toolCallbackResolvers) { + Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null"); + Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements"); + this.toolCallbackResolvers = toolCallbackResolvers; + } + + @Override + @Nullable + public ToolCallback resolve(String toolName) { + for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) { + ToolCallback toolCallback = toolCallbackResolver.resolve(toolName); + if (toolCallback != null) { + return toolCallback; + } + } + return null; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java new file mode 100644 index 00000000000..e5c51379023 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java @@ -0,0 +1,240 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import kotlin.jvm.functions.Function0; +import kotlin.jvm.functions.Function1; +import kotlin.jvm.functions.Function2; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.tool.util.ToolUtils; +import org.springframework.ai.util.json.JsonSchemaGenerator; +import org.springframework.ai.util.json.SchemaType; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Description; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.KotlinDetector; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * A Spring {@link ApplicationContext}-based implementation that provides a way to + * retrieve a bean from the Spring context and wrap it into a {@link ToolCallback}. + * + * @author Christian Tzolov + * @author Christopher Smith + * @author Sebastien Deleuze + * @author Thomas Vitale + * @since 1.0.0 + */ +public class SpringBeanToolCallbackResolver implements ToolCallbackResolver { + + private static final Map toolCallbacksCache = new HashMap<>(); + + private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA; + + private final GenericApplicationContext applicationContext; + + private final SchemaType schemaType; + + public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext, + @Nullable SchemaType schemaType) { + Assert.notNull(applicationContext, "applicationContext cannot be null"); + + this.applicationContext = applicationContext; + this.schemaType = schemaType != null ? schemaType : DEFAULT_SCHEMA_TYPE; + } + + @Override + public ToolCallback resolve(String toolName) { + Assert.hasText(toolName, "toolName cannot be null or empty"); + + ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName); + + if (resolvedToolCallback != null) { + return resolvedToolCallback; + } + + ResolvableType toolType = TypeResolverHelper.resolveBeanType(applicationContext, toolName); + ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType)) + ? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0); + + String toolDescription = resolveToolDescription(toolName, toolInputType.toClass()); + Object bean = applicationContext.getBean(toolName); + + resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean); + + toolCallbacksCache.put(toolName, resolvedToolCallback); + + return resolvedToolCallback; + } + + public SchemaType getSchemaType() { + return schemaType; + } + + private String resolveToolDescription(String toolName, Class toolInputType) { + Description descriptionAnnotation = applicationContext.findAnnotationOnBean(toolName, Description.class); + if (descriptionAnnotation != null && StringUtils.hasText(descriptionAnnotation.value())) { + return descriptionAnnotation.value(); + } + + JsonClassDescription jsonClassDescriptionAnnotation = toolInputType.getAnnotation(JsonClassDescription.class); + if (jsonClassDescriptionAnnotation != null && StringUtils.hasText(jsonClassDescriptionAnnotation.value())) { + return jsonClassDescriptionAnnotation.value(); + } + + return ToolUtils.getToolDescriptionFromName(toolName); + } + + private ToolCallback buildToolCallback(String toolName, ResolvableType toolType, ResolvableType toolInputType, + String toolDescription, Object bean) { + if (KotlinDetector.isKotlinPresent()) { + if (KotlinDelegate.isKotlinFunction(toolType.toClass())) { + return FunctionToolCallback.builder(toolName, KotlinDelegate.wrapKotlinFunction(bean)) + .description(toolDescription) + .inputSchema(generateSchema(toolInputType)) + .inputType(ParameterizedTypeReference.forType(toolInputType.getType())) + .build(); + } + if (KotlinDelegate.isKotlinBiFunction(toolType.toClass())) { + return FunctionToolCallback.builder(toolName, KotlinDelegate.wrapKotlinBiFunction(bean)) + .description(toolDescription) + .inputSchema(generateSchema(toolInputType)) + .inputType(ParameterizedTypeReference.forType(toolInputType.getType())) + .build(); + } + if (KotlinDelegate.isKotlinSupplier(toolType.toClass())) { + return FunctionToolCallback.builder(toolName, KotlinDelegate.wrapKotlinSupplier(bean)) + .description(toolDescription) + .inputSchema(generateSchema(toolInputType)) + .inputType(ParameterizedTypeReference.forType(toolInputType.getType())) + .build(); + } + } + + if (bean instanceof Function function) { + return FunctionToolCallback.builder(toolName, function) + .description(toolDescription) + .inputSchema(generateSchema(toolInputType)) + .inputType(ParameterizedTypeReference.forType(toolInputType.getType())) + .build(); + } + if (bean instanceof BiFunction) { + return FunctionToolCallback.builder(toolName, (BiFunction) bean) + .description(toolDescription) + .inputSchema(generateSchema(toolInputType)) + .inputType(ParameterizedTypeReference.forType(toolInputType.getType())) + .build(); + } + if (bean instanceof Supplier supplier) { + return FunctionToolCallback.builder(toolName, supplier) + .description(toolDescription) + .inputSchema(generateSchema(toolInputType)) + .inputType(ParameterizedTypeReference.forType(toolInputType.getType())) + .build(); + } + if (bean instanceof Consumer consumer) { + return FunctionToolCallback.builder(toolName, consumer) + .description(toolDescription) + .inputSchema(generateSchema(toolInputType)) + .inputType(ParameterizedTypeReference.forType(toolInputType.getType())) + .build(); + } + + throw new IllegalStateException( + "Unsupported bean type. Support types: Function, BiFunction, Supplier, Consumer."); + } + + private String generateSchema(ResolvableType toolInputType) { + if (schemaType == SchemaType.OPEN_API_SCHEMA) { + return JsonSchemaGenerator.generateForType(toolInputType.getType(), + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES); + } + return JsonSchemaGenerator.generateForType(toolInputType.getType()); + } + + private static final class KotlinDelegate { + + public static boolean isKotlinSupplier(Class clazz) { + return Function0.class.isAssignableFrom(clazz); + } + + @SuppressWarnings("unchecked") + public static Supplier wrapKotlinSupplier(Object bean) { + return () -> ((Function0) bean).invoke(); + } + + public static boolean isKotlinFunction(Class clazz) { + return Function1.class.isAssignableFrom(clazz); + } + + @SuppressWarnings("unchecked") + public static Function wrapKotlinFunction(Object bean) { + return t -> ((Function1) bean).invoke(t); + } + + public static boolean isKotlinBiFunction(Class clazz) { + return Function2.class.isAssignableFrom(clazz); + } + + @SuppressWarnings("unchecked") + public static BiFunction wrapKotlinBiFunction(Object bean) { + return (t, u) -> ((Function2) bean).invoke(t, u); + } + + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private GenericApplicationContext applicationContext; + + private SchemaType schemaType; + + public Builder applicationContext(GenericApplicationContext applicationContext) { + this.applicationContext = applicationContext; + return this; + } + + public Builder schemaType(SchemaType schemaType) { + this.schemaType = schemaType; + return this; + } + + public SpringBeanToolCallbackResolver build() { + return new SpringBeanToolCallbackResolver(applicationContext, schemaType); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java new file mode 100644 index 00000000000..24d0d14b32c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import org.springframework.ai.tool.ToolCallback; +import org.springframework.util.Assert; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A {@link ToolCallbackResolver} that resolves tool callbacks from a static registry. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class StaticToolCallbackResolver implements ToolCallbackResolver { + + private final Map toolCallbacks = new HashMap<>(); + + public StaticToolCallbackResolver(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + + toolCallbacks + .forEach(toolCallback -> this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback)); + } + + @Override + public ToolCallback resolve(String toolName) { + return toolCallbacks.get(toolName); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java new file mode 100644 index 00000000000..8efa01e9ccd --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; + +/** + * A resolver for {@link ToolCallback} instances. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ToolCallbackResolver { + + /** + * Resolve the {@link ToolCallback} for the given tool name. + */ + @Nullable + ToolCallback resolve(String toolName); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/TypeResolverHelper.java similarity index 98% rename from spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java rename to spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/TypeResolverHelper.java index 33fe38e27bf..ebdc3c4d878 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/TypeResolverHelper.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.model.function; +package org.springframework.ai.tool.resolution; import java.lang.reflect.Method; import java.lang.reflect.Modifier; @@ -45,7 +45,7 @@ * @author Christian Tzolov * @author Sebastien Dekeuze */ -public abstract class TypeResolverHelper { +public final class TypeResolverHelper { /** * Returns the input class of a given Consumer class. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/package-info.java new file mode 100644 index 00000000000..2ec725256de --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.resolution; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/json/SchemaType.java b/spring-ai-core/src/main/java/org/springframework/ai/util/json/SchemaType.java new file mode 100644 index 00000000000..9a9c988ff56 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/json/SchemaType.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util.json; + +/** + * The type of schema to generate for a given Java type. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public enum SchemaType { + + /** + * JSON schema. + */ + JSON_SCHEMA, + + /** + * Open API schema. + */ + OPEN_API_SCHEMA; + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java new file mode 100644 index 00000000000..2cdd976a815 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.model; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ChatResponse}. + * + * @author Thomas Vitale + */ +class ChatResponseTests { + + @Test + void whenToolCallsArePresentThenReturnTrue() { + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .build(); + assertThat(chatResponse.hasToolCalls()).isTrue(); + } + + @Test + void whenNoToolCallsArePresentThenReturnFalse() { + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("Result")))) + .build(); + assertThat(chatResponse.hasToolCalls()).isFalse(); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java index 154a31fa28f..e7a50dd71eb 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -17,7 +17,6 @@ package org.springframework.ai.model.tool; import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; @@ -32,6 +31,8 @@ /** * Unit tests for {@link DefaultToolCallingChatOptions}. + * + * @author Thomas Vitale */ class DefaultToolCallingChatOptionsTests { @@ -40,7 +41,7 @@ void setToolCallbacksShouldStoreToolCallbacks() { DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); ToolCallback callback1 = mock(ToolCallback.class); ToolCallback callback2 = mock(ToolCallback.class); - List callbacks = List.of(callback1, callback2); + List callbacks = List.of(callback1, callback2); options.setToolCallbacks(callbacks); @@ -53,7 +54,7 @@ void setToolCallbacksWithVarargsShouldStoreToolCallbacks() { ToolCallback callback1 = mock(ToolCallback.class); ToolCallback callback2 = mock(ToolCallback.class); - options.setToolCallbacks(callback1, callback2); + options.setToolCallbacks(List.of(callback1, callback2)); assertThat(options.getToolCallbacks()).hasSize(2).containsExactly(callback1, callback2); } @@ -62,8 +63,7 @@ void setToolCallbacksWithVarargsShouldStoreToolCallbacks() { void setToolCallbacksShouldRejectNullList() { DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); - assertThatThrownBy(() -> options.setToolCallbacks((List) null)) - .isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> options.setToolCallbacks(null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("toolCallbacks cannot be null"); } @@ -81,7 +81,7 @@ void setToolsShouldStoreTools() { void setToolsWithVarargsShouldStoreTools() { DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); - options.setTools("tool1", "tool2"); + options.setTools(Set.of("tool1", "tool2")); assertThat(options.getTools()).hasSize(2).containsExactlyInAnyOrder("tool1", "tool2"); } @@ -139,7 +139,7 @@ void copyShouldCreateNewInstanceWithSameValues() { original.setToolCallbacks(List.of(callback)); original.setTools(Set.of("tool1")); original.setToolContext(Map.of("key", "value")); - original.setToolCallReturnDirect(true); + original.setToolExecutionEnabled(true); original.setModel("gpt-4"); original.setTemperature(0.7); @@ -149,7 +149,7 @@ void copyShouldCreateNewInstanceWithSameValues() { assertThat(c.getToolCallbacks()).isEqualTo(original.getToolCallbacks()); assertThat(c.getTools()).isEqualTo(original.getTools()); assertThat(c.getToolContext()).isEqualTo(original.getToolContext()); - assertThat(c.getToolCallReturnDirect()).isEqualTo(original.getToolCallReturnDirect()); + assertThat(c.isToolExecutionEnabled()).isEqualTo(original.isToolExecutionEnabled()); assertThat(c.getModel()).isEqualTo(original.getModel()); assertThat(c.getTemperature()).isEqualTo(original.getTemperature()); }); @@ -170,45 +170,6 @@ void gettersShouldReturnImmutableCollections() { .isInstanceOf(UnsupportedOperationException.class); } - @Test - void mergeShouldCombineWithNonToolCallingChatOptions() { - DefaultToolCallingChatOptions original = new DefaultToolCallingChatOptions(); - original.setToolCallbacks(List.of(mock(ToolCallback.class))); - original.setTools(Set.of("tool1")); - original.setModel("gpt-3.5"); - - ChatOptions toMerge = ChatOptions.builder().model("gpt-4").build(); - - ToolCallingChatOptions merged = original.merge(toMerge); - - assertThat(merged.getToolCallbacks()).hasSize(1); - assertThat(merged.getTools()).containsExactly("tool1"); - assertThat(merged.getModel()).isEqualTo("gpt-4"); - } - - @Test - void mergeShouldCombineOptionsCorrectly() { - DefaultToolCallingChatOptions original = new DefaultToolCallingChatOptions(); - original.setToolCallbacks(List.of(mock(ToolCallback.class))); - original.setTools(Set.of("tool1")); - original.setToolContext(Map.of("key1", "value1")); - original.setModel("gpt-3.5"); - - DefaultToolCallingChatOptions toMerge = new DefaultToolCallingChatOptions(); - toMerge.setToolCallbacks(List.of(mock(ToolCallback.class))); - toMerge.setTools(Set.of("tool2")); - toMerge.setToolContext(Map.of("key2", "value2")); - toMerge.setTemperature(0.8); - - ToolCallingChatOptions merged = original.merge(toMerge); - - assertThat(merged.getToolCallbacks()).hasSize(2); - assertThat(merged.getTools()).containsExactlyInAnyOrder("tool1", "tool2"); - assertThat(merged.getToolContext()).containsEntry("key1", "value1").containsEntry("key2", "value2"); - assertThat(merged.getModel()).isEqualTo("gpt-3.5"); - assertThat(merged.getTemperature()).isEqualTo(0.8); - } - @Test void builderShouldCreateOptionsWithAllProperties() { ToolCallback callback = mock(ToolCallback.class); @@ -218,7 +179,7 @@ void builderShouldCreateOptionsWithAllProperties() { .toolCallbacks(List.of(callback)) .tools(Set.of("tool1")) .toolContext(context) - .toolCallReturnDirect(true) + .toolExecutionEnabled(true) .model("gpt-4") .temperature(0.7) .maxTokens(100) @@ -233,7 +194,7 @@ void builderShouldCreateOptionsWithAllProperties() { assertThat(o.getToolCallbacks()).containsExactly(callback); assertThat(o.getTools()).containsExactly("tool1"); assertThat(o.getToolContext()).isEqualTo(context); - assertThat(o.getToolCallReturnDirect()).isTrue(); + assertThat(o.isToolExecutionEnabled()).isTrue(); assertThat(o.getModel()).isEqualTo("gpt-4"); assertThat(o.getTemperature()).isEqualTo(0.7); assertThat(o.getMaxTokens()).isEqualTo(100); @@ -258,11 +219,11 @@ void builderShouldSupportToolContextAddition() { @Test void deprecatedMethodsShouldWorkCorrectly() { DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); - FunctionCallback callback = mock(FunctionCallback.class); - assertThatThrownBy(() -> options.setFunctionCallbacks(List.of(callback))) - .isInstanceOf(UnsupportedOperationException.class) - .hasMessage("Not supported. Call setToolCallbacks instead."); + FunctionCallback callback1 = mock(FunctionCallback.class); + ToolCallback callback2 = mock(ToolCallback.class); + options.setFunctionCallbacks(List.of(callback1, callback2)); + assertThat(options.getFunctionCallbacks()).hasSize(2); options.setTools(Set.of("tool1")); assertThat(options.getFunctions()).containsExactly("tool1"); @@ -270,11 +231,11 @@ void deprecatedMethodsShouldWorkCorrectly() { options.setFunctions(Set.of("function1")); assertThat(options.getTools()).containsExactly("function1"); - options.setToolCallReturnDirect(true); - assertThat(options.getProxyToolCalls()).isTrue(); + options.setToolExecutionEnabled(true); + assertThat(options.getProxyToolCalls()).isFalse(); options.setProxyToolCalls(true); - assertThat(options.getToolCallReturnDirect()).isTrue(); + assertThat(options.isToolExecutionEnabled()).isFalse(); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java new file mode 100644 index 00000000000..5529b8f7b6c --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java @@ -0,0 +1,259 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallExceptionConverter; +import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link DefaultToolCallingManager}. + * + * @author Thomas Vitale + */ +class DefaultToolCallingManagerTests { + + // BUILD + + @Test + void whenDefaultArgumentsThenReturn() { + DefaultToolCallingManager defaultToolExecutor = DefaultToolCallingManager.builder().build(); + assertThat(defaultToolExecutor).isNotNull(); + } + + @Test + void whenObservationRegistryIsNullThenThrow() { + assertThatThrownBy(() -> DefaultToolCallingManager.builder() + .observationRegistry(null) + .toolCallbackResolver(mock(ToolCallbackResolver.class)) + .toolCallExceptionConverter(mock(ToolCallExceptionConverter.class)) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("observationRegistry cannot be null"); + } + + @Test + void whenToolCallbackResolverIsNullThenThrow() { + assertThatThrownBy(() -> DefaultToolCallingManager.builder() + .observationRegistry(mock(ObservationRegistry.class)) + .toolCallbackResolver(null) + .toolCallExceptionConverter(mock(ToolCallExceptionConverter.class)) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("toolCallbackResolver cannot be null"); + } + + @Test + void whenToolCallExceptionConverterIsNullThenThrow() { + assertThatThrownBy(() -> DefaultToolCallingManager.builder() + .observationRegistry(mock(ObservationRegistry.class)) + .toolCallbackResolver(mock(ToolCallbackResolver.class)) + .toolCallExceptionConverter(null) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolCallExceptionConverter cannot be null"); + } + + // RESOLVE TOOL DEFINITIONS + + @Test + void whenChatOptionsIsNullThenThrow() { + DefaultToolCallingManager defaultToolExecutor = DefaultToolCallingManager.builder().build(); + assertThatThrownBy(() -> defaultToolExecutor.resolveToolDefinitions(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatOptions cannot be null"); + } + + @Test + void whenToolCallbackExistsThenResolve() { + ToolCallback toolCallback = new TestToolCallback("toolA"); + ToolCallbackResolver toolCallbackResolver = new StaticToolCallbackResolver(List.of(toolCallback)); + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() + .toolCallbackResolver(toolCallbackResolver) + .build(); + + List toolDefinitions = toolCallingManager + .resolveToolDefinitions(ToolCallingChatOptions.builder().tools("toolA").build()); + + assertThat(toolDefinitions).containsExactly(toolCallback.getToolDefinition()); + } + + @Test + void whenToolCallbackDoesNotExistThenThrow() { + ToolCallbackResolver toolCallbackResolver = new StaticToolCallbackResolver(List.of()); + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() + .toolCallbackResolver(toolCallbackResolver) + .build(); + + assertThatThrownBy(() -> toolCallingManager + .resolveToolDefinitions(ToolCallingChatOptions.builder().tools("toolB").build())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("No ToolCallback found for tool name: toolB"); + } + + // EXECUTE TOOL CALLS + + @Test + void whenPromptIsNullThenThrow() { + DefaultToolCallingManager defaultToolExecutor = DefaultToolCallingManager.builder().build(); + assertThatThrownBy(() -> defaultToolExecutor.executeToolCalls(null, mock(ChatResponse.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("prompt cannot be null"); + } + + @Test + void whenChatResponseIsNullThenThrow() { + DefaultToolCallingManager defaultToolExecutor = DefaultToolCallingManager.builder().build(); + assertThatThrownBy(() -> defaultToolExecutor.executeToolCalls(mock(Prompt.class), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatResponse cannot be null"); + } + + @Test + void whenNoToolCallInChatResponseThenThrow() { + DefaultToolCallingManager defaultToolExecutor = DefaultToolCallingManager.builder().build(); + assertThatThrownBy(() -> defaultToolExecutor.executeToolCalls(mock(Prompt.class), + ChatResponse.builder().generations(List.of()).build())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("No tool call requested by the chat model"); + } + + @Test + void whenSingleToolCallInChatResponseThenExecute() { + ToolCallback toolCallback = new TestToolCallback("toolA"); + ToolCallbackResolver toolCallbackResolver = new StaticToolCallbackResolver(List.of(toolCallback)); + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() + .toolCallbackResolver(toolCallbackResolver) + .build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"))); + + List toolCallHistory = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolCallHistory).contains(expectedToolResponse); + } + + @Test + void whenMultipleToolCallsInChatResponseThenExecute() { + ToolCallback toolCallbackA = new TestToolCallback("toolA"); + ToolCallback toolCallbackB = new TestToolCallback("toolB"); + ToolCallbackResolver toolCallbackResolver = new StaticToolCallbackResolver( + List.of(toolCallbackA, toolCallbackB)); + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() + .toolCallbackResolver(toolCallbackResolver) + .build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"), + new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!"))); + + List toolCallHistory = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolCallHistory).contains(expectedToolResponse); + } + + @Test + void whenToolCallWithExceptionThenReturnError() { + ToolCallback toolCallback = new FailingToolCallback("toolC"); + ToolCallbackResolver toolCallbackResolver = new StaticToolCallbackResolver(List.of(toolCallback)); + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() + .toolCallbackResolver(toolCallbackResolver) + .build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolC", "toolC", "You failed this city!"))); + + List toolCallHistory = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolCallHistory).contains(expectedToolResponse); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + + static class FailingToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public FailingToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + throw new ToolExecutionException(toolDefinition, new IllegalStateException("You failed this city!")); + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java new file mode 100644 index 00000000000..26647933c1e --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.tool; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.function.FunctionCallingOptions; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +/** + * Unit tests for {@link ToolCallingChatOptions}. + * + * @author Thomas Vitale + */ +class ToolCallingChatOptionsTests { + + @Test + void whenToolCallingChatOptionsAndExecutionEnabledTrue() { + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + options.setToolExecutionEnabled(true); + assertThat(ToolCallingChatOptions.isToolExecutionEnabled(options)).isTrue(); + } + + @Test + void whenToolCallingChatOptionsAndExecutionEnabledFalse() { + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + options.setToolExecutionEnabled(false); + assertThat(ToolCallingChatOptions.isToolExecutionEnabled(options)).isFalse(); + } + + @Test + void whenToolCallingChatOptionsAndExecutionEnabledDefault() { + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + assertThat(ToolCallingChatOptions.isToolExecutionEnabled(options)).isTrue(); + } + + @Test + void whenFunctionCallingOptionsAndExecutionEnabledTrue() { + FunctionCallingOptions options = FunctionCallingOptions.builder().build(); + options.setProxyToolCalls(false); + assertThat(ToolCallingChatOptions.isToolExecutionEnabled(options)).isTrue(); + } + + @Test + void whenFunctionCallingOptionsAndExecutionEnabledFalse() { + FunctionCallingOptions options = FunctionCallingOptions.builder().build(); + options.setProxyToolCalls(true); + assertThat(ToolCallingChatOptions.isToolExecutionEnabled(options)).isFalse(); + } + + @Test + void whenFunctionCallingOptionsAndExecutionEnabledDefault() { + FunctionCallingOptions options = FunctionCallingOptions.builder().build(); + assertThat(ToolCallingChatOptions.isToolExecutionEnabled(options)).isTrue(); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverterTests.java new file mode 100644 index 00000000000..f44b2239934 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverterTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link DefaultToolCallExceptionConverter}. + * + * @author Thomas Vitale + */ +class DefaultToolCallExceptionConverterTests { + + @Test + void whenDefaultThenReturnMessage() { + ToolCallExceptionConverter converter = DefaultToolCallExceptionConverter.builder().build(); + ToolExecutionException exception = new ToolExecutionException(generateTestDefinition(), + new RuntimeException("Test")); + assertThat(converter.convert(exception)).isEqualTo("Test"); + } + + @Test + void whenNotAlwaysThrowThenReturnMessage() { + ToolCallExceptionConverter converter = DefaultToolCallExceptionConverter.builder().alwaysThrow(false).build(); + ToolExecutionException exception = new ToolExecutionException(generateTestDefinition(), + new RuntimeException("Test")); + assertThat(converter.convert(exception)).isEqualTo("Test"); + } + + @Test + void whenAlwaysThrowThenThrow() { + ToolCallExceptionConverter converter = DefaultToolCallExceptionConverter.builder().alwaysThrow(true).build(); + ToolExecutionException exception = new ToolExecutionException(generateTestDefinition(), + new RuntimeException("Test")); + assertThatThrownBy(() -> converter.convert(exception)).isInstanceOf(ToolExecutionException.class); + } + + private ToolDefinition generateTestDefinition() { + return DefaultToolDefinition.builder().name("test").inputSchema("{}").build(); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolverTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolverTests.java new file mode 100644 index 00000000000..597247c7a47 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolverTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link DelegatingToolCallbackResolver}. + * + * @author Thomas Vitale + */ +class DelegatingToolCallbackResolverTests { + + @Test + void whenToolCallbackResolversAreNullThenThrowException() { + assertThatThrownBy(() -> new DelegatingToolCallbackResolver(null)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void whenToolCallbackResolversContainNullElementsThenThrowException() { + var toolCallbackResolvers = new ArrayList(); + toolCallbackResolvers.add(null); + assertThatThrownBy(() -> new DelegatingToolCallbackResolver(toolCallbackResolvers)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void whenToolCallbacksAreProvidedThenResolveToolCallback() { + ToolCallback toolCallback = mock(ToolCallback.class); + when(toolCallback.getToolDefinition()) + .thenReturn(ToolDefinition.builder().name("myTool").inputSchema("{}").build()); + StaticToolCallbackResolver staticToolCallbackResolver = new StaticToolCallbackResolver(List.of(toolCallback)); + + DelegatingToolCallbackResolver delegatingToolCallbackResolver = new DelegatingToolCallbackResolver( + List.of(staticToolCallbackResolver)); + + assertThat(delegatingToolCallbackResolver.resolve("myTool")).isEqualTo(toolCallback); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverTests.java new file mode 100644 index 00000000000..f6b5049bfe9 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverTests.java @@ -0,0 +1,175 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.util.json.SchemaType; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; +import org.springframework.context.support.GenericApplicationContext; + +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link SpringBeanToolCallbackResolver}. + * + * @author Thomas Vitale + */ +class SpringBeanToolCallbackResolverTests { + + @Test + void whenApplicationContextIsNullThenThrow() { + assertThatThrownBy(() -> new SpringBeanToolCallbackResolver(null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("applicationContext cannot be null"); + + assertThatThrownBy(() -> SpringBeanToolCallbackResolver.builder().build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("applicationContext cannot be null"); + } + + @Test + void whenSchemaTypeIsNullThenUseDefault() { + SpringBeanToolCallbackResolver resolver = new SpringBeanToolCallbackResolver(new GenericApplicationContext(), + null); + assertThat(resolver.getSchemaType()).isEqualTo(SchemaType.JSON_SCHEMA); + + SpringBeanToolCallbackResolver resolver2 = SpringBeanToolCallbackResolver.builder() + .applicationContext(new GenericApplicationContext()) + .build(); + assertThat(resolver2.getSchemaType()).isEqualTo(SchemaType.JSON_SCHEMA); + } + + @Test + void whenSchemaTypeIsNotNullThenUseIt() { + SchemaType schemaType = SchemaType.OPEN_API_SCHEMA; + SpringBeanToolCallbackResolver resolver = new SpringBeanToolCallbackResolver(new GenericApplicationContext(), + schemaType); + assertThat(resolver.getSchemaType()).isEqualTo(schemaType); + + SpringBeanToolCallbackResolver resolver2 = SpringBeanToolCallbackResolver.builder() + .applicationContext(new GenericApplicationContext()) + .schemaType(schemaType) + .build(); + assertThat(resolver2.getSchemaType()).isEqualTo(schemaType); + } + + @Test + void whenRequiredArgumentsAreProvidedThenCreateInstance() { + GenericApplicationContext applicationContext = new GenericApplicationContext(); + SchemaType schemaType = SchemaType.OPEN_API_SCHEMA; + SpringBeanToolCallbackResolver resolver = new SpringBeanToolCallbackResolver(applicationContext, schemaType); + assertThat(resolver).isNotNull(); + + SpringBeanToolCallbackResolver resolver2 = SpringBeanToolCallbackResolver.builder() + .applicationContext(applicationContext) + .schemaType(schemaType) + .build(); + assertThat(resolver2).isNotNull(); + } + + @Test + void whenToolCallbackWithVoidConsumerIsResolvedThenReturnIt() { + GenericApplicationContext applicationContext = new AnnotationConfigApplicationContext(Functions.class); + SpringBeanToolCallbackResolver resolver = new SpringBeanToolCallbackResolver(applicationContext, + SchemaType.JSON_SCHEMA); + ToolCallback resolvedToolCallback = resolver.resolve(Functions.WELCOME_TOOL_NAME); + assertThat(resolvedToolCallback).isNotNull(); + assertThat(resolvedToolCallback.getToolDefinition().name()).isEqualTo(Functions.WELCOME_TOOL_NAME); + assertThat(resolvedToolCallback.getToolDefinition().description()) + .isEqualTo(Functions.WELCOME_TOOL_DESCRIPTION); + } + + @Test + void whenToolCallbackWithConsumerIsResolvedThenReturnIt() { + GenericApplicationContext applicationContext = new AnnotationConfigApplicationContext(Functions.class); + SpringBeanToolCallbackResolver resolver = new SpringBeanToolCallbackResolver(applicationContext, + SchemaType.JSON_SCHEMA); + ToolCallback resolvedToolCallback = resolver.resolve(Functions.WELCOME_USER_TOOL_NAME); + assertThat(resolvedToolCallback).isNotNull(); + assertThat(resolvedToolCallback.getToolDefinition().name()).isEqualTo(Functions.WELCOME_USER_TOOL_NAME); + assertThat(resolvedToolCallback.getToolDefinition().description()) + .isEqualTo(Functions.WELCOME_USER_TOOL_DESCRIPTION); + } + + @Test + void whenToolCallbackWithFunctionIsResolvedThenReturnIt() { + GenericApplicationContext applicationContext = new AnnotationConfigApplicationContext(Functions.class); + SpringBeanToolCallbackResolver resolver = new SpringBeanToolCallbackResolver(applicationContext, + SchemaType.JSON_SCHEMA); + ToolCallback resolvedToolCallback = resolver.resolve(Functions.BOOKS_BY_AUTHOR_TOOL_NAME); + assertThat(resolvedToolCallback).isNotNull(); + assertThat(resolvedToolCallback.getToolDefinition().name()).isEqualTo(Functions.BOOKS_BY_AUTHOR_TOOL_NAME); + assertThat(resolvedToolCallback.getToolDefinition().description()) + .isEqualTo(Functions.BOOKS_BY_AUTHOR_TOOL_DESCRIPTION); + } + + @Configuration(proxyBeanMethods = false) + static class Functions { + + public static final String BOOKS_BY_AUTHOR_TOOL_NAME = "booksByAuthor"; + + public static final String BOOKS_BY_AUTHOR_TOOL_DESCRIPTION = "Get the list of books written by the given author available in the library"; + + public static final String WELCOME_TOOL_NAME = "welcome"; + + public static final String WELCOME_TOOL_DESCRIPTION = "Welcome users to the library"; + + public static final String WELCOME_USER_TOOL_NAME = "welcomeUser"; + + public static final String WELCOME_USER_TOOL_DESCRIPTION = "Welcome a specific user to the library"; + + @Bean(WELCOME_TOOL_NAME) + @Description(WELCOME_TOOL_DESCRIPTION) + Consumer welcome() { + return (input) -> { + }; + } + + @Bean(WELCOME_USER_TOOL_NAME) + @Description(WELCOME_USER_TOOL_DESCRIPTION) + Consumer welcomeUser() { + return user -> { + }; + } + + @Bean(BOOKS_BY_AUTHOR_TOOL_NAME) + @Description(BOOKS_BY_AUTHOR_TOOL_DESCRIPTION) + Function> booksByAuthor() { + return author -> List.of(new Book("Book 1", author.name()), new Book("Book 2", author.name())); + } + + public record User(String name) { + } + + public record Author(String name) { + } + + public record Book(String title, String author) { + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StandaloneWeatherFunction.java similarity index 76% rename from spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java rename to spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StandaloneWeatherFunction.java index 69b60b35519..aa124fa3f89 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/StandaloneWeatherFunction.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StandaloneWeatherFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -package org.springframework.ai.model.function; +package org.springframework.ai.tool.resolution; import java.util.function.Function; -import org.springframework.ai.model.function.TypeResolverHelperIT.WeatherRequest; -import org.springframework.ai.model.function.TypeResolverHelperIT.WeatherResponse; +import org.springframework.ai.tool.resolution.TypeResolverHelperIT.WeatherRequest; +import org.springframework.ai.tool.resolution.TypeResolverHelperIT.WeatherResponse; /** * @author Christian Tzolov diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java new file mode 100644 index 00000000000..c7398a9b131 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link StaticToolCallbackResolver}. + * + * @author Thomas Vitale + */ +class StaticToolCallbackResolverTests { + + @Test + void whenToolCallbacksAreNullThenThrowException() { + assertThatThrownBy(() -> new StaticToolCallbackResolver(null)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + void whenToolCallbacksContainNullElementsThenThrowException() { + var toolCallbacks = new ArrayList(); + toolCallbacks.add(null); + assertThatThrownBy(() -> new StaticToolCallbackResolver(toolCallbacks)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void whenToolCallbacksAreEmptyThenReturn() { + StaticToolCallbackResolver resolver = new StaticToolCallbackResolver(List.of()); + assertThat(resolver).isNotNull(); + } + + @Test + void whenToolCallbacksAreProvidedThenResolveToolCallback() { + ToolCallback toolCallback = mock(ToolCallback.class); + when(toolCallback.getToolDefinition()) + .thenReturn(ToolDefinition.builder().name("myTool").inputSchema("{}").build()); + StaticToolCallbackResolver resolver = new StaticToolCallbackResolver(List.of(toolCallback)); + assertThat(resolver.resolve("myTool")).isEqualTo(toolCallback); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/TypeResolverHelperIT.java similarity index 92% rename from spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java rename to spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/TypeResolverHelperIT.java index e2cb9d48753..fa02472b4bd 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/TypeResolverHelperIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,13 @@ * limitations under the License. */ -package org.springframework.ai.model.function; +package org.springframework.ai.tool.resolution; import java.util.function.Consumer; import java.util.function.Function; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; @@ -71,8 +70,8 @@ public WeatherResponse apply(WeatherRequest weatherRequest) { } @Configuration - @ComponentScan({ "org.springframework.ai.model.function.config", - "org.springframework.ai.model.function.component" }) + @ComponentScan({ "org.springframework.ai.tool.resolution.config", + "org.springframework.ai.tool.resolution.component" }) public static class TypeResolverHelperConfiguration { @Bean diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/TypeResolverHelperTests.java similarity index 91% rename from spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java rename to spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/TypeResolverHelperTests.java index 45e94a1e67f..a5185bee6b7 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/TypeResolverHelperTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.model.function; +package org.springframework.ai.tool.resolution; import java.util.function.Consumer; import java.util.function.Function; @@ -25,9 +25,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; import org.junit.jupiter.api.Test; - -import org.springframework.ai.model.function.TypeResolverHelperTests.MockWeatherService.Request; -import org.springframework.ai.model.function.TypeResolverHelperTests.MockWeatherService.Response; +import org.springframework.ai.tool.resolution.TypeResolverHelperTests.MockWeatherService.Request; +import org.springframework.ai.tool.resolution.TypeResolverHelperTests.MockWeatherService.Response; import static org.assertj.core.api.Assertions.assertThat; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/component/ComponentWeatherFunction.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/component/ComponentWeatherFunction.java similarity index 76% rename from spring-ai-core/src/test/java/org/springframework/ai/model/function/component/ComponentWeatherFunction.java rename to spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/component/ComponentWeatherFunction.java index 704606d531f..2a96dce139e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/component/ComponentWeatherFunction.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/component/ComponentWeatherFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -package org.springframework.ai.model.function.component; +package org.springframework.ai.tool.resolution.component; import java.util.function.Function; -import org.springframework.ai.model.function.TypeResolverHelperIT.WeatherRequest; -import org.springframework.ai.model.function.TypeResolverHelperIT.WeatherResponse; +import org.springframework.ai.tool.resolution.TypeResolverHelperIT.WeatherRequest; +import org.springframework.ai.tool.resolution.TypeResolverHelperIT.WeatherResponse; import org.springframework.stereotype.Component; /** diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/config/TypeResolverHelperConfiguration.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/config/TypeResolverHelperConfiguration.java similarity index 82% rename from spring-ai-core/src/test/java/org/springframework/ai/model/function/config/TypeResolverHelperConfiguration.java rename to spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/config/TypeResolverHelperConfiguration.java index 588db50b51d..7632f317f12 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/config/TypeResolverHelperConfiguration.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/config/TypeResolverHelperConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.model.function.config; +package org.springframework.ai.tool.resolution.config; -import org.springframework.ai.model.function.StandaloneWeatherFunction; +import org.springframework.ai.tool.resolution.StandaloneWeatherFunction; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/StandaloneWeatherKotlinFunction.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/StandaloneWeatherKotlinFunction.kt deleted file mode 100644 index 97ccdbe5c4f..00000000000 --- a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/StandaloneWeatherKotlinFunction.kt +++ /dev/null @@ -1,8 +0,0 @@ -package org.springframework.ai.model.function - -class StandaloneWeatherKotlinFunction : Function1 { - - override fun invoke(weatherRequest: WeatherRequest): WeatherResponse { - return WeatherResponse(42.0f) - } -} diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/FunctionCallbackExtensionsTests.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/FunctionCallbackExtensionsTests.kt similarity index 82% rename from spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/FunctionCallbackExtensionsTests.kt rename to spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/FunctionCallbackExtensionsTests.kt index 2ab01f2f62f..b32e0a4db25 100644 --- a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/FunctionCallbackExtensionsTests.kt +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/FunctionCallbackExtensionsTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,14 @@ * limitations under the License. */ -package org.springframework.ai.model.function +package org.springframework.ai.tool.resolution import io.mockk.every import io.mockk.mockk import io.mockk.verify import org.junit.jupiter.api.Test +import org.springframework.ai.model.function.FunctionCallback +import org.springframework.ai.model.function.inputType class FunctionCallbackExtensionsTests { diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverKotlinTests.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverKotlinTests.kt new file mode 100644 index 00000000000..c9088ea5e54 --- /dev/null +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolverKotlinTests.kt @@ -0,0 +1,114 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution + +import org.assertj.core.api.Assertions +import org.junit.jupiter.api.Test +import org.springframework.ai.util.json.SchemaType +import org.springframework.context.annotation.AnnotationConfigApplicationContext +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.context.annotation.Description +import org.springframework.context.support.GenericApplicationContext +import java.util.function.Consumer +import java.util.function.Function + +/** + * Unit tests for {@link SpringBeanToolCallbackResolver}. + * + * @author Thomas Vitale + */ +class SpringBeanToolCallbackResolverKotlinTests { + + @Test + fun whenToolCallbackWithVoidConsumerIsResolvedThenReturnIt() { + val applicationContext: GenericApplicationContext = AnnotationConfigApplicationContext(Functions::class.java) + val resolver = SpringBeanToolCallbackResolver(applicationContext, SchemaType.JSON_SCHEMA) + val resolvedToolCallback = resolver.resolve(Functions.WELCOME_TOOL_NAME) + Assertions.assertThat(resolvedToolCallback).isNotNull() + Assertions.assertThat(resolvedToolCallback.toolDefinition.name()).isEqualTo(Functions.WELCOME_TOOL_NAME) + Assertions.assertThat(resolvedToolCallback.toolDefinition.description()) + .isEqualTo(Functions.WELCOME_TOOL_DESCRIPTION) + } + + @Test + fun whenToolCallbackWithConsumerIsResolvedThenReturnIt() { + val applicationContext: GenericApplicationContext = AnnotationConfigApplicationContext(Functions::class.java) + val resolver = SpringBeanToolCallbackResolver(applicationContext, SchemaType.JSON_SCHEMA) + val resolvedToolCallback = resolver.resolve(Functions.WELCOME_USER_TOOL_NAME) + Assertions.assertThat(resolvedToolCallback).isNotNull() + Assertions.assertThat(resolvedToolCallback.toolDefinition.name()).isEqualTo(Functions.WELCOME_USER_TOOL_NAME) + Assertions.assertThat(resolvedToolCallback.toolDefinition.description()) + .isEqualTo(Functions.WELCOME_USER_TOOL_DESCRIPTION) + } + + @Test + fun whenToolCallbackWithFunctionIsResolvedThenReturnIt() { + val applicationContext: GenericApplicationContext = AnnotationConfigApplicationContext(Functions::class.java) + val resolver = SpringBeanToolCallbackResolver(applicationContext, SchemaType.JSON_SCHEMA) + val resolvedToolCallback = resolver.resolve(Functions.BOOKS_BY_AUTHOR_TOOL_NAME) + Assertions.assertThat(resolvedToolCallback).isNotNull() + Assertions.assertThat(resolvedToolCallback.toolDefinition.name()).isEqualTo(Functions.BOOKS_BY_AUTHOR_TOOL_NAME) + Assertions.assertThat(resolvedToolCallback.toolDefinition.description()) + .isEqualTo(Functions.BOOKS_BY_AUTHOR_TOOL_DESCRIPTION) + } + + @Configuration(proxyBeanMethods = false) + open class Functions { + + @Bean(WELCOME_TOOL_NAME) + @Description(WELCOME_TOOL_DESCRIPTION) + open fun welcome(): Consumer { + return Consumer { input: Void? -> } + } + + @Bean(WELCOME_USER_TOOL_NAME) + @Description(WELCOME_USER_TOOL_DESCRIPTION) + open fun welcomeUser(): Consumer { + return Consumer { user: User? -> } + } + + @Bean(BOOKS_BY_AUTHOR_TOOL_NAME) + @Description(BOOKS_BY_AUTHOR_TOOL_DESCRIPTION) + open fun booksByAuthor(): Function> { + return Function { author: Author -> + java.util.List.of( + Book("Book 1", author.name), + Book("Book 2", author.name) + ) + } + } + + data class User(val name: String) + + data class Author(val name: String) + + data class Book(val title: String, val author: String) + + companion object { + const val BOOKS_BY_AUTHOR_TOOL_NAME: String = "booksByAuthor" + const val BOOKS_BY_AUTHOR_TOOL_DESCRIPTION: String = "Get the list of books written by the given author available in the library" + + const val WELCOME_TOOL_NAME: String = "welcome" + const val WELCOME_TOOL_DESCRIPTION: String = "Welcome users to the library" + + const val WELCOME_USER_TOOL_NAME: String = "welcomeUser" + const val WELCOME_USER_TOOL_DESCRIPTION: String = "Welcome a specific user to the library" + } + } + +} diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/StandaloneWeatherKotlinFunction.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/StandaloneWeatherKotlinFunction.kt new file mode 100644 index 00000000000..2a76fb80eea --- /dev/null +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/StandaloneWeatherKotlinFunction.kt @@ -0,0 +1,24 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution + +class StandaloneWeatherKotlinFunction : Function1 { + + override fun invoke(weatherRequest: WeatherRequest): WeatherResponse { + return WeatherResponse(42.0f) + } +} diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/TypeResolverHelperKotlinIT.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/TypeResolverHelperKotlinIT.kt similarity index 93% rename from spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/TypeResolverHelperKotlinIT.kt rename to spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/TypeResolverHelperKotlinIT.kt index 8ced2732fce..b7671a89b98 100644 --- a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/TypeResolverHelperKotlinIT.kt +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/TypeResolverHelperKotlinIT.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.model.function +package org.springframework.ai.tool.resolution import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.params.ParameterizedTest @@ -53,7 +53,7 @@ class TypeResolverHelperKotlinIT { } @Configuration - @ComponentScan("org.springframework.ai.model.function.kotlinconfig") + @ComponentScan("org.springframework.ai.tool.resolution.kotlinconfig") open class TypeResolverHelperConfiguration { @Bean diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt similarity index 82% rename from spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt rename to spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt index 5724a0d2c65..7ea2504f4d4 100644 --- a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/tool/resolution/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.model.function.kotlinconfig +package org.springframework.ai.tool.resolution.kotlinconfig -import org.springframework.ai.model.function.StandaloneWeatherKotlinFunction +import org.springframework.ai.tool.resolution.StandaloneWeatherKotlinFunction import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java index fddb7b4b9f6..32c93fc6523 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java @@ -52,7 +52,7 @@ void whenTransformerWithDefaults() { assertThat(transformedQuery).isNotNull(); System.out.println(transformedQuery); - assertThat(transformedQuery.text()).containsIgnoringCase("Large Language Model"); + assertThat(transformedQuery.text()).containsIgnoringCase("model"); } } diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java index 27ee2e0612f..3cdeb44a90f 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java @@ -22,6 +22,9 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.integration.tests.tool.domain.Author; +import org.springframework.ai.integration.tests.tool.domain.Book; +import org.springframework.ai.integration.tests.tool.domain.BookService; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; @@ -32,8 +35,6 @@ import org.springframework.context.annotation.Import; import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.function.Function; @@ -232,44 +233,9 @@ Function> authorsByBooks() { public record User(String name) { } - public record Author(String name) { - } - - public record Authors(List authors) { - } - - public record Book(String title, String author) { - } - public record Books(List books) { } - static class BookService { - - private static final Map books = new ConcurrentHashMap<>(); - - static { - books.put(1, new Book("His Dark Materials", "Philip Pullman")); - books.put(2, new Book("Narnia", "C.S. Lewis")); - books.put(3, new Book("The Hobbit", "J.R.R. Tolkien")); - books.put(4, new Book("The Lord of The Rings", "J.R.R. Tolkien")); - books.put(5, new Book("The Silmarillion", "J.R.R. Tolkien")); - } - - public List getBooksByAuthor(Author author) { - return books.values().stream().filter(book -> author.name().equals(book.author())).toList(); - } - - public List getAuthorsByBook(List booksToSearch) { - return books.values() - .stream() - .filter(book -> booksToSearch.stream().anyMatch(b -> b.title().equals(book.title()))) - .map(book -> new Author(book.author())) - .toList(); - } - - } - // @formatter:on } diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java index d1bc8fb5689..29737fa293c 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java @@ -22,6 +22,9 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.integration.tests.tool.domain.Author; +import org.springframework.ai.integration.tests.tool.domain.Book; +import org.springframework.ai.integration.tests.tool.domain.BookService; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.tool.ToolCallbacks; import org.springframework.ai.tool.annotation.Tool; @@ -30,8 +33,6 @@ import org.springframework.boot.test.context.SpringBootTest; import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import static org.assertj.core.api.Assertions.assertThat; @@ -158,36 +159,4 @@ List authorsByBooks(List books) { } - public record Author(String name) { - } - - public record Book(String title, String author) { - } - - static class BookService { - - private static final Map books = new ConcurrentHashMap<>(); - - static { - books.put(1, new Book("His Dark Materials", "Philip Pullman")); - books.put(2, new Book("Narnia", "C.S. Lewis")); - books.put(3, new Book("The Hobbit", "J.R.R. Tolkien")); - books.put(4, new Book("The Lord of The Rings", "J.R.R. Tolkien")); - books.put(5, new Book("The Silmarillion", "J.R.R. Tolkien")); - } - - public List getBooksByAuthor(Author author) { - return books.values().stream().filter(book -> author.name().equals(book.author())).toList(); - } - - public List getAuthorsByBook(List booksToSearch) { - return books.values() - .stream() - .filter(book -> booksToSearch.stream().anyMatch(b -> b.title().equals(book.title()))) - .map(book -> new Author(book.author())) - .toList(); - } - - } - } diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java new file mode 100644 index 00000000000..936b9368f18 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.integration.tests.tool.domain.Author; +import org.springframework.ai.integration.tests.tool.domain.Book; +import org.springframework.ai.integration.tests.tool.domain.BookService; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import reactor.core.publisher.Flux; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link ToolCallingManager}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +public class ToolCallingManagerTests { + + private final Tools tools = new Tools(); + + private final ToolCallingManager toolCallingManager = ToolCallingManager.builder().build(); + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void explicitToolCallingExecutionWithNewOptions() { + ChatOptions chatOptions = ToolCallingChatOptions.builder() + .toolCallbacks(ToolCallbacks.from(tools)) + .toolExecutionEnabled(false) + .build(); + Prompt prompt = new Prompt( + new UserMessage("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")), + chatOptions); + runExplicitToolCallingExecutionWithOptions(chatOptions, prompt); + } + + @Test + void explicitToolCallingExecutionWithLegacyOptions() { + ChatOptions chatOptions = FunctionCallingOptions.builder() + .functionCallbacks(ToolCallbacks.from(tools)) + .proxyToolCalls(true) + .build(); + Prompt prompt = new Prompt( + new UserMessage("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")), + chatOptions); + runExplicitToolCallingExecutionWithOptions(chatOptions, prompt); + } + + @Test + void explicitToolCallingExecutionWithNewOptionsStream() { + ChatOptions chatOptions = ToolCallingChatOptions.builder() + .toolCallbacks(ToolCallbacks.from(tools)) + .toolExecutionEnabled(false) + .build(); + Prompt prompt = new Prompt(new UserMessage("What books written by %s, %s, and %s are available in the library?" + .formatted("J.R.R. Tolkien", "Philip Pullman", "C.S. Lewis")), chatOptions); + runExplicitToolCallingExecutionWithOptionsStream(chatOptions, prompt); + } + + private void runExplicitToolCallingExecutionWithOptions(ChatOptions chatOptions, Prompt prompt) { + ChatResponse chatResponse = openAiChatModel.call(prompt); + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.hasToolCalls()).isTrue(); + + List messages = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(messages).isNotEmpty(); + assertThat(messages.stream().anyMatch(m -> m instanceof ToolResponseMessage)).isTrue(); + + Prompt secondPrompt = new Prompt(messages, chatOptions); + + ChatResponse secondChatResponse = openAiChatModel.call(secondPrompt); + + assertThat(secondChatResponse).isNotNull(); + assertThat(secondChatResponse.getResult().getOutput().getText()).isNotEmpty() + .contains("The Hobbit") + .contains("The Lord of The Rings") + .contains("The Silmarillion"); + } + + private void runExplicitToolCallingExecutionWithOptionsStream(ChatOptions chatOptions, Prompt prompt) { + ChatResponse chatResponse = openAiChatModel.stream(prompt).flatMap(response -> { + if (response.hasToolCalls()) { + List messages = toolCallingManager.executeToolCalls(prompt, response); + + assertThat(messages).isNotEmpty(); + assertThat(messages.stream().anyMatch(m -> m instanceof ToolResponseMessage)).isTrue(); + + Prompt secondPrompt = new Prompt(messages, chatOptions); + // return openAiChatModel.stream(secondPrompt); + return Flux.just(openAiChatModel.call(secondPrompt)); + } + return Flux.just(response); + }).blockLast(); + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty() + .contains("His Dark Materials") + .contains("Narnia") + .contains("The Hobbit") + .contains("The Lord of The Rings") + .contains("The Silmarillion"); + } + + static class Tools { + + private static final Logger logger = LoggerFactory.getLogger(Tools.class); + + private final BookService bookService = new BookService(); + + @Tool(description = "Get the list of books written by the given author available in the library") + List booksByAuthor(String author) { + logger.info("Getting books by author: {}", author); + return bookService.getBooksByAuthor(new Author(author)); + } + + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Author.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Author.java new file mode 100644 index 00000000000..69d6374ea68 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Author.java @@ -0,0 +1,23 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.tool.domain; + +/** + * @author Thomas Vitale + */ +public record Author(String name) { +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Book.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Book.java new file mode 100644 index 00000000000..7f43a307032 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/Book.java @@ -0,0 +1,23 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.tool.domain; + +/** + * @author Thomas Vitale + */ +public record Book(String title, String author) { +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/BookService.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/BookService.java new file mode 100644 index 00000000000..240a3134388 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/domain/BookService.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.tool.domain; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author Thomas Vitale + */ +public class BookService { + + private static final Map books = new ConcurrentHashMap<>(); + + static { + books.put(1, new Book("His Dark Materials", "Philip Pullman")); + books.put(2, new Book("Narnia", "C.S. Lewis")); + books.put(3, new Book("The Hobbit", "J.R.R. Tolkien")); + books.put(4, new Book("The Lord of The Rings", "J.R.R. Tolkien")); + books.put(5, new Book("The Silmarillion", "J.R.R. Tolkien")); + } + + public List getBooksByAuthor(Author author) { + return books.values().stream().filter(book -> author.name().equals(book.author())).toList(); + } + + public List getAuthorsByBook(List booksToSearch) { + return books.values() + .stream() + .filter(book -> booksToSearch.stream().anyMatch(b -> b.title().equals(book.title()))) + .map(book -> new Author(book.author())) + .toList(); + } + +}