Skip to content

Commit 400d0d1

Browse files
committed
Add VertexAI Gemini system-instruction support
Resolves #906
1 parent 75d6866 commit 400d0d1

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.google.cloud.vertexai.api.Part;
2929
import com.google.cloud.vertexai.api.Schema;
3030
import com.google.cloud.vertexai.api.Tool;
31+
import com.google.cloud.vertexai.generativeai.ContentMaker;
3132
import com.google.cloud.vertexai.generativeai.GenerativeModel;
3233
import com.google.cloud.vertexai.generativeai.PartMaker;
3334
import com.google.cloud.vertexai.generativeai.ResponseStream;
@@ -258,6 +259,16 @@ private GeminiRequest createGeminiRequest(Prompt prompt) {
258259

259260
GenerativeModel generativeModel = generativeModelBuilder.build();
260261

262+
String systemContext = prompt.getInstructions()
263+
.stream()
264+
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
265+
.map(m -> m.getContent())
266+
.collect(Collectors.joining(System.lineSeparator()));
267+
268+
if (StringUtils.hasText(systemContext)) {
269+
generativeModel.withSystemInstruction(ContentMaker.fromString(systemContext));
270+
}
271+
261272
return new GeminiRequest(toGeminiContent(prompt), generativeModel);
262273
}
263274

@@ -289,18 +300,12 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
289300

290301
private List<Content> toGeminiContent(Prompt prompt) {
291302

292-
String systemContext = prompt.getInstructions()
293-
.stream()
294-
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
295-
.map(m -> m.getContent())
296-
.collect(Collectors.joining(System.lineSeparator()));
297-
298303
List<Content> contents = prompt.getInstructions()
299304
.stream()
300305
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
301306
.map(message -> Content.newBuilder()
302307
.setRole(toGeminiMessageType(message.getMessageType()).getValue())
303-
.addAllParts(messageToGeminiParts(message, systemContext))
308+
.addAllParts(messageToGeminiParts(message))
304309
.build())
305310
.toList();
306311

@@ -321,14 +326,11 @@ private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type)
321326
}
322327
}
323328

324-
static List<Part> messageToGeminiParts(Message message, String systemContext) {
329+
static List<Part> messageToGeminiParts(Message message) {
325330

326331
if (message instanceof UserMessage userMessage) {
327332

328333
String messageTextContent = (userMessage.getContent() == null) ? "null" : userMessage.getContent();
329-
if (StringUtils.hasText(systemContext)) {
330-
messageTextContent = systemContext + "\n\n" + messageTextContent;
331-
}
332334
Part textPart = Part.newBuilder().setText(messageTextContent).build();
333335

334336
List<Part> parts = new ArrayList<>(List.of(textPart));

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.slf4j.LoggerFactory;
2929
import org.springframework.ai.chat.messages.AssistantMessage;
3030
import org.springframework.ai.chat.messages.Message;
31-
import org.springframework.ai.chat.messages.SystemMessage;
3231
import org.springframework.ai.chat.messages.UserMessage;
3332
import org.springframework.ai.chat.model.ChatResponse;
3433
import org.springframework.ai.chat.model.Generation;

0 commit comments

Comments
 (0)