From 01ccb0d0e3f10d5e83b0622a348dee6c1ce84ab0 Mon Sep 17 00:00:00 2001 From: simhuang Date: Sat, 23 Dec 2023 01:08:41 +0800 Subject: [PATCH] stream client for chat completion, entities update and some junit --- pom.xml | 5 + .../ClientAutoConfiguration.java | 5 +- .../client/chat/ChatCompletionChunk.java | 13 ++ .../client/chat/ChatService.java | 12 +- .../client/chat/ChatServiceImpl.java | 42 ++++- .../client/chat/ChoiceData.java | 4 +- .../client/chat/Content.java | 11 ++ .../chat/CreateChatCompletionRequest.java | 19 ++- .../client/chat/Logprobs.java | 8 + .../client/chat/ChatServiceImplTest.java | 150 ++++++++++++++++++ 10 files changed, 260 insertions(+), 9 deletions(-) create mode 100644 src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatCompletionChunk.java create mode 100644 src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Content.java create mode 100644 src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Logprobs.java create mode 100644 src/test/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImplTest.java diff --git a/pom.xml b/pom.xml index eada6ca..a583295 100644 --- a/pom.xml +++ b/pom.xml @@ -67,6 +67,11 @@ reactor-test test + + org.mockito + mockito-core + test + diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/ClientAutoConfiguration.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/ClientAutoConfiguration.java index b8345c7..d371178 100644 --- a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/ClientAutoConfiguration.java +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/ClientAutoConfiguration.java @@ -1,5 +1,6 @@ package io.github.reactiveclown.openaiwebfluxclient; +import com.fasterxml.jackson.databind.ObjectMapper; import io.github.reactiveclown.openaiwebfluxclient.client.audio.AudioService; import io.github.reactiveclown.openaiwebfluxclient.client.audio.AudioServiceImpl; import io.github.reactiveclown.openaiwebfluxclient.client.chat.ChatService; @@ -78,8 +79,8 @@ public AudioService audioService(@Qualifier("OpenAIClient") WebClient client) { @Bean @ConditionalOnMissingBean - public ChatService chatService(@Qualifier("OpenAIClient") WebClient client) { - return new ChatServiceImpl(client); + public ChatService chatService(@Qualifier("OpenAIClient") WebClient client, ObjectMapper objectMapper) { + return new ChatServiceImpl(client, objectMapper); } @Bean diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatCompletionChunk.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatCompletionChunk.java new file mode 100644 index 0000000..34f356e --- /dev/null +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatCompletionChunk.java @@ -0,0 +1,13 @@ +package io.github.reactiveclown.openaiwebfluxclient.client.chat; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public record ChatCompletionChunk(@JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") String systemFingerprint, + @JsonProperty("choices") List choices) { +} diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatService.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatService.java index 700a908..806894d 100644 --- a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatService.java +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatService.java @@ -1,5 +1,6 @@ package io.github.reactiveclown.openaiwebfluxclient.client.chat; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public interface ChatService { @@ -8,7 +9,16 @@ public interface ChatService { * Creates a completion for the chat message. * * @param request {@link CreateChatCompletionRequest } - * @return A Mono of {@link CreateChatCompletionResponse} + * @return A {@link Mono} of {@link CreateChatCompletionResponse} */ Mono createChatCompletion(CreateChatCompletionRequest request); + + /** + * Creates a completion for the chat message, but with stream. + * The method returns a Flux with chucks of the chat completion response. + * + * @param request {@link CreateChatCompletionRequest } + * @return A {@link Flux} of {@link ChatCompletionChunk} + */ + Flux createStreamChatCompletion(CreateChatCompletionRequest request); } diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImpl.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImpl.java index bbb6fa4..58740eb 100644 --- a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImpl.java +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImpl.java @@ -1,24 +1,60 @@ package io.github.reactiveclown.openaiwebfluxclient.client.chat; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.stereotype.Service; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @Service public class ChatServiceImpl implements ChatService{ private final WebClient client; - public ChatServiceImpl(WebClient client){ + private final ObjectMapper objectMapper; + private static final String CREATE_CHAT_COMPLETION_URL = "/chat/completions"; + public ChatServiceImpl(WebClient client, ObjectMapper objectMapper){ this.client = client; + this.objectMapper = objectMapper; } @Override public Mono createChatCompletion(CreateChatCompletionRequest request) { - String createChatCompletionUrl = "/chat/completions"; return client.post() - .uri(createChatCompletionUrl) + .uri(CREATE_CHAT_COMPLETION_URL) .bodyValue(request) .retrieve() .bodyToMono(CreateChatCompletionResponse.class); } + + @Override + public Flux createStreamChatCompletion(CreateChatCompletionRequest request) { + if (request.stream() == null || !request.stream()) { + request = request.withStream(); + } + return client.post() + .uri(CREATE_CHAT_COMPLETION_URL) + .accept(MediaType.TEXT_EVENT_STREAM) + .bodyValue(request) + .retrieve() + // transfer to String first to handle the "[DONE]" + .bodyToFlux(new ParameterizedTypeReference>() { + }) + .flatMap(serverSentEvent -> { + String data = serverSentEvent.data(); + // ignore the done text + if (data == null || data.equals("[DONE]")) { + return Mono.empty(); + } + try { + ChatCompletionChunk parsedResponse = objectMapper.readValue(data, ChatCompletionChunk.class); + return Mono.justOrEmpty(parsedResponse); + } catch (JsonProcessingException e) { + return Mono.error(e); + } + }); + } } diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChoiceData.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChoiceData.java index f93bd7d..2da9dff 100644 --- a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChoiceData.java +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChoiceData.java @@ -1,9 +1,11 @@ package io.github.reactiveclown.openaiwebfluxclient.client.chat; +import com.fasterxml.jackson.annotation.JsonAlias; import com.fasterxml.jackson.annotation.JsonProperty; public record ChoiceData(@JsonProperty("index") Integer index, - @JsonProperty("message") MessageData message, + @JsonProperty("logprobs") Logprobs logprobs, + @JsonAlias("delta") @JsonProperty("message") MessageData message, @JsonProperty("finish_reason") String finishReason) { } diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Content.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Content.java new file mode 100644 index 0000000..7fa429e --- /dev/null +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Content.java @@ -0,0 +1,11 @@ +package io.github.reactiveclown.openaiwebfluxclient.client.chat; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public record Content(@JsonProperty("token") String token, + @JsonProperty("logprob") Integer logprob, + @JsonProperty("bytes") List bytes, + @JsonProperty("top_logprobs") List topLogprobs) { +} diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/CreateChatCompletionRequest.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/CreateChatCompletionRequest.java index 87da6a9..7b317ee 100644 --- a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/CreateChatCompletionRequest.java +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/CreateChatCompletionRequest.java @@ -58,7 +58,8 @@ public record CreateChatCompletionRequest(@JsonProperty("model") String model, @JsonProperty("presence_penalty") Double presencePenalty, @JsonProperty("frequency_penalty") Double frequencyPenalty, @JsonProperty("logit_bias") Map logitBias, - @JsonProperty("user") String user) { + @JsonProperty("user") String user, + @JsonProperty("stream") Boolean stream) { public CreateChatCompletionRequest { if (model == null || model.isBlank()) throw new IllegalArgumentException("model value can't be null or blank"); @@ -67,6 +68,14 @@ public record CreateChatCompletionRequest(@JsonProperty("model") String model, throw new IllegalArgumentException("messages can't be null or empty"); } + public CreateChatCompletionRequest withStream() { + return new CreateChatCompletionRequest( + model, messages, temperature, + topP, n, stop, maxTokens, + presencePenalty, frequencyPenalty, logitBias, + user, true); + } + public static Builder builder(String model, List messages) { return new Builder(model, messages); } @@ -83,13 +92,14 @@ public static final class Builder { private Double frequencyPenalty; private Map logitBias; private String user; + private Boolean stream; public CreateChatCompletionRequest build() { return new CreateChatCompletionRequest( model, messages, temperature, topP, n, stop, maxTokens, presencePenalty, frequencyPenalty, logitBias, - user); + user, stream); } public Builder(String model, List messages) { @@ -147,5 +157,10 @@ public Builder user(String user) { return this; } + public Builder stream(Boolean stream) { + this.stream = stream; + return this; + } + } } \ No newline at end of file diff --git a/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Logprobs.java b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Logprobs.java new file mode 100644 index 0000000..725fab8 --- /dev/null +++ b/src/main/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/Logprobs.java @@ -0,0 +1,8 @@ +package io.github.reactiveclown.openaiwebfluxclient.client.chat; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public record Logprobs(@JsonProperty("content") List content) { +} diff --git a/src/test/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImplTest.java b/src/test/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImplTest.java new file mode 100644 index 0000000..e72dc87 --- /dev/null +++ b/src/test/java/io/github/reactiveclown/openaiwebfluxclient/client/chat/ChatServiceImplTest.java @@ -0,0 +1,150 @@ +package io.github.reactiveclown.openaiwebfluxclient.client.chat; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.github.reactiveclown.openaiwebfluxclient.client.UsageData; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class ChatServiceImplTest { + + + @Mock + WebClient.RequestBodyUriSpec requestBodyUriSpec; + @Mock + WebClient.RequestBodySpec requestBodySpec; + @Mock + WebClient.RequestHeadersSpec requestHeadersSpec; + @Mock + WebClient.ResponseSpec responseSpec; + @Mock + private WebClient webClient; + + private ObjectMapper objectMapper; + + @InjectMocks + private ChatServiceImpl chatService; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL); + ReflectionTestUtils.setField(chatService, "objectMapper", objectMapper); + } + + @Test + public void createChatCompletion() { + // Arrange + CreateChatCompletionRequest request = CreateChatCompletionRequest + .builder( + "model", + List.of(new MessageData( + "role", + "content"))) + .stream(true) + .build(); + CreateChatCompletionResponse expectedResponse = new CreateChatCompletionResponse( + "id", + "object", + 1L, + "model", + List.of(new ChoiceData( + 1, + null, + new MessageData( + "role", + "content"), + "finishReason")), + new UsageData( + 1, + 1, + 2)); + when(webClient.post()).thenReturn(requestBodyUriSpec); + when(requestBodyUriSpec.uri(anyString())).thenReturn(requestBodySpec); + when(requestBodySpec.bodyValue(any())).thenReturn(requestHeadersSpec); + when(requestHeadersSpec.retrieve()).thenReturn(responseSpec); + when(responseSpec.bodyToMono(CreateChatCompletionResponse.class)).thenReturn(Mono.just(expectedResponse)); + + // Assert + StepVerifier.create(chatService.createChatCompletion(request)) + .expectNext(expectedResponse) + .verifyComplete(); + } + + @Test + public void createStreamChatCompletion() { + // Arrange + CreateChatCompletionRequest request = CreateChatCompletionRequest + .builder( + "model", + List.of(new MessageData( + "role", + "content"))) + .stream(true) + .build(); + when(webClient.post()).thenReturn(requestBodyUriSpec); + when(requestBodyUriSpec.uri(anyString())).thenReturn(requestBodySpec); + when(requestBodySpec.accept(any(MediaType.class))).thenReturn(requestBodySpec); + when(requestBodySpec.bodyValue(any())).thenReturn(requestHeadersSpec); + when(requestHeadersSpec.retrieve()).thenReturn(responseSpec); + + // case [DONE] + ServerSentEvent mockEvent = ServerSentEvent.builder("[DONE]").build(); + Flux> responseFlux = Flux.just(mockEvent); + when(responseSpec.bodyToFlux(new ParameterizedTypeReference>() { + })) + .thenReturn(responseFlux); + + // Assert + StepVerifier.create(chatService.createStreamChatCompletion(request)) + .expectNextCount(0) // Since the data is "[DONE]", it should return an empty flux + .verifyComplete(); + + // case delta + mockEvent = ServerSentEvent.builder(""" + { + "id": "chatcmpl-mock", + "object": "chat.completion.chunk", + "created": 1703257917, + "model": "gpt-3.5-turbo-0613", + "system_fingerprint": null, + "choices": [ + { + "index": 0, + "delta": { + "content": "mock" + }, + "logprobs": null, + "finish_reason": null + } + ] + }""").build(); + responseFlux = Flux.just(mockEvent); + when(responseSpec.bodyToFlux(new ParameterizedTypeReference>() { + })) + .thenReturn(responseFlux); + + // Assert + StepVerifier.create(chatService.createStreamChatCompletion(request)) + .expectNextMatches(chatCompletionChunk -> chatCompletionChunk.choices().get(0).message().content().equals("mock")) + .verifyComplete(); + } +} \ No newline at end of file