Skip to content

Commit d67a3cd

Browse files
alxkmmarkpollack
authored andcommitted
Enhance test coverage across Spring AI modules with comprehensive edge cases
Co-authored-by: Oleksandr Klymenko <[email protected]> Signed-off-by: Oleksandr Klymenko <[email protected]> Auto-cherry-pick to 1.0.x Fixes #4197
1 parent 35486e9 commit d67a3cd

File tree

8 files changed

+710
-0
lines changed

8 files changed

+710
-0
lines changed

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.ollama.api.OllamaOptions;
3838

3939
import static org.assertj.core.api.Assertions.assertThat;
40+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4041
import static org.mockito.BDDMockito.given;
4142

4243
/**
@@ -115,4 +116,143 @@ public void options() {
115116

116117
}
117118

119+
@Test
120+
public void singleInputEmbedding() {
121+
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
122+
.willReturn(new EmbeddingsResponse("TEST_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f }), 10L, 5L, 1));
123+
124+
var embeddingModel = OllamaEmbeddingModel.builder()
125+
.ollamaApi(this.ollamaApi)
126+
.defaultOptions(OllamaOptions.builder().model("TEST_MODEL").build())
127+
.build();
128+
129+
EmbeddingResponse response = embeddingModel
130+
.call(new EmbeddingRequest(List.of("Single input text"), EmbeddingOptionsBuilder.builder().build()));
131+
132+
assertThat(response.getResults()).hasSize(1);
133+
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
134+
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 0.1f, 0.2f, 0.3f });
135+
assertThat(response.getMetadata().getModel()).isEqualTo("TEST_MODEL");
136+
137+
assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Single input text"));
138+
assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("TEST_MODEL");
139+
}
140+
141+
@Test
142+
public void embeddingWithNullOptions() {
143+
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
144+
.willReturn(new EmbeddingsResponse("NULL_OPTIONS_MODEL", List.of(new float[] { 0.5f }), 5L, 2L, 1));
145+
146+
var embeddingModel = OllamaEmbeddingModel.builder()
147+
.ollamaApi(this.ollamaApi)
148+
.defaultOptions(OllamaOptions.builder().model("NULL_OPTIONS_MODEL").build())
149+
.build();
150+
151+
EmbeddingResponse response = embeddingModel.call(new EmbeddingRequest(List.of("Null options test"), null));
152+
153+
assertThat(response.getResults()).hasSize(1);
154+
assertThat(response.getMetadata().getModel()).isEqualTo("NULL_OPTIONS_MODEL");
155+
156+
assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("NULL_OPTIONS_MODEL");
157+
assertThat(this.embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of());
158+
}
159+
160+
@Test
161+
public void embeddingWithMultipleLargeInputs() {
162+
List<String> largeInputs = List.of(
163+
"This is a very long text input that might be used for document embedding scenarios",
164+
"Another substantial piece of text content that could represent a paragraph or section",
165+
"A third lengthy input to test batch processing capabilities of the embedding model");
166+
167+
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
168+
.willReturn(new EmbeddingsResponse(
169+
"BATCH_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f, 0.4f },
170+
new float[] { 0.5f, 0.6f, 0.7f, 0.8f }, new float[] { 0.9f, 1.0f, 1.1f, 1.2f }),
171+
150L, 75L, 3));
172+
173+
var embeddingModel = OllamaEmbeddingModel.builder()
174+
.ollamaApi(this.ollamaApi)
175+
.defaultOptions(OllamaOptions.builder().model("BATCH_MODEL").build())
176+
.build();
177+
178+
EmbeddingResponse response = embeddingModel
179+
.call(new EmbeddingRequest(largeInputs, EmbeddingOptionsBuilder.builder().build()));
180+
181+
assertThat(response.getResults()).hasSize(3);
182+
assertThat(response.getResults().get(0).getOutput()).hasSize(4);
183+
assertThat(response.getResults().get(1).getOutput()).hasSize(4);
184+
assertThat(response.getResults().get(2).getOutput()).hasSize(4);
185+
186+
assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(largeInputs);
187+
}
188+
189+
@Test
190+
public void embeddingWithCustomKeepAliveFormats() {
191+
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
192+
.willReturn(new EmbeddingsResponse("KEEPALIVE_MODEL", List.of(new float[] { 1.0f }), 5L, 2L, 1));
193+
194+
var embeddingModel = OllamaEmbeddingModel.builder()
195+
.ollamaApi(this.ollamaApi)
196+
.defaultOptions(OllamaOptions.builder().model("KEEPALIVE_MODEL").build())
197+
.build();
198+
199+
// Test with seconds format
200+
var secondsOptions = OllamaOptions.builder().model("KEEPALIVE_MODEL").keepAlive("300s").build();
201+
202+
embeddingModel.call(new EmbeddingRequest(List.of("Keep alive seconds"), secondsOptions));
203+
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofSeconds(300));
204+
205+
// Test with hours format
206+
var hoursOptions = OllamaOptions.builder().model("KEEPALIVE_MODEL").keepAlive("2h").build();
207+
208+
embeddingModel.call(new EmbeddingRequest(List.of("Keep alive hours"), hoursOptions));
209+
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofHours(2));
210+
}
211+
212+
@Test
213+
public void embeddingResponseMetadata() {
214+
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
215+
.willReturn(new EmbeddingsResponse("METADATA_MODEL", List.of(new float[] { 0.1f, 0.2f }), 100L, 50L, 25));
216+
217+
var embeddingModel = OllamaEmbeddingModel.builder()
218+
.ollamaApi(this.ollamaApi)
219+
.defaultOptions(OllamaOptions.builder().model("METADATA_MODEL").build())
220+
.build();
221+
222+
EmbeddingResponse response = embeddingModel
223+
.call(new EmbeddingRequest(List.of("Metadata test"), EmbeddingOptionsBuilder.builder().build()));
224+
225+
assertThat(response.getMetadata().getModel()).isEqualTo("METADATA_MODEL");
226+
assertThat(response.getResults()).hasSize(1);
227+
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
228+
}
229+
230+
@Test
231+
public void embeddingWithZeroLengthVectors() {
232+
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
233+
.willReturn(new EmbeddingsResponse("ZERO_MODEL", List.of(new float[] {}), 0L, 0L, 1));
234+
235+
var embeddingModel = OllamaEmbeddingModel.builder()
236+
.ollamaApi(this.ollamaApi)
237+
.defaultOptions(OllamaOptions.builder().model("ZERO_MODEL").build())
238+
.build();
239+
240+
EmbeddingResponse response = embeddingModel
241+
.call(new EmbeddingRequest(List.of("Zero length test"), EmbeddingOptionsBuilder.builder().build()));
242+
243+
assertThat(response.getResults()).hasSize(1);
244+
assertThat(response.getResults().get(0).getOutput()).isEmpty();
245+
}
246+
247+
@Test
248+
public void builderValidation() {
249+
// Test that builder requires ollamaApi
250+
assertThatThrownBy(() -> OllamaEmbeddingModel.builder().build()).isInstanceOf(IllegalArgumentException.class);
251+
252+
// Test successful builder with minimal required parameters
253+
var model = OllamaEmbeddingModel.builder().ollamaApi(this.ollamaApi).build();
254+
255+
assertThat(model).isNotNull();
256+
}
257+
118258
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
11
package org.springframework.ai.ollama;
22

33
import java.time.Instant;
4+
import java.util.List;
45

56
import org.junit.jupiter.api.BeforeEach;
67
import org.junit.jupiter.api.Test;
78
import org.junit.jupiter.api.extension.ExtendWith;
89
import org.mockito.Mock;
910
import org.mockito.junit.jupiter.MockitoExtension;
1011

12+
import org.springframework.ai.chat.messages.Message;
13+
import org.springframework.ai.chat.messages.UserMessage;
1114
import org.springframework.ai.chat.prompt.Prompt;
1215
import org.springframework.ai.ollama.api.OllamaApi;
1316
import org.springframework.ai.ollama.api.OllamaModel;
1417
import org.springframework.ai.ollama.api.OllamaOptions;
18+
import org.springframework.ai.retry.NonTransientAiException;
1519
import org.springframework.ai.retry.RetryUtils;
1620
import org.springframework.ai.retry.TransientAiException;
1721
import org.springframework.retry.RetryCallback;
1822
import org.springframework.retry.RetryContext;
1923
import org.springframework.retry.RetryListener;
2024
import org.springframework.retry.support.RetryTemplate;
25+
import org.springframework.web.client.ResourceAccessException;
2126

2227
import static org.assertj.core.api.Assertions.assertThat;
28+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2329
import static org.mockito.ArgumentMatchers.isA;
30+
import static org.mockito.Mockito.times;
31+
import static org.mockito.Mockito.verify;
2432
import static org.mockito.Mockito.when;
2533

2634
/**
@@ -75,6 +83,101 @@ void ollamaChatTransientError() {
7583
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
7684
}
7785

86+
@Test
87+
void ollamaChatSuccessOnFirstAttempt() {
88+
String promptText = "Simple question";
89+
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
90+
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Quick response").build(), null,
91+
true, null, null, null, null, null, null);
92+
93+
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class))).thenReturn(expectedChatResponse);
94+
95+
var result = this.chatModel.call(new Prompt(promptText));
96+
97+
assertThat(result).isNotNull();
98+
assertThat(result.getResult().getOutput().getText()).isEqualTo("Quick response");
99+
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0);
100+
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(0);
101+
verify(this.ollamaApi, times(1)).chat(isA(OllamaApi.ChatRequest.class));
102+
}
103+
104+
@Test
105+
void ollamaChatNonTransientErrorShouldNotRetry() {
106+
String promptText = "Invalid request";
107+
108+
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
109+
.thenThrow(new NonTransientAiException("Model not found"));
110+
111+
assertThatThrownBy(() -> this.chatModel.call(new Prompt(promptText)))
112+
.isInstanceOf(NonTransientAiException.class)
113+
.hasMessage("Model not found");
114+
115+
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0);
116+
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(1);
117+
verify(this.ollamaApi, times(1)).chat(isA(OllamaApi.ChatRequest.class));
118+
}
119+
120+
@Test
121+
void ollamaChatWithMultipleMessages() {
122+
List<Message> messages = List.of(new UserMessage("What is AI?"), new UserMessage("Explain machine learning"));
123+
Prompt prompt = new Prompt(messages);
124+
125+
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
126+
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT)
127+
.content("AI is artificial intelligence...")
128+
.build(),
129+
null, true, null, null, null, null, null, null);
130+
131+
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
132+
.thenThrow(new TransientAiException("Temporary overload"))
133+
.thenReturn(expectedChatResponse);
134+
135+
var result = this.chatModel.call(prompt);
136+
137+
assertThat(result).isNotNull();
138+
assertThat(result.getResult().getOutput().getText()).isEqualTo("AI is artificial intelligence...");
139+
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
140+
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(1);
141+
}
142+
143+
@Test
144+
void ollamaChatWithCustomOptions() {
145+
String promptText = "Custom temperature request";
146+
OllamaOptions customOptions = OllamaOptions.builder().model(MODEL).temperature(0.1).topP(0.9).build();
147+
148+
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
149+
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Deterministic response").build(),
150+
null, true, null, null, null, null, null, null);
151+
152+
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
153+
.thenThrow(new ResourceAccessException("Connection timeout"))
154+
.thenReturn(expectedChatResponse);
155+
156+
var result = this.chatModel.call(new Prompt(promptText, customOptions));
157+
158+
assertThat(result).isNotNull();
159+
assertThat(result.getResult().getOutput().getText()).isEqualTo("Deterministic response");
160+
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
161+
}
162+
163+
@Test
164+
void ollamaChatWithEmptyResponse() {
165+
String promptText = "Edge case request";
166+
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
167+
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("").build(), null, true, null, null,
168+
null, null, null, null);
169+
170+
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
171+
.thenThrow(new TransientAiException("Rate limit exceeded"))
172+
.thenReturn(expectedChatResponse);
173+
174+
var result = this.chatModel.call(new Prompt(promptText));
175+
176+
assertThat(result).isNotNull();
177+
assertThat(result.getResult().getOutput().getText()).isEmpty();
178+
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
179+
}
180+
78181
private static class TestRetryListener implements RetryListener {
79182

80183
int onErrorRetryCount = 0;

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,19 @@ void dynamicApiKeyRestClient() throws InterruptedException {
171171
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
172172
}
173173

174+
@Test
175+
void testBuilderMethodsReturnNewInstances() {
176+
OpenAiModerationApi.Builder builder1 = OpenAiModerationApi.builder();
177+
OpenAiModerationApi.Builder builder2 = builder1.apiKey(TEST_API_KEY);
178+
OpenAiModerationApi.Builder builder3 = builder2.baseUrl(TEST_BASE_URL);
179+
180+
assertThat(builder2).isNotNull();
181+
assertThat(builder3).isNotNull();
182+
183+
OpenAiModerationApi api = builder3.build();
184+
assertThat(api).isNotNull();
185+
}
186+
174187
}
175188

176189
}

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,51 @@ public void defaultOptionsTools() {
107107
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
108108
}
109109

110+
@Test
111+
public void promptOptionsOverrideDefaultOptions() {
112+
var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(),
113+
ZhiPuAiChatOptions.builder().model("DEFAULT_MODEL").temperature(10.0).build());
114+
115+
var request = client.createRequest(new Prompt("Test", ZhiPuAiChatOptions.builder().temperature(90.0).build()),
116+
false);
117+
118+
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
119+
assertThat(request.temperature()).isEqualTo(90.0);
120+
}
121+
122+
@Test
123+
public void defaultOptionsToolsWithAssertion() {
124+
final String TOOL_FUNCTION_NAME = "CurrentWeather";
125+
126+
var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(),
127+
ZhiPuAiChatOptions.builder()
128+
.model("DEFAULT_MODEL")
129+
.toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService())
130+
.description("Get the weather in location")
131+
.inputType(MockWeatherService.Request.class)
132+
.build()))
133+
.build());
134+
135+
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));
136+
var request = client.createRequest(prompt, false);
137+
138+
assertThat(request.messages()).hasSize(1);
139+
assertThat(request.stream()).isFalse();
140+
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
141+
assertThat(request.tools()).hasSize(1);
142+
assertThat(request.tools().get(0).getFunction().getName()).isEqualTo(TOOL_FUNCTION_NAME);
143+
}
144+
145+
@Test
146+
public void createRequestWithStreamingEnabled() {
147+
var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(),
148+
ZhiPuAiChatOptions.builder().model("DEFAULT_MODEL").build());
149+
150+
var prompt = client.buildRequestPrompt(new Prompt("Test streaming"));
151+
var request = client.createRequest(prompt, true);
152+
153+
assertThat(request.stream()).isTrue();
154+
assertThat(request.messages()).hasSize(1);
155+
}
156+
110157
}

0 commit comments

Comments
 (0)