Skip to content

Commit 2c8968d

Browse files
author
wmz7year
committed
Add Amazon Bedrock Mistral model support.
1 parent 216df35 commit 2c8968d

File tree

20 files changed

+1794
-3
lines changed

20 files changed

+1794
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
8888
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
8989
* OpenAI
9090
* Azure OpenAI
91-
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
91+
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2, Mistral)
9292
* HuggingFace
9393
* Google VertexAI (PaLM2, Gemini)
9494
* Mistral AI

models/spring-ai-bedrock/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
99
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
1010
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)
11+
- [Mistral Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-mistral.html)
1112

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.mistral;
17+
18+
import java.util.ArrayList;
19+
import java.util.HashSet;
20+
import java.util.List;
21+
import java.util.Set;
22+
23+
import org.springframework.ai.bedrock.BedrockConverseChatGenerationMetadata;
24+
import org.springframework.ai.bedrock.api.BedrockConverseApi;
25+
import org.springframework.ai.bedrock.api.BedrockConverseApiUtils;
26+
import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest;
27+
import org.springframework.ai.chat.model.ChatModel;
28+
import org.springframework.ai.chat.model.ChatResponse;
29+
import org.springframework.ai.chat.model.Generation;
30+
import org.springframework.ai.chat.model.StreamingChatModel;
31+
import org.springframework.ai.chat.prompt.ChatOptions;
32+
import org.springframework.ai.chat.prompt.Prompt;
33+
import org.springframework.ai.model.ModelDescription;
34+
import org.springframework.ai.model.ModelOptionsUtils;
35+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
36+
import org.springframework.ai.model.function.FunctionCallbackContext;
37+
import org.springframework.util.Assert;
38+
import org.springframework.util.CollectionUtils;
39+
40+
import reactor.core.publisher.Flux;
41+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
42+
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
43+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
44+
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
45+
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
46+
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
47+
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
48+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
49+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
50+
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultStatus;
51+
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
52+
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
53+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type;
54+
55+
/**
56+
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Mistral chat
57+
* generative model.
58+
*
59+
* @author Wei Jiang
60+
* @since 1.0.0
61+
*/
62+
public class BedrockMistralChatModel extends AbstractFunctionCallSupport<Message, BedrockConverseRequest, ChatResponse>
63+
implements ChatModel, StreamingChatModel {
64+
65+
private final String modelId;
66+
67+
private final BedrockConverseApi converseApi;
68+
69+
private final BedrockMistralChatOptions defaultOptions;
70+
71+
public BedrockMistralChatModel(BedrockConverseApi converseApi) {
72+
this(converseApi, BedrockMistralChatOptions.builder().build());
73+
}
74+
75+
public BedrockMistralChatModel(BedrockConverseApi converseApi, BedrockMistralChatOptions options) {
76+
this(MistralChatModel.MISTRAL_LARGE.id(), converseApi, options);
77+
}
78+
79+
public BedrockMistralChatModel(String modelId, BedrockConverseApi converseApi, BedrockMistralChatOptions options) {
80+
this(modelId, converseApi, options, null);
81+
}
82+
83+
public BedrockMistralChatModel(String modelId, BedrockConverseApi converseApi, BedrockMistralChatOptions options,
84+
FunctionCallbackContext functionCallbackContext) {
85+
super(functionCallbackContext);
86+
87+
Assert.notNull(modelId, "modelId must not be null.");
88+
Assert.notNull(converseApi, "BedrockConverseApi must not be null.");
89+
Assert.notNull(options, "BedrockMistralChatOptions must not be null.");
90+
91+
this.modelId = modelId;
92+
this.converseApi = converseApi;
93+
this.defaultOptions = options;
94+
}
95+
96+
@Override
97+
public ChatResponse call(Prompt prompt) {
98+
Assert.notNull(prompt, "Prompt must not be null.");
99+
100+
var request = createBedrockConverseRequest(prompt);
101+
102+
return this.callWithFunctionSupport(request);
103+
}
104+
105+
@Override
106+
public Flux<ChatResponse> stream(Prompt prompt) {
107+
Assert.notNull(prompt, "Prompt must not be null.");
108+
109+
var request = createBedrockConverseRequest(prompt);
110+
111+
return converseApi.converseStream(request);
112+
}
113+
114+
private BedrockConverseRequest createBedrockConverseRequest(Prompt prompt) {
115+
var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions);
116+
117+
ToolConfiguration toolConfiguration = createToolConfiguration(prompt);
118+
119+
return BedrockConverseRequest.from(request).withToolConfiguration(toolConfiguration).build();
120+
}
121+
122+
private ToolConfiguration createToolConfiguration(Prompt prompt) {
123+
Set<String> functionsForThisRequest = new HashSet<>();
124+
125+
if (this.defaultOptions != null) {
126+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
127+
!IS_RUNTIME_CALL);
128+
functionsForThisRequest.addAll(promptEnabledFunctions);
129+
}
130+
131+
if (prompt.getOptions() != null) {
132+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
133+
BedrockMistralChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
134+
ChatOptions.class, BedrockMistralChatOptions.class);
135+
136+
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
137+
IS_RUNTIME_CALL);
138+
functionsForThisRequest.addAll(defaultEnabledFunctions);
139+
}
140+
else {
141+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
142+
+ prompt.getOptions().getClass().getSimpleName());
143+
}
144+
}
145+
146+
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
147+
return ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build();
148+
}
149+
150+
return null;
151+
}
152+
153+
private List<Tool> getFunctionTools(Set<String> functionNames) {
154+
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
155+
var description = functionCallback.getDescription();
156+
var name = functionCallback.getName();
157+
String inputSchema = functionCallback.getInputTypeSchema();
158+
159+
return Tool.builder()
160+
.toolSpec(ToolSpecification.builder()
161+
.name(name)
162+
.description(description)
163+
.inputSchema(ToolInputSchema.builder()
164+
.json(BedrockConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema)))
165+
.build())
166+
.build())
167+
.build();
168+
}).toList();
169+
}
170+
171+
@Override
172+
public ChatOptions getDefaultOptions() {
173+
return defaultOptions;
174+
}
175+
176+
@Override
177+
protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest,
178+
Message responseMessage, List<Message> conversationHistory) {
179+
List<ToolUseBlock> toolToUseList = responseMessage.content()
180+
.stream()
181+
.filter(content -> content.type() == Type.TOOL_USE)
182+
.map(content -> content.toolUse())
183+
.toList();
184+
185+
List<ToolResultBlock> toolResults = new ArrayList<>();
186+
187+
for (ToolUseBlock toolToUse : toolToUseList) {
188+
var functionCallId = toolToUse.toolUseId();
189+
var functionName = toolToUse.name();
190+
var functionArguments = toolToUse.input().unwrap();
191+
192+
if (!this.functionCallbackRegister.containsKey(functionName)) {
193+
throw new IllegalStateException("No function callback found for function name: " + functionName);
194+
}
195+
196+
String functionResponse = this.functionCallbackRegister.get(functionName)
197+
.call(ModelOptionsUtils.toJsonString(functionArguments));
198+
199+
toolResults.add(ToolResultBlock.builder()
200+
.toolUseId(functionCallId)
201+
.status(ToolResultStatus.SUCCESS)
202+
.content(ToolResultContentBlock.builder().text(functionResponse).build())
203+
.build());
204+
}
205+
206+
// Add the function response to the conversation.
207+
Message toolResultMessage = Message.builder()
208+
.content(toolResults.stream().map(toolResult -> ContentBlock.fromToolResult(toolResult)).toList())
209+
.role(ConversationRole.USER)
210+
.build();
211+
conversationHistory.add(toolResultMessage);
212+
213+
// Recursively call chatCompletionWithTools until the model doesn't call a
214+
// functions anymore.
215+
return BedrockConverseRequest.from(previousRequest).withMessages(conversationHistory).build();
216+
}
217+
218+
@Override
219+
protected List<Message> doGetUserMessages(BedrockConverseRequest request) {
220+
return request.messages();
221+
}
222+
223+
@Override
224+
protected Message doGetToolResponseMessage(ChatResponse response) {
225+
Generation result = response.getResult();
226+
227+
var metadata = (BedrockConverseChatGenerationMetadata) result.getMetadata();
228+
229+
return metadata.getMessage();
230+
}
231+
232+
@Override
233+
protected ChatResponse doChatCompletion(BedrockConverseRequest request) {
234+
return converseApi.converse(request);
235+
}
236+
237+
@Override
238+
protected Flux<ChatResponse> doChatCompletionStream(BedrockConverseRequest request) {
239+
throw new UnsupportedOperationException("Streaming function calling is not supported.");
240+
}
241+
242+
@Override
243+
protected boolean isToolFunctionCall(ChatResponse response) {
244+
Generation result = response.getResult();
245+
if (result == null) {
246+
return false;
247+
}
248+
249+
return StopReason.fromValue(result.getMetadata().getFinishReason()) == StopReason.TOOL_USE;
250+
}
251+
252+
/**
253+
* Mistral models version.
254+
*/
255+
public enum MistralChatModel implements ModelDescription {
256+
257+
/**
258+
* mistral.mistral-7b-instruct-v0:2
259+
*/
260+
MISTRAL_7B_INSTRUCT("mistral.mistral-7b-instruct-v0:2"),
261+
262+
/**
263+
* mistral.mixtral-8x7b-instruct-v0:1
264+
*/
265+
MISTRAL_8X7B_INSTRUCT("mistral.mixtral-8x7b-instruct-v0:1"),
266+
267+
/**
268+
* mistral.mistral-large-2402-v1:0
269+
*/
270+
MISTRAL_LARGE("mistral.mistral-large-2402-v1:0"),
271+
272+
/**
273+
* mistral.mistral-small-2402-v1:0
274+
*/
275+
MISTRAL_SMALL("mistral.mistral-small-2402-v1:0");
276+
277+
private final String id;
278+
279+
/**
280+
* @return The model id.
281+
*/
282+
public String id() {
283+
return id;
284+
}
285+
286+
MistralChatModel(String value) {
287+
this.id = value;
288+
}
289+
290+
@Override
291+
public String getModelName() {
292+
return this.id;
293+
}
294+
295+
}
296+
297+
}

0 commit comments

Comments
 (0)