Skip to content

Commit e557081

Browse files
rodrigomalarailayaperumalg
authored andcommitted
[GH-3723] Vertex AI Gemini logprobs support
Signed-off-by: Rodrigo Malara <[email protected]>
1 parent 656395e commit e557081

File tree

5 files changed

+129
-4
lines changed

5 files changed

+129
-4
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
import org.springframework.ai.retry.RetryUtils;
8686
import org.springframework.ai.support.UsageCalculator;
8787
import org.springframework.ai.tool.definition.ToolDefinition;
88+
import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi;
8889
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants;
8990
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
9091
import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager;
@@ -587,8 +588,28 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
587588
int candidateIndex = candidate.getIndex();
588589
FinishReason candidateFinishReason = candidate.getFinishReason();
589590

591+
// Convert from VertexAI protobuf to VertexAiGeminiApi DTOs
592+
List<VertexAiGeminiApi.LogProbs.TopContent> topCandidates = candidate.getLogprobsResult()
593+
.getTopCandidatesList()
594+
.stream()
595+
.filter(topCandidate -> !topCandidate.getCandidatesList().isEmpty())
596+
.map(topCandidate -> new VertexAiGeminiApi.LogProbs.TopContent(topCandidate.getCandidatesList()
597+
.stream()
598+
.map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId()))
599+
.toList()))
600+
.toList();
601+
602+
List<VertexAiGeminiApi.LogProbs.Content> chosenCandidates = candidate.getLogprobsResult()
603+
.getChosenCandidatesList()
604+
.stream()
605+
.map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId()))
606+
.toList();
607+
608+
VertexAiGeminiApi.LogProbs logprobs = new VertexAiGeminiApi.LogProbs(candidate.getAvgLogprobs(), topCandidates,
609+
chosenCandidates);
610+
590611
Map<String, Object> messageMetadata = Map.of("candidateIndex", candidateIndex, "finishReason",
591-
candidateFinishReason);
612+
candidateFinishReason, "logprobs", logprobs);
592613

593614
ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.builder()
594615
.finishReason(candidateFinishReason.name())
@@ -744,6 +765,10 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
744765
if (options.getPresencePenalty() != null) {
745766
generationConfigBuilder.setPresencePenalty(options.getPresencePenalty().floatValue());
746767
}
768+
if (options.getLogprobs() != null) {
769+
generationConfigBuilder.setLogprobs(options.getLogprobs());
770+
}
771+
generationConfigBuilder.setResponseLogprobs(options.getResponseLogprobs());
747772

748773
return generationConfigBuilder.build();
749774
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
6464
*/
6565
private @JsonProperty("temperature") Double temperature;
6666

67+
/**
68+
* Optional. Enable returning the log probabilities of the top candidate tokens at each generation step.
69+
* The model's chosen token might not be the same as the top candidate token at each step.
70+
* Specify the number of candidates to return by using an integer value in the range of 1-20.
71+
* Should not be set unless responseLogprobs is set to true.
72+
*/
73+
private @JsonProperty("logprobs") Integer logprobs;
74+
75+
/**
76+
* Optional. If true, returns the log probabilities of the tokens that were chosen by the model at each step.
77+
* By default, this parameter is set to false.
78+
*/
79+
private @JsonProperty("responseLogprobs") boolean responseLogprobs;
80+
6781
/**
6882
* Optional. If specified, nucleus sampling will be used.
6983
*/
@@ -162,6 +176,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
162176
options.setSafetySettings(fromOptions.getSafetySettings());
163177
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
164178
options.setToolContext(fromOptions.getToolContext());
179+
options.setLogprobs(fromOptions.getLogprobs());
180+
options.setResponseLogprobs(fromOptions.getResponseLogprobs());
165181
return options;
166182
}
167183

@@ -183,6 +199,10 @@ public void setTemperature(Double temperature) {
183199
this.temperature = temperature;
184200
}
185201

202+
public void setResponseLogprobs(boolean responseLogprobs) {
203+
this.responseLogprobs = responseLogprobs;
204+
}
205+
186206
@Override
187207
public Double getTopP() {
188208
return this.topP;
@@ -326,6 +346,18 @@ public void setToolContext(Map<String, Object> toolContext) {
326346
this.toolContext = toolContext;
327347
}
328348

349+
public Integer getLogprobs() {
350+
return logprobs;
351+
}
352+
353+
public void setLogprobs(Integer logprobs) {
354+
this.logprobs = logprobs;
355+
}
356+
357+
public boolean getResponseLogprobs() {
358+
return responseLogprobs;
359+
}
360+
329361
@Override
330362
public boolean equals(Object o) {
331363
if (this == o) {
@@ -346,15 +378,16 @@ public boolean equals(Object o) {
346378
&& Objects.equals(this.toolNames, that.toolNames)
347379
&& Objects.equals(this.safetySettings, that.safetySettings)
348380
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
349-
&& Objects.equals(this.toolContext, that.toolContext);
381+
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.logprobs, that.logprobs)
382+
&& Objects.equals(this.responseLogprobs, that.responseLogprobs);
350383
}
351384

352385
@Override
353386
public int hashCode() {
354387
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
355388
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
356389
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
357-
this.internalToolExecutionEnabled, this.toolContext);
390+
this.internalToolExecutionEnabled, this.toolContext, this.logprobs, this.responseLogprobs);
358391
}
359392

360393
@Override
@@ -365,7 +398,8 @@ public String toString() {
365398
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
366399
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
367400
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
368-
+ ", safetySettings=" + this.safetySettings + '}';
401+
+ ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + ", responseLogprobs="
402+
+ this.responseLogprobs + '}';
369403
}
370404

371405
@Override
@@ -488,6 +522,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
488522
return this;
489523
}
490524

525+
public Builder logprobs(Integer logprobs) {
526+
this.options.setLogprobs(logprobs);
527+
return this;
528+
}
529+
530+
public Builder responseLogprobs(Boolean responseLogprobs) {
531+
this.options.setResponseLogprobs(responseLogprobs);
532+
return this;
533+
}
534+
491535
public VertexAiGeminiChatOptions build() {
492536
return this.options;
493537
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.vertexai.gemini.api;
17+
18+
import java.util.List;
19+
20+
public class VertexAiGeminiApi {
21+
22+
public record LogProbs(Double avgLogprobs, List<TopContent> topCandidates,
23+
List<LogProbs.Content> chosenCandidates) {
24+
public record Content(String token, Float logprob, Integer id) {
25+
}
26+
27+
public record TopContent(List<Content> candidates) {
28+
}
29+
}
30+
31+
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ public void createRequestWithGenerationConfigOptions() {
262262
.stopSequences(List.of("stop1", "stop2"))
263263
.candidateCount(1)
264264
.responseMimeType("application/json")
265+
.responseLogprobs(true)
266+
.logprobs(2)
265267
.build())
266268
.build();
267269

@@ -280,6 +282,8 @@ public void createRequestWithGenerationConfigOptions() {
280282
assertThat(request.model().getGenerationConfig().getStopSequences(0)).isEqualTo("stop1");
281283
assertThat(request.model().getGenerationConfig().getStopSequences(1)).isEqualTo("stop2");
282284
assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json");
285+
assertThat(request.model().getGenerationConfig().getLogprobs()).isEqualTo(2);
286+
assertThat(request.model().getGenerationConfig().getResponseLogprobs()).isEqualTo(true);
283287
}
284288

285289
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.springframework.ai.model.tool.ToolCallingManager;
4747
import org.springframework.ai.tool.annotation.Tool;
4848
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
49+
import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi;
4950
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
5051
import org.springframework.beans.factory.annotation.Autowired;
5152
import org.springframework.beans.factory.annotation.Value;
@@ -226,6 +227,26 @@ void textStream() {
226227
assertThat(generationTextFromStream).isNotEmpty();
227228
}
228229

230+
@Test
231+
void logprobs() {
232+
VertexAiGeminiChatOptions chatOptions = VertexAiGeminiChatOptions.builder()
233+
.logprobs(1)
234+
.responseLogprobs(true)
235+
.build();
236+
237+
var logprobs = (VertexAiGeminiApi.LogProbs) this.chatModel
238+
.call(new Prompt("Explain Bulgaria? Answer in 10 paragraphs.", chatOptions))
239+
.getResult()
240+
.getOutput()
241+
.getMetadata()
242+
.get("logprobs");
243+
244+
assertThat(logprobs).isNotNull();
245+
assertThat(logprobs.avgLogprobs()).isNotZero();
246+
assertThat(logprobs.topCandidates()).isNotEmpty();
247+
assertThat(logprobs.chosenCandidates()).isNotEmpty();
248+
}
249+
229250
@Test
230251
void beanStreamOutputConverterRecords() {
231252

0 commit comments

Comments
 (0)