28
28
import com .google .cloud .vertexai .api .Part ;
29
29
import com .google .cloud .vertexai .api .Schema ;
30
30
import com .google .cloud .vertexai .api .Tool ;
31
+ import com .google .cloud .vertexai .generativeai .ContentMaker ;
31
32
import com .google .cloud .vertexai .generativeai .GenerativeModel ;
32
33
import com .google .cloud .vertexai .generativeai .PartMaker ;
33
34
import com .google .cloud .vertexai .generativeai .ResponseStream ;
@@ -258,6 +259,16 @@ private GeminiRequest createGeminiRequest(Prompt prompt) {
258
259
259
260
GenerativeModel generativeModel = generativeModelBuilder .build ();
260
261
262
+ String systemContext = prompt .getInstructions ()
263
+ .stream ()
264
+ .filter (m -> m .getMessageType () == MessageType .SYSTEM )
265
+ .map (m -> m .getContent ())
266
+ .collect (Collectors .joining (System .lineSeparator ()));
267
+
268
+ if (StringUtils .hasText (systemContext )) {
269
+ generativeModel .withSystemInstruction (ContentMaker .fromString (systemContext ));
270
+ }
271
+
261
272
return new GeminiRequest (toGeminiContent (prompt ), generativeModel );
262
273
}
263
274
@@ -289,18 +300,12 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
289
300
290
301
private List <Content > toGeminiContent (Prompt prompt ) {
291
302
292
- String systemContext = prompt .getInstructions ()
293
- .stream ()
294
- .filter (m -> m .getMessageType () == MessageType .SYSTEM )
295
- .map (m -> m .getContent ())
296
- .collect (Collectors .joining (System .lineSeparator ()));
297
-
298
303
List <Content > contents = prompt .getInstructions ()
299
304
.stream ()
300
305
.filter (m -> m .getMessageType () == MessageType .USER || m .getMessageType () == MessageType .ASSISTANT )
301
306
.map (message -> Content .newBuilder ()
302
307
.setRole (toGeminiMessageType (message .getMessageType ()).getValue ())
303
- .addAllParts (messageToGeminiParts (message , systemContext ))
308
+ .addAllParts (messageToGeminiParts (message ))
304
309
.build ())
305
310
.toList ();
306
311
@@ -321,14 +326,11 @@ private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type)
321
326
}
322
327
}
323
328
324
- static List <Part > messageToGeminiParts (Message message , String systemContext ) {
329
+ static List <Part > messageToGeminiParts (Message message ) {
325
330
326
331
if (message instanceof UserMessage userMessage ) {
327
332
328
333
String messageTextContent = (userMessage .getContent () == null ) ? "null" : userMessage .getContent ();
329
- if (StringUtils .hasText (systemContext )) {
330
- messageTextContent = systemContext + "\n \n " + messageTextContent ;
331
- }
332
334
Part textPart = Part .newBuilder ().setText (messageTextContent ).build ();
333
335
334
336
List <Part > parts = new ArrayList <>(List .of (textPart ));
0 commit comments