Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.http.client.ClientHttpRequestFactoryBuilder;
import org.springframework.boot.http.client.HttpClientSettings;
import org.springframework.boot.http.client.autoconfigure.HttpClientSettingsPropertyMapper;
import org.springframework.boot.http.client.reactive.ClientHttpConnectorBuilder;
import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.reactive.ClientHttpConnector;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -65,15 +72,30 @@ public class AnthropicChatAutoConfiguration {
@ConditionalOnMissingBean
public AnthropicApi anthropicApi(AnthropicConnectionProperties connectionProperties,
ObjectProvider<RestClient.Builder> restClientBuilderProvider,
ObjectProvider<WebClient.Builder> webClientBuilderProvider, ResponseErrorHandler responseErrorHandler) {
ObjectProvider<WebClient.Builder> webClientBuilderProvider, ResponseErrorHandler responseErrorHandler,
ObjectProvider<SslBundles> sslBundles, ObjectProvider<HttpClientSettings> globalHttpClientSettings,
ObjectProvider<ClientHttpRequestFactoryBuilder<?>> factoryBuilder,
ObjectProvider<ClientHttpConnectorBuilder<?>> webConnectorBuilderProvider) {

HttpClientSettingsPropertyMapper mapper = new HttpClientSettingsPropertyMapper(sslBundles.getIfAvailable(),
globalHttpClientSettings.getIfAvailable());
HttpClientSettings httpClientSettings = mapper.map(connectionProperties.getHttp());

RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder);
applyRestClientSettings(restClientBuilder, httpClientSettings,
factoryBuilder.getIfAvailable(ClientHttpRequestFactoryBuilder::detect));

WebClient.Builder webClientBuilder = webClientBuilderProvider.getIfAvailable(WebClient::builder);
applyWebClientSettings(webClientBuilder, httpClientSettings,
webConnectorBuilderProvider.getIfAvailable(ClientHttpConnectorBuilder::detect));

return AnthropicApi.builder()
.baseUrl(connectionProperties.getBaseUrl())
.completionsPath(connectionProperties.getCompletionsPath())
.apiKey(connectionProperties.getApiKey())
.anthropicVersion(connectionProperties.getVersion())
.restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder))
.webClientBuilder(webClientBuilderProvider.getIfAvailable(WebClient::builder))
.restClientBuilder(restClientBuilder)
.webClientBuilder(webClientBuilder)
.responseErrorHandler(responseErrorHandler)
.anthropicBetaFeatures(connectionProperties.getBetaVersion())
.build();
Expand Down Expand Up @@ -102,4 +124,16 @@ public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, Anthropi
return chatModel;
}

private void applyRestClientSettings(RestClient.Builder builder, HttpClientSettings httpClientSettings,
ClientHttpRequestFactoryBuilder<?> factoryBuilder) {
ClientHttpRequestFactory requestFactory = factoryBuilder.build(httpClientSettings);
builder.requestFactory(requestFactory);
}

private void applyWebClientSettings(WebClient.Builder builder, HttpClientSettings httpClientSettings,
ClientHttpConnectorBuilder<?> connectorBuilder) {
ClientHttpConnector connector = connectorBuilder.build(httpClientSettings);
builder.clientConnector(connector);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@

package org.springframework.ai.model.anthropic.autoconfigure;

import java.time.Duration;

import jakarta.annotation.Nullable;

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.boot.http.client.HttpRedirects;
import org.springframework.boot.http.client.autoconfigure.HttpClientSettingsProperties;

/**
* Anthropic API connection properties.
Expand Down Expand Up @@ -56,6 +63,10 @@ public class AnthropicConnectionProperties {
*/
private String betaVersion = AnthropicApi.DEFAULT_ANTHROPIC_BETA_VERSION;

@NestedConfigurationProperty
private final HttpClientSettingsProperties http = new HttpClientSettingsProperties() {
};

public String getApiKey() {
return this.apiKey;
}
Expand Down Expand Up @@ -96,4 +107,39 @@ public void setBetaVersion(String betaVersion) {
this.betaVersion = betaVersion;
}

@Nullable
public HttpRedirects getRedirects() {
return this.http.getRedirects();
}

public void setRedirects(HttpRedirects redirects) {
this.http.setRedirects(redirects);
}

@Nullable
public Duration getConnectTimeout() {
return this.http.getConnectTimeout();
}

public void setConnectTimeout(Duration connectTimeout) {
this.http.setConnectTimeout(connectTimeout);
}

@Nullable
public Duration getReadTimeout() {
return this.http.getReadTimeout();
}

public void setReadTimeout(Duration readTimeout) {
this.http.setReadTimeout(readTimeout);
}

public HttpClientSettingsProperties.Ssl getSsl() {
return this.http.getSsl();
}

public HttpClientSettingsProperties getHttp() {
return this.http;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package org.springframework.ai.model.anthropic.autoconfigure;

import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -91,4 +93,51 @@ void stream() {
});
}

@Test
void generateWithCustomTimeout() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.anthropic.apiKey=" + System.getenv("ANTHROPIC_API_KEY"),
"spring.ai.deepseek.connect-timeout=1ms", "spring.ai.deepseek.read-timeout=1ms")
.withConfiguration(SpringAiTestAutoConfigurations.of(AnthropicChatAutoConfiguration.class))
.run(context -> {
AnthropicChatModel client = context.getBean(AnthropicChatModel.class);

// Verify that the HTTP client configuration is applied
var connectionProperties = context.getBean(AnthropicConnectionProperties.class);
assertThat(connectionProperties.getConnectTimeout()).isEqualTo(Duration.ofMillis(1));
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMillis(1));

// Verify that the client can actually make requests with the configured
// timeout
String response = client.call("Hello");
assertThat(response).isNotEmpty();
logger.info("Response with custom timeout: " + response);
});
}

@Test
void generateStreamingWithCustomTimeout() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.deepseek.apiKey=" + "sk-2567813d742c40e79fa6f1f2ee2f830c",
"spring.ai.deepseek.connect-timeout=1s", "spring.ai.deepseek.read-timeout=1s")
.withConfiguration(SpringAiTestAutoConfigurations.of(AnthropicChatAutoConfiguration.class))
.run(context -> {
AnthropicChatModel client = context.getBean(AnthropicChatModel.class);

// Verify that the HTTP client configuration is applied
var connectionProperties = context.getBean(AnthropicConnectionProperties.class);
assertThat(connectionProperties.getConnectTimeout()).isEqualTo(Duration.ofSeconds(1));
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofSeconds(1));

Flux<ChatResponse> responseFlux = client.stream(new Prompt(new UserMessage("Hello")));
String response = Objects.requireNonNull(responseFlux.collectList().block())
.stream()
.map(chatResponse -> chatResponse.getResults().get(0).getOutput().getText())
.collect(Collectors.joining());

assertThat(response).isNotEmpty();
logger.info("Response with custom timeout: " + response);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.KeyCredential;
import com.azure.core.util.ClientOptions;
import com.azure.core.util.Header;
import com.azure.core.util.HttpClientOptions;
import com.azure.identity.DefaultAzureCredentialBuilder;

import org.springframework.beans.factory.ObjectProvider;
Expand Down Expand Up @@ -56,23 +56,17 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c

final OpenAIClientBuilder clientBuilder;

HttpClientOptions clientOptions = createHttpClientOptions(connectionProperties);

// Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is
// used as OpenAI model name.
if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) {
clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
.credential(new KeyCredential(connectionProperties.getOpenAiApiKey()))
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
.clientOptions(clientOptions);
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
return clientBuilder;
}

Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
List<Header> headers = customHeaders.entrySet()
.stream()
.map(entry -> new Header(entry.getKey(), entry.getValue()))
.collect(Collectors.toList());
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);

Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");

if (!StringUtils.hasText(connectionProperties.getApiKey())) {
Expand All @@ -96,4 +90,44 @@ private void applyOpenAIClientBuilderCustomizers(OpenAIClientBuilder clientBuild
customizers.orderedStream().forEach(customizer -> customizer.customize(clientBuilder));
}

/**
* Create HttpClientOptions
*/
private HttpClientOptions createHttpClientOptions(AzureOpenAiConnectionProperties connectionProperties) {
// Create HttpClientOptions and apply the configuration
HttpClientOptions options = new HttpClientOptions();

options.setApplicationId(APPLICATION_ID);

Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
List<Header> headers = customHeaders.entrySet()
.stream()
.map(entry -> new Header(entry.getKey(), entry.getValue()))
.collect(Collectors.toList());

options.setHeaders(headers);

if (connectionProperties.getConnectTimeout() != null) {
options.setConnectTimeout(connectionProperties.getConnectTimeout());
}

if (connectionProperties.getReadTimeout() != null) {
options.setReadTimeout(connectionProperties.getReadTimeout());
}

if (connectionProperties.getWriteTimeout() != null) {
options.setWriteTimeout(connectionProperties.getWriteTimeout());
}

if (connectionProperties.getResponseTimeout() != null) {
options.setResponseTimeout(connectionProperties.getResponseTimeout());
}

if (connectionProperties.getMaximumConnectionPoolSize() != null) {
options.setMaximumConnectionPoolSize(connectionProperties.getMaximumConnectionPoolSize());
}

return options;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.model.azure.openai.autoconfigure;

import java.time.Duration;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -46,6 +47,31 @@ public class AzureOpenAiConnectionProperties {

private Map<String, String> customHeaders = new HashMap<>();

/**
* HTTP connection timeout
*/
private Duration connectTimeout;

/**
* HTTP read timeout
*/
private Duration readTimeout;

/**
* HTTP write timeout
*/
private Duration writeTimeout;

/**
* HTTP response timeout
*/
private Duration responseTimeout;

/**
* The maximum number of connections in the HTTP connection pool
*/
private Integer maximumConnectionPoolSize;

public String getEndpoint() {
return this.endpoint;
}
Expand Down Expand Up @@ -78,4 +104,44 @@ public void setCustomHeaders(Map<String, String> customHeaders) {
this.customHeaders = customHeaders;
}

public Duration getConnectTimeout() {
return this.connectTimeout;
}

public void setConnectTimeout(Duration connectTimeout) {
this.connectTimeout = connectTimeout;
}

public Duration getReadTimeout() {
return this.readTimeout;
}

public void setReadTimeout(Duration readTimeout) {
this.readTimeout = readTimeout;
}

public Duration getWriteTimeout() {
return this.writeTimeout;
}

public void setWriteTimeout(Duration writeTimeout) {
this.writeTimeout = writeTimeout;
}

public Duration getResponseTimeout() {
return this.responseTimeout;
}

public void setResponseTimeout(Duration responseTimeout) {
this.responseTimeout = responseTimeout;
}

public Integer getMaximumConnectionPoolSize() {
return this.maximumConnectionPoolSize;
}

public void setMaximumConnectionPoolSize(Integer maximumConnectionPoolSize) {
this.maximumConnectionPoolSize = maximumConnectionPoolSize;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.springframework.util.ReflectionUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* @author Christian Tzolov
Expand Down Expand Up @@ -288,4 +289,15 @@ void openAIClientBuilderCustomizer() {
});
}

@Test
void connectTimeoutShouldTakeEffect() {
new ApplicationContextRunner().withPropertyValues("spring.ai.azure.openai.connect-timeout=1ms")
.withConfiguration(SpringAiTestAutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
.run(context -> {
AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);

assertThatThrownBy(() -> chatModel.call(new Prompt("Hello"))).isInstanceOf(Exception.class);
});
}

}
Loading