Skip to content

Commit 1b855b3

Browse files
committed
debug
1 parent 0023b1f commit 1b855b3

File tree

1 file changed

+99
-84
lines changed

1 file changed

+99
-84
lines changed

src/open_prompt_extension.cpp

Lines changed: 99 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,19 @@
1313
#include <string>
1414
#include <sstream>
1515
#include <mutex>
16-
#include <iostream>
16+
#include <iostream>
17+
#include <yyjson.hpp>
1718

1819

1920
namespace duckdb {
21+
struct OpenPromptData: FunctionData {
22+
unique_ptr<FunctionData> Copy() const {
23+
throw std::runtime_error("OpenPromptData::Copy");
24+
};
25+
bool Equals(const FunctionData &other) const {
26+
throw std::runtime_error("OpenPromptData::Equals");
27+
};
28+
};
2029

2130
// Helper function to parse URL and setup client
2231
static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) {
@@ -93,38 +102,71 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
93102
// Open Prompt
94103
// Global settings
95104
static std::string api_url = "http://localhost:11434/v1/chat/completions";
96-
static std::string api_token = ""; // Store your API token here
105+
static std::string api_token; // Store your API token here
97106
static std::string model_name = "qwen2.5:0.5b"; // Default model
98107
static std::mutex settings_mutex;
99108

100109
// Function to set API token
101-
void SetApiToken(const std::string &token) {
102-
std::lock_guard<std::mutex> guard(settings_mutex);
103-
if (token.empty()) {
104-
throw std::invalid_argument("API token cannot be empty.");
105-
}
106-
api_token = token;
107-
std::cerr << "API token set to: " << api_token << std::endl; // Debugging output
110+
void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) {
111+
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
112+
[&](string_t token) {
113+
try {
114+
auto _token = token.GetData();
115+
if (token.Empty()) {
116+
throw std::invalid_argument("API token cannot be empty.");
117+
}
118+
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
119+
"openprompt_api_token",
120+
Value::CreateValue(token.GetString()));
121+
return StringVector::AddString(result, string("token : ") + string(_token, token.GetSize()));
122+
} catch (std::exception &e) {
123+
string_t res(e.what());
124+
res.Finalize();
125+
return res;
126+
}
127+
});
108128
}
109129

110130
// Function to set API URL
111-
void SetApiUrl(const std::string &url) {
112-
std::lock_guard<std::mutex> guard(settings_mutex);
113-
if (url.empty()) {
114-
throw std::invalid_argument("URL cannot be empty.");
115-
}
116-
api_url = url;
117-
std::cerr << "API URL set to: " << api_url << std::endl; // Debugging output
131+
void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) {
132+
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
133+
[&](string_t token) {
134+
try {
135+
auto _token = token.GetData();
136+
if (token.Empty()) {
137+
throw std::invalid_argument("API token cannot be empty.");
138+
}
139+
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
140+
"openprompt_api_url",
141+
Value::CreateValue(token.GetString()));
142+
return StringVector::AddString(result, string("url : ") + string(_token, token.GetSize()));
143+
} catch (std::exception &e) {
144+
string_t res(e.what());
145+
res.Finalize();
146+
return res;
147+
}
148+
});
118149
}
119150

120151
// Function to set model name
121-
void SetModelName(const std::string &model) {
122-
std::lock_guard<std::mutex> guard(settings_mutex);
123-
if (model.empty()) {
124-
throw std::invalid_argument("Model name cannot be empty.");
125-
}
126-
model_name = model;
127-
std::cerr << "Model name set to: " << model_name << std::endl; // Debugging output
152+
void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) {
153+
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
154+
[&](string_t token) {
155+
try {
156+
auto _token = token.GetData();
157+
if (token.Empty()) {
158+
throw std::invalid_argument("API token cannot be empty.");
159+
}
160+
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
161+
"openprompt_model_name",
162+
Value::CreateValue(token.GetString()));
163+
return StringVector::AddString(result, string("name : ") + string(_token, token.GetSize()));
164+
} catch (std::exception &e) {
165+
string_t res(e.what());
166+
res.Finalize();
167+
return res;
168+
}
169+
});
128170
}
129171

130172
// Retrieve the API URL from the stored settings
@@ -145,68 +187,65 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
145187
return model_name.empty() ? "qwen2.5:0.5b" : model_name;
146188
}
147189

190+
template<typename a> a assert_null(a val) {
191+
if (val == nullptr) {
192+
throw std::runtime_error("Failed to parse the first message content in the API response.");
193+
}
194+
return val;
195+
}
148196
// Open Prompt Function
149197
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
150-
D_ASSERT(args.data.size() == 2); // Expecting the prompt and model name
151-
152198
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
153199
[&](string_t user_prompt) {
154-
std::string api_url = GetApiUrl(); // Retrieve the API URL from settings
155-
std::string api_token = GetApiToken(); // Retrieve the API Token from settings
156-
std::string model_name;
157-
158-
if (!args.data[1].GetValue(0).IsNull()) {
159-
model_name = args.data[1].GetValue(0).ToString(); // Use passed model name
160-
} else {
161-
model_name = GetModelName(); // Use the default model if none is provided
162-
}
200+
duckdb_yyjson::yyjson_doc *doc = nullptr;
201+
auto &conf = ClientConfig::GetConfig(state.GetContext());
202+
Value api_url;
203+
Value api_token;
204+
Value model_name;
205+
conf.GetUserVariable("openprompt_api_url", api_url);
206+
conf.GetUserVariable("openprompt_api_token", api_token);
207+
conf.GetUserVariable("openprompt_model_name", model_name);
163208

164209
// Manually construct the JSON body as a string. TODO use json parser from extension.
165210
std::string request_body = "{";
166-
request_body += "\"model\":\"" + model_name + "\",";
211+
request_body += "\"model\":\"" + model_name.ToString() + "\",";
167212
request_body += "\"messages\":[";
168213
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
169214
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
170215
request_body += "]}";
171216

172217
try {
173218
// Make the POST request
174-
auto client_and_path = SetupHttpClient(api_url);
219+
auto client_and_path = SetupHttpClient(api_url.ToString());
175220
auto &client = client_and_path.first;
176221
auto &path = client_and_path.second;
177222

178223
// Setup headers
179224
duckdb_httplib_openssl::Headers header_map;
180225
header_map.emplace("Content-Type", "application/json");
181-
if (!api_token.empty()) {
182-
header_map.emplace("Authorization", "Bearer " + api_token);
226+
if (!api_token.ToString().empty()) {
227+
header_map.emplace("Authorization", "Bearer " + api_token.ToString());
183228
}
184229

185230
// Send the request
186231
auto res = client.Post(path.c_str(), header_map, request_body, "application/json");
187232
if (res && res->status == 200) {
188233
// Extract the first choice's message content from the response
189234
std::string response_body = res->body;
190-
size_t choices_pos = response_body.find("\"choices\":");
191-
if (choices_pos != std::string::npos) {
192-
size_t message_pos = response_body.find("\"message\":", choices_pos);
193-
size_t content_pos = response_body.find("\"content\":\"", message_pos);
194-
if (content_pos != std::string::npos) {
195-
content_pos += 11; // Move to the start of the content value
196-
size_t content_end = response_body.find("\"", content_pos);
197-
if (content_end != std::string::npos) {
198-
std::string first_message_content = response_body.substr(content_pos, content_end - content_pos);
199-
return StringVector::AddString(result, first_message_content);
200-
}
201-
}
202-
}
203-
throw std::runtime_error("Failed to parse the first message content in the API response.");
204-
} else {
205-
throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason);
235+
doc = duckdb_yyjson::yyjson_read(
236+
response_body.c_str(), response_body.length(), 0);
237+
auto root = assert_null(duckdb_yyjson::yyjson_doc_get_root(doc));
238+
auto choices = assert_null(duckdb_yyjson::yyjson_obj_get(root, "choices"));
239+
auto choices_0 = assert_null(duckdb_yyjson::yyjson_arr_get_first(choices));
240+
auto message = assert_null(duckdb_yyjson::yyjson_obj_get(choices_0, "message"));
241+
auto content = assert_null(duckdb_yyjson::yyjson_obj_get(message, "content"));
242+
auto c_content = assert_null(duckdb_yyjson::yyjson_get_str(content));
243+
return StringVector::AddString(result, c_content);
206244
}
245+
throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason);
207246
} catch (std::exception &e) {
208247
// In case of any error, return the original input text to avoid disruption
209-
return StringVector::AddString(result, user_prompt);
248+
return StringVector::AddString(result, e.what());
210249
}
211250
});
212251
}
@@ -216,45 +255,21 @@ static void LoadInternal(DatabaseInstance &instance) {
216255
// Register open_prompt function with two arguments: prompt and model
217256
ScalarFunctionSet open_prompt("open_prompt");
218257
open_prompt.AddFunction(ScalarFunction(
219-
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
258+
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
220259
ExtensionUtil::RegisterFunction(instance, open_prompt);
221260

222261
// Other set_* functions remain the same as before
223262
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
224263
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
225-
[](DataChunk &args, ExpressionState &state, Vector &result) {
226-
try {
227-
auto token = args.data[0].GetValue(0).ToString();
228-
SetApiToken(token);
229-
return StringVector::AddString(result, "API token set successfully.");
230-
} catch (std::exception &e) {
231-
return StringVector::AddString(result, "Failed to set API token: " + std::string(e.what()));
232-
}
233-
}));
264+
SetApiToken));
234265

235266
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
236267
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
237-
[](DataChunk &args, ExpressionState &state, Vector &result) {
238-
try {
239-
auto new_url = args.data[0].GetValue(0).ToString();
240-
SetApiUrl(new_url);
241-
return StringVector::AddString(result, "API URL set successfully.");
242-
} catch (std::exception &e) {
243-
return StringVector::AddString(result, "Failed to set API URL: " + std::string(e.what()));
244-
}
245-
}));
268+
SetApiUrl));
246269

247270
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
248-
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
249-
[](DataChunk &args, ExpressionState &state, Vector &result) {
250-
try {
251-
auto model = args.data[0].GetValue(0).ToString();
252-
SetModelName(model);
253-
return StringVector::AddString(result, "Model name set successfully.");
254-
} catch (std::exception &e) {
255-
return StringVector::AddString(result, "Failed to set model name: " + std::string(e.what()));
256-
}
257-
}));
271+
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName
272+
));
258273
}
259274

260275

0 commit comments

Comments
 (0)