Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/client_backend/client_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,13 @@ class InferResult {
return Error("InferResult::IsNullResponse() not implemented");
};

/// Get stream response bool for this response.
/// \return Error object indicating the success or failure.
virtual Error IsStreamResponse(bool* is_stream_response) const
{
return Error("InferReuslt::IsStreamRsponse() not implemented");
};

/// Returns the response timestamps of the streaming request.
/// \return Error object indicating the success or failure.
virtual Error ResponseTimestamps(
Expand Down
2 changes: 1 addition & 1 deletion src/client_backend/openai/openai_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ChatCompletionRequest::SendResponse(bool is_final, bool is_null)
{
final_response_sent_ = is_final;
response_callback_(new ChatCompletionResult(
http_code_, std::move(response_buffer_), is_final, is_null, request_id_));
http_code_, std::move(response_buffer_), is_final, is_null, is_stream_, request_id_));
}

ChatCompletionClient::ChatCompletionClient(
Expand Down
14 changes: 12 additions & 2 deletions src/client_backend/openai/openai_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class ChatCompletionResult : public InferResult {
public:
ChatCompletionResult(
uint32_t http_code, std::string&& serialized_response, bool is_final,
bool is_null, const std::string& request_id)
bool is_null, bool is_stream, const std::string& request_id)
: http_code_(http_code),
serialized_response_(std::move(serialized_response)),
is_final_(is_final), is_null_(is_null), request_id_(request_id)
is_final_(is_final), is_null_(is_null), is_stream_(is_stream),
request_id_(request_id)
{
}
virtual ~ChatCompletionResult() = default;
Expand Down Expand Up @@ -99,11 +100,20 @@ class ChatCompletionResult : public InferResult {
return Error::Success;
};

/// Get stream response bool for this response.
/// \return Error object indicating the success or failure.
Error IsStreamResponse(bool* is_stream_response) const override
{
*is_stream_response = is_stream_;
return Error::Success;
};

private:
const uint32_t http_code_{200};
const std::string serialized_response_;
const bool is_final_{false};
const bool is_null_{false};
const bool is_stream_{false};
const std::string request_id_;
};

Expand Down
62 changes: 61 additions & 1 deletion src/session_concurrency/payload_json_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include <stdexcept>
#include <string>
#include <vector>

#include "../rapidjson_utils.h"

Expand Down Expand Up @@ -95,14 +96,73 @@ PayloadJsonUtils::ValidatePayloadMessages(
}
}

void
PayloadJsonUtils::UpdateContent(
rapidjson::Value& item,
std::string& buffer,
rapidjson::Document::AllocatorType& allocator)
{
std::string c = std::string(item["content"].GetString()) + buffer;
item["content"].SetString(c.c_str(), c.size(), allocator);
}

void
PayloadJsonUtils::SetPayloadToChatHistory(
rapidjson::Document& payload_document,
const rapidjson::Document& chat_history)
{
auto& payload_messages{GetPayloadMessages(payload_document)};

payload_messages.CopyFrom(chat_history, payload_document.GetAllocator());
// Merge chunked responses in streaming mode.
rapidjson::Document merged_history{};
merged_history.Parse("[]");
auto& allocator = merged_history.GetAllocator();
std::vector<rapidjson::Value> values{};
std::string content_buffer{};
for (auto& h : chat_history.GetArray()) {
// This merge sequence assumes that:
// 1. the order of arrivals is preserved in chat_history,
// 2. for request payload and non-streaming response,
// each entry in chat_history includes the entire text which is not chunked,
// 3. for streaming response, each chunk has "role" field,
// but the value of chunks execpt for the first one is null,
// 4. each chunk doesn't have inconsistent value,
// that is, "role" and/or "function_call" field don't have
// different values for one sequence.
// (e.g., the situation, chunks[0]["role"]: "assistant" and chunks[1]["role"]: "user", never happens)
auto& role{h["role"]};

if (role.IsNull()) {
// Intermediate streaming chunks corresponding to one request.
content_buffer.append(h["content"].GetString());
} else {
std::string role_str{role.GetString()};

if (!content_buffer.empty()) {
auto& new_item = values.back();
UpdateContent(new_item, content_buffer, allocator);
content_buffer.clear();
}

// First streaming chunk or Request payload.
auto& new_item = values.emplace_back();
new_item.CopyFrom(h, allocator);
}
}

// Store the final entry if it exists.
if (!content_buffer.empty()) {
auto& new_item = values.back();
UpdateContent(new_item, content_buffer, allocator);
content_buffer.clear();
}

// Convert multiple Value objects into one Value instance.
for (auto& v : values) {
merged_history.PushBack(v, allocator);
}

payload_messages.CopyFrom(merged_history, payload_document.GetAllocator());
}

std::string
Expand Down
4 changes: 4 additions & 0 deletions src/session_concurrency/payload_json_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class PayloadJsonUtils {
static void ValidatePayloadMessages(
const rapidjson::Document& payload_document);

static void UpdateContent(
rapidjson::Value& item,
std::string& buffer,
rapidjson::Document::AllocatorType& allocator);
static void SetPayloadToChatHistory(
rapidjson::Document& payload_document,
const rapidjson::Document& chat_history);
Expand Down
52 changes: 44 additions & 8 deletions src/session_concurrency/request_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "../client_backend/client_backend.h"
#include "../model_parser.h"
#include "../request_record.h"
#include "../rapidjson_utils.h"
#include "payload_dataset_manager.h"
#include "payload_json_utils.h"
#include "response_json_utils.h"
Expand Down Expand Up @@ -197,16 +198,51 @@ RequestHandler::PrepareCallback(
const auto& response_document{
ResponseJsonUtils::GetResponseDocument(response_buffer)};

const auto& response_message{
ResponseJsonUtils::GetMessage(response_document)};

rapidjson::Value response_message_copy{};
response_message_copy.CopyFrom(
response_message, chat_history.GetAllocator());
bool is_stream{false};
auto error = infer_result->IsStreamResponse(&is_stream);
if (!error.IsOk()) {
// Forcibly set false to `is_stream` because
// this `infer_result` object is a subclass which
// does not implement `IsStreamResponse()`.
is_stream = false;
}
bool is_final{false};
error = infer_result->IsFinalResponse(&is_final);
if (!error.IsOk()) {
// Forcibly set false to `is_final`.
is_final = false;
}

chat_history.PushBack(response_message_copy, chat_history.GetAllocator());
if (!response_document.IsNull()) {
// `response_document` should not be null
// when the response text is empty ("").
// Null can happen only when response is `data: [DONE]`.

if (is_stream && is_final) {
// Unexpected response.
throw std::runtime_error(
"In the case of streaming and the last chunk, response object must be null:\n\n" +
RapidJsonUtils::Serialize(response_document) + "\n\n\n"
);
}

rapidjson::Value response_message_copy{};
if (is_stream) {
response_message_copy.CopyFrom(
ResponseJsonUtils::GetDelta(response_document),
chat_history.GetAllocator());
} else {
response_message_copy.CopyFrom(
ResponseJsonUtils::GetMessage(response_document),
chat_history.GetAllocator());
}

chat_history.PushBack(response_message_copy, chat_history.GetAllocator());
}

response_promise->set_value();
if (is_final) {
response_promise->set_value();
}
};
}

Expand Down
61 changes: 49 additions & 12 deletions src/session_concurrency/response_json_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ ResponseJsonUtils::GetResponseDocument(
const std::string response_buffer_str(
response_buffer.begin(), response_buffer.end());

response_document.Parse(response_buffer_str.c_str(), response_buffer.size());
if (response_buffer_str.find("data: [DONE]") != std::string::npos) {
// This response is the last message indicating
// the termination of the streaming response.
return response_document;
}

size_t substr_index = 0;
if (response_buffer_str.starts_with("data: ")) {
substr_index = 6;
}
const auto response_substr = response_buffer_str.substr(substr_index);
response_document.Parse(response_substr.c_str(), response_substr.size());

if (response_document.HasParseError()) {
throw std::runtime_error(
Expand All @@ -62,17 +73,7 @@ ResponseJsonUtils::GetResponseDocument(
const rapidjson::Value&
ResponseJsonUtils::GetMessage(const rapidjson::Document& response_document)
{
if (!response_document.IsObject() ||
!response_document.HasMember("choices") ||
!response_document["choices"].IsArray() ||
response_document["choices"].Empty()) {
throw std::runtime_error(
"Response body must be an object and have a 'choices' field that is "
"an array with at least one element. Response body:\n\n" +
RapidJsonUtils::Serialize(response_document) + "\n\n\n");
}

const auto& response_first_choice{response_document["choices"][0]};
const auto& response_first_choice = GetChoices(response_document);

if (!response_first_choice.IsObject() ||
!response_first_choice.HasMember("message") ||
Expand All @@ -86,4 +87,40 @@ ResponseJsonUtils::GetMessage(const rapidjson::Document& response_document)
return response_first_choice["message"];
}

const rapidjson::Value&
ResponseJsonUtils::GetDelta(const rapidjson::Document& response_document)
{
const auto& response_first_choice = GetChoices(response_document);

if (!response_first_choice.IsObject() ||
!response_first_choice.HasMember("delta") ||
!response_first_choice["delta"].IsObject()) {
throw std::runtime_error(
"In streaming mode, response body 'choices' field's first element "
"must be an object and have a 'delta' field that is an object. "
"Response body:\n\n" +
RapidJsonUtils::Serialize(response_document) + "\n\n\n");
}

return response_first_choice["delta"];
}

const rapidjson::Value&
ResponseJsonUtils::GetChoices(
const rapidjson::Document& response_document,
const int choices_index)
{
if (!response_document.IsObject() ||
!response_document.HasMember("choices") ||
!response_document["choices"].IsArray() ||
response_document["choices"].Empty()) {
throw std::runtime_error(
"Response body must be an object and have a 'choices' field that is "
"an array with at least one element. Response body:\n\n" +
RapidJsonUtils::Serialize(response_document) + "\n\n\n");
}

return response_document["choices"][choices_index];
}

} // namespace triton::perfanalyzer
7 changes: 7 additions & 0 deletions src/session_concurrency/response_json_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class ResponseJsonUtils {

static const rapidjson::Value& GetMessage(
const rapidjson::Document& response_document);
static const rapidjson::Value& GetDelta(
const rapidjson::Document& response_document);

private:
static const rapidjson::Value& GetChoices(
const rapidjson::Document& response_document,
const int choices_index=0);
};

} // namespace triton::perfanalyzer
Loading