Skip to content

Commit 89f9c89

Browse files
author
wmz7year
committed
Amazon Bedrock Chat adds tool support.
1 parent 49b3326 commit 89f9c89

File tree

15 files changed

+1205
-83
lines changed

15 files changed

+1205
-83
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock;
17+
18+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
19+
20+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
21+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
22+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
23+
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
24+
25+
/**
26+
* Amazon Bedrock Chat model converse interface generation metadata, encapsulating
27+
* information on the completion.
28+
*
29+
* @author Wei Jiang
30+
* @since 1.0.0
31+
*/
32+
public class BedrockConverseChatGenerationMetadata implements ChatGenerationMetadata {
33+
34+
private String stopReason;
35+
36+
private Message message;
37+
38+
private ConverseStreamOutput event;
39+
40+
public BedrockConverseChatGenerationMetadata(String stopReason, ConverseStreamOutput event) {
41+
super();
42+
43+
this.stopReason = stopReason;
44+
this.event = event;
45+
}
46+
47+
public BedrockConverseChatGenerationMetadata(String stopReason, Message message) {
48+
super();
49+
50+
this.stopReason = stopReason;
51+
this.message = message;
52+
}
53+
54+
public static BedrockConverseChatGenerationMetadata from(ConverseResponse response, Message message) {
55+
return new BedrockConverseChatGenerationMetadata(response.stopReasonAsString(), message);
56+
}
57+
58+
public static BedrockConverseChatGenerationMetadata from(ConverseStreamOutput event) {
59+
String stopReason = null;
60+
61+
if (event instanceof MessageStopEvent messageStopEvent) {
62+
stopReason = messageStopEvent.stopReasonAsString();
63+
}
64+
65+
return new BedrockConverseChatGenerationMetadata(stopReason, event);
66+
}
67+
68+
@Override
69+
public <T> T getContentFilterMetadata() {
70+
return null;
71+
}
72+
73+
@Override
74+
public String getFinishReason() {
75+
return stopReason;
76+
}
77+
78+
public Message getMessage() {
79+
return message;
80+
}
81+
82+
public ConverseStreamOutput getEvent() {
83+
return event;
84+
}
85+
86+
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515
*/
1616
package org.springframework.ai.bedrock.anthropic3;
1717

18+
import com.fasterxml.jackson.annotation.JsonIgnore;
1819
import com.fasterxml.jackson.annotation.JsonInclude;
1920
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2021
import com.fasterxml.jackson.annotation.JsonProperty;
2122
import org.springframework.ai.chat.prompt.ChatOptions;
23+
import org.springframework.ai.model.function.FunctionCallback;
24+
import org.springframework.ai.model.function.FunctionCallingOptions;
25+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
26+
import org.springframework.util.Assert;
2227

28+
import java.util.ArrayList;
29+
import java.util.HashSet;
2330
import java.util.List;
31+
import java.util.Set;
2432

2533
/**
2634
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
@@ -31,7 +39,7 @@
3139
* @since 1.0.0
3240
*/
3341
@JsonInclude(Include.NON_NULL)
34-
public class Anthropic3ChatOptions implements ChatOptions {
42+
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {
3543

3644
// @formatter:off
3745
/**
@@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions {
6674
*/
6775
private @JsonProperty("stop_sequences") List<String> stopSequences;
6876

77+
/**
78+
* Tool Function Callbacks to register with the ChatModel. For Prompt
79+
* Options the functionCallbacks are automatically enabled for the duration of the
80+
* prompt execution. For Default Options the functionCallbacks are registered but
81+
* disabled by default. Use the enableFunctions to set the functions from the registry
82+
* to be used by the ChatModel chat completion requests.
83+
*/
84+
@NestedConfigurationProperty
85+
@JsonIgnore
86+
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
87+
88+
/**
89+
* List of functions, identified by their names, to configure for function calling in
90+
* the chat completion requests. Functions with those names must exist in the
91+
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
92+
* are automatically enabled for the duration of the prompt execution.
93+
*
94+
* Note that function enabled with the default options are enabled for all chat
95+
* completion requests. This could impact the token count and the billing. If the
96+
* functions is set in a prompt options, then the enabled functions are only active
97+
* for the duration of this prompt execution.
98+
*/
99+
@NestedConfigurationProperty
100+
@JsonIgnore
101+
private Set<String> functions = new HashSet<>();
69102
// @formatter:on
70103

71104
public static Builder builder() {
@@ -101,6 +134,23 @@ public Builder withStopSequences(List<String> stopSequences) {
101134
return this;
102135
}
103136

137+
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
138+
this.options.functionCallbacks = functionCallbacks;
139+
return this;
140+
}
141+
142+
public Builder withFunctions(Set<String> functionNames) {
143+
Assert.notNull(functionNames, "Function names must not be null");
144+
this.options.functions = functionNames;
145+
return this;
146+
}
147+
148+
public Builder withFunction(String functionName) {
149+
Assert.hasText(functionName, "Function name must not be empty");
150+
this.options.functions.add(functionName);
151+
return this;
152+
}
153+
104154
public Anthropic3ChatOptions build() {
105155
return this.options;
106156
}
@@ -150,12 +200,36 @@ public void setStopSequences(List<String> stopSequences) {
150200
this.stopSequences = stopSequences;
151201
}
152202

203+
@Override
204+
public List<FunctionCallback> getFunctionCallbacks() {
205+
return this.functionCallbacks;
206+
}
207+
208+
@Override
209+
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
210+
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
211+
this.functionCallbacks = functionCallbacks;
212+
}
213+
214+
@Override
215+
public Set<String> getFunctions() {
216+
return this.functions;
217+
}
218+
219+
@Override
220+
public void setFunctions(Set<String> functions) {
221+
Assert.notNull(functions, "Function must not be null");
222+
this.functions = functions;
223+
}
224+
153225
public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
154226
return builder().withTemperature(fromOptions.getTemperature())
155227
.withMaxTokens(fromOptions.getMaxTokens())
156228
.withTopK(fromOptions.getTopK())
157229
.withTopP(fromOptions.getTopP())
158230
.withStopSequences(fromOptions.getStopSequences())
231+
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
232+
.withFunctions(fromOptions.getFunctions())
159233
.build();
160234
}
161235

0 commit comments

Comments
 (0)