Skip to content

Commit 0fffb5a

Browse files
author
wmz7year
committed
Add Bedrock Cohere Command R model support.
1 parent 2c8968d commit 0fffb5a

File tree

17 files changed

+2036
-1
lines changed

17 files changed

+2036
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.cohere;
17+
18+
import java.util.ArrayList;
19+
import java.util.HashSet;
20+
import java.util.List;
21+
import java.util.Set;
22+
23+
import org.springframework.ai.bedrock.BedrockConverseChatGenerationMetadata;
24+
import org.springframework.ai.bedrock.api.BedrockConverseApi;
25+
import org.springframework.ai.bedrock.api.BedrockConverseApiUtils;
26+
import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest;
27+
import org.springframework.ai.chat.model.ChatModel;
28+
import org.springframework.ai.chat.model.ChatResponse;
29+
import org.springframework.ai.chat.model.Generation;
30+
import org.springframework.ai.chat.model.StreamingChatModel;
31+
import org.springframework.ai.chat.prompt.ChatOptions;
32+
import org.springframework.ai.chat.prompt.Prompt;
33+
import org.springframework.ai.model.ModelOptionsUtils;
34+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
35+
import org.springframework.ai.model.function.FunctionCallbackContext;
36+
import org.springframework.util.Assert;
37+
import org.springframework.util.CollectionUtils;
38+
39+
import reactor.core.publisher.Flux;
40+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
41+
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
42+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
43+
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
44+
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
45+
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
46+
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
47+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
48+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
49+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultStatus;
50+
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
51+
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
52+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type;
53+
54+
/**
55+
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Command R chat
56+
* generative model.
57+
*
58+
* @author Wei Jiang
59+
* @since 1.0.0
60+
*/
61+
public class BedrockCohereCommandRChatModel
62+
extends AbstractFunctionCallSupport<Message, BedrockConverseRequest, ChatResponse>
63+
implements ChatModel, StreamingChatModel {
64+
65+
private final String modelId;
66+
67+
private final BedrockConverseApi converseApi;
68+
69+
private final BedrockCohereCommandRChatOptions defaultOptions;
70+
71+
public BedrockCohereCommandRChatModel(BedrockConverseApi converseApi) {
72+
this(converseApi, BedrockCohereCommandRChatOptions.builder().build());
73+
}
74+
75+
public BedrockCohereCommandRChatModel(BedrockConverseApi converseApi, BedrockCohereCommandRChatOptions options) {
76+
this(CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), converseApi, options);
77+
}
78+
79+
public BedrockCohereCommandRChatModel(String modelId, BedrockConverseApi converseApi,
80+
BedrockCohereCommandRChatOptions options) {
81+
this(modelId, converseApi, options, null);
82+
}
83+
84+
public BedrockCohereCommandRChatModel(String modelId, BedrockConverseApi converseApi,
85+
BedrockCohereCommandRChatOptions options, FunctionCallbackContext functionCallbackContext) {
86+
super(functionCallbackContext);
87+
88+
Assert.notNull(modelId, "modelId must not be null.");
89+
Assert.notNull(converseApi, "BedrockConverseApi must not be null.");
90+
Assert.notNull(options, "BedrockCohereCommandRChatOptions must not be null.");
91+
92+
this.modelId = modelId;
93+
this.converseApi = converseApi;
94+
this.defaultOptions = options;
95+
}
96+
97+
@Override
98+
public ChatResponse call(Prompt prompt) {
99+
Assert.notNull(prompt, "Prompt must not be null.");
100+
101+
var request = createBedrockConverseRequest(prompt);
102+
103+
return this.callWithFunctionSupport(request);
104+
}
105+
106+
@Override
107+
public Flux<ChatResponse> stream(Prompt prompt) {
108+
Assert.notNull(prompt, "Prompt must not be null.");
109+
110+
var request = createBedrockConverseRequest(prompt);
111+
112+
return converseApi.converseStream(request);
113+
}
114+
115+
private BedrockConverseRequest createBedrockConverseRequest(Prompt prompt) {
116+
var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions);
117+
118+
ToolConfiguration toolConfiguration = createToolConfiguration(prompt);
119+
120+
return BedrockConverseRequest.from(request).withToolConfiguration(toolConfiguration).build();
121+
}
122+
123+
private ToolConfiguration createToolConfiguration(Prompt prompt) {
124+
Set<String> functionsForThisRequest = new HashSet<>();
125+
126+
if (this.defaultOptions != null) {
127+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
128+
!IS_RUNTIME_CALL);
129+
functionsForThisRequest.addAll(promptEnabledFunctions);
130+
}
131+
132+
if (prompt.getOptions() != null) {
133+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
134+
BedrockCohereCommandRChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
135+
ChatOptions.class, BedrockCohereCommandRChatOptions.class);
136+
137+
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
138+
IS_RUNTIME_CALL);
139+
functionsForThisRequest.addAll(defaultEnabledFunctions);
140+
}
141+
else {
142+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
143+
+ prompt.getOptions().getClass().getSimpleName());
144+
}
145+
}
146+
147+
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
148+
return ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build();
149+
}
150+
151+
return null;
152+
}
153+
154+
private List<Tool> getFunctionTools(Set<String> functionNames) {
155+
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
156+
var description = functionCallback.getDescription();
157+
var name = functionCallback.getName();
158+
String inputSchema = functionCallback.getInputTypeSchema();
159+
160+
return Tool.builder()
161+
.toolSpec(ToolSpecification.builder()
162+
.name(name)
163+
.description(description)
164+
.inputSchema(ToolInputSchema.builder()
165+
.json(BedrockConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema)))
166+
.build())
167+
.build())
168+
.build();
169+
}).toList();
170+
}
171+
172+
@Override
173+
public ChatOptions getDefaultOptions() {
174+
return BedrockCohereCommandRChatOptions.fromOptions(defaultOptions);
175+
}
176+
177+
@Override
178+
protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest,
179+
Message responseMessage, List<Message> conversationHistory) {
180+
List<ToolUseBlock> toolToUseList = responseMessage.content()
181+
.stream()
182+
.filter(content -> content.type() == Type.TOOL_USE)
183+
.map(content -> content.toolUse())
184+
.toList();
185+
186+
List<ToolResultBlock> toolResults = new ArrayList<>();
187+
188+
for (ToolUseBlock toolToUse : toolToUseList) {
189+
var functionCallId = toolToUse.toolUseId();
190+
var functionName = toolToUse.name();
191+
var functionArguments = toolToUse.input().unwrap();
192+
193+
if (!this.functionCallbackRegister.containsKey(functionName)) {
194+
throw new IllegalStateException("No function callback found for function name: " + functionName);
195+
}
196+
197+
String functionResponse = this.functionCallbackRegister.get(functionName)
198+
.call(ModelOptionsUtils.toJsonString(functionArguments));
199+
200+
toolResults.add(ToolResultBlock.builder()
201+
.toolUseId(functionCallId)
202+
.status(ToolResultStatus.SUCCESS)
203+
.content(ToolResultContentBlock.builder().text(functionResponse).build())
204+
.build());
205+
}
206+
207+
// Add the function response to the conversation.
208+
Message toolResultMessage = Message.builder()
209+
.content(toolResults.stream().map(toolResult -> ContentBlock.fromToolResult(toolResult)).toList())
210+
.role(ConversationRole.USER)
211+
.build();
212+
conversationHistory.add(toolResultMessage);
213+
214+
// Recursively call chatCompletionWithTools until the model doesn't call a
215+
// functions anymore.
216+
return BedrockConverseRequest.from(previousRequest).withMessages(conversationHistory).build();
217+
}
218+
219+
@Override
220+
protected List<Message> doGetUserMessages(BedrockConverseRequest request) {
221+
return request.messages();
222+
}
223+
224+
@Override
225+
protected Message doGetToolResponseMessage(ChatResponse response) {
226+
Generation result = response.getResult();
227+
228+
var metadata = (BedrockConverseChatGenerationMetadata) result.getMetadata();
229+
230+
return metadata.getMessage();
231+
}
232+
233+
@Override
234+
protected ChatResponse doChatCompletion(BedrockConverseRequest request) {
235+
return converseApi.converse(request);
236+
}
237+
238+
@Override
239+
protected Flux<ChatResponse> doChatCompletionStream(BedrockConverseRequest request) {
240+
throw new UnsupportedOperationException("Streaming function calling is not supported.");
241+
}
242+
243+
@Override
244+
protected boolean isToolFunctionCall(ChatResponse response) {
245+
Generation result = response.getResult();
246+
if (result == null) {
247+
return false;
248+
}
249+
250+
return StopReason.fromValue(result.getMetadata().getFinishReason()) == StopReason.TOOL_USE;
251+
}
252+
253+
/**
254+
* Cohere command R models version.
255+
*/
256+
public enum CohereCommandRChatModel {
257+
258+
/**
259+
* cohere.command-r-v1:0
260+
*/
261+
COHERE_COMMAND_R_V1("cohere.command-r-v1:0"),
262+
263+
/**
264+
* cohere.command-r-plus-v1:0
265+
*/
266+
COHERE_COMMAND_R_PLUS_V1("cohere.command-r-plus-v1:0");
267+
268+
private final String id;
269+
270+
/**
271+
* @return The model id.
272+
*/
273+
public String id() {
274+
return id;
275+
}
276+
277+
CohereCommandRChatModel(String value) {
278+
this.id = value;
279+
}
280+
281+
}
282+
283+
}

0 commit comments

Comments
 (0)