13
13
#include < string>
14
14
#include < sstream>
15
15
#include < mutex>
16
- #include < iostream>
16
+ #include < iostream>
17
+ #include < yyjson.hpp>
17
18
18
19
19
20
namespace duckdb {
21
+ struct OpenPromptData : FunctionData {
22
+ unique_ptr<FunctionData> Copy () const {
23
+ throw std::runtime_error (" OpenPromptData::Copy" );
24
+ };
25
+ bool Equals (const FunctionData &other) const {
26
+ throw std::runtime_error (" OpenPromptData::Equals" );
27
+ };
28
+ };
20
29
21
30
// Helper function to parse URL and setup client
22
31
static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient (const std::string &url) {
@@ -93,38 +102,71 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
93
102
// Open Prompt
94
103
// Global settings
95
104
static std::string api_url = " http://localhost:11434/v1/chat/completions" ;
96
- static std::string api_token = " " ; // Store your API token here
105
+ static std::string api_token; // Store your API token here
97
106
static std::string model_name = " qwen2.5:0.5b" ; // Default model
98
107
static std::mutex settings_mutex;
99
108
100
109
// Function to set API token
101
- void SetApiToken (const std::string &token) {
102
- std::lock_guard<std::mutex> guard (settings_mutex);
103
- if (token.empty ()) {
104
- throw std::invalid_argument (" API token cannot be empty." );
105
- }
106
- api_token = token;
107
- std::cerr << " API token set to: " << api_token << std::endl; // Debugging output
110
+ void SetApiToken (DataChunk &args, ExpressionState &state, Vector &result) {
111
+ UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
112
+ [&](string_t token) {
113
+ try {
114
+ auto _token = token.GetData ();
115
+ if (token.Empty ()) {
116
+ throw std::invalid_argument (" API token cannot be empty." );
117
+ }
118
+ ClientConfig::GetConfig (state.GetContext ()).SetUserVariable (
119
+ " openprompt_api_token" ,
120
+ Value::CreateValue (token.GetString ()));
121
+ return StringVector::AddString (result, string (" token : " ) + string (_token, token.GetSize ()));
122
+ } catch (std::exception &e) {
123
+ string_t res (e.what ());
124
+ res.Finalize ();
125
+ return res;
126
+ }
127
+ });
108
128
}
109
129
110
130
// Function to set API URL
111
- void SetApiUrl (const std::string &url) {
112
- std::lock_guard<std::mutex> guard (settings_mutex);
113
- if (url.empty ()) {
114
- throw std::invalid_argument (" URL cannot be empty." );
115
- }
116
- api_url = url;
117
- std::cerr << " API URL set to: " << api_url << std::endl; // Debugging output
131
+ void SetApiUrl (DataChunk &args, ExpressionState &state, Vector &result) {
132
+ UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
133
+ [&](string_t token) {
134
+ try {
135
+ auto _token = token.GetData ();
136
+ if (token.Empty ()) {
137
+ throw std::invalid_argument (" API token cannot be empty." );
138
+ }
139
+ ClientConfig::GetConfig (state.GetContext ()).SetUserVariable (
140
+ " openprompt_api_url" ,
141
+ Value::CreateValue (token.GetString ()));
142
+ return StringVector::AddString (result, string (" url : " ) + string (_token, token.GetSize ()));
143
+ } catch (std::exception &e) {
144
+ string_t res (e.what ());
145
+ res.Finalize ();
146
+ return res;
147
+ }
148
+ });
118
149
}
119
150
120
151
// Function to set model name
121
- void SetModelName (const std::string &model) {
122
- std::lock_guard<std::mutex> guard (settings_mutex);
123
- if (model.empty ()) {
124
- throw std::invalid_argument (" Model name cannot be empty." );
125
- }
126
- model_name = model;
127
- std::cerr << " Model name set to: " << model_name << std::endl; // Debugging output
152
+ void SetModelName (DataChunk &args, ExpressionState &state, Vector &result) {
153
+ UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
154
+ [&](string_t token) {
155
+ try {
156
+ auto _token = token.GetData ();
157
+ if (token.Empty ()) {
158
+ throw std::invalid_argument (" API token cannot be empty." );
159
+ }
160
+ ClientConfig::GetConfig (state.GetContext ()).SetUserVariable (
161
+ " openprompt_model_name" ,
162
+ Value::CreateValue (token.GetString ()));
163
+ return StringVector::AddString (result, string (" name : " ) + string (_token, token.GetSize ()));
164
+ } catch (std::exception &e) {
165
+ string_t res (e.what ());
166
+ res.Finalize ();
167
+ return res;
168
+ }
169
+ });
128
170
}
129
171
130
172
// Retrieve the API URL from the stored settings
@@ -145,68 +187,65 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
145
187
return model_name.empty () ? " qwen2.5:0.5b" : model_name;
146
188
}
147
189
190
+ template <typename a> a assert_null (a val) {
191
+ if (val == nullptr ) {
192
+ throw std::runtime_error (" Failed to parse the first message content in the API response." );
193
+ }
194
+ return val;
195
+ }
148
196
// Open Prompt Function
149
197
static void OpenPromptRequestFunction (DataChunk &args, ExpressionState &state, Vector &result) {
150
- D_ASSERT (args.data .size () == 2 ); // Expecting the prompt and model name
151
-
152
198
UnaryExecutor::Execute<string_t , string_t >(args.data [0 ], result, args.size (),
153
199
[&](string_t user_prompt) {
154
- std::string api_url = GetApiUrl (); // Retrieve the API URL from settings
155
- std::string api_token = GetApiToken (); // Retrieve the API Token from settings
156
- std::string model_name;
157
-
158
- if (!args.data [1 ].GetValue (0 ).IsNull ()) {
159
- model_name = args.data [1 ].GetValue (0 ).ToString (); // Use passed model name
160
- } else {
161
- model_name = GetModelName (); // Use the default model if none is provided
162
- }
200
+ duckdb_yyjson::yyjson_doc *doc = nullptr ;
201
+ auto &conf = ClientConfig::GetConfig (state.GetContext ());
202
+ Value api_url;
203
+ Value api_token;
204
+ Value model_name;
205
+ conf.GetUserVariable (" openprompt_api_url" , api_url);
206
+ conf.GetUserVariable (" openprompt_api_token" , api_token);
207
+ conf.GetUserVariable (" openprompt_model_name" , model_name);
163
208
164
209
// Manually construct the JSON body as a string. TODO use json parser from extension.
165
210
std::string request_body = " {" ;
166
- request_body += " \" model\" :\" " + model_name + " \" ," ;
211
+ request_body += " \" model\" :\" " + model_name. ToString () + " \" ," ;
167
212
request_body += " \" messages\" :[" ;
168
213
request_body += " {\" role\" :\" system\" ,\" content\" :\" You are a helpful assistant.\" }," ;
169
214
request_body += " {\" role\" :\" user\" ,\" content\" :\" " + user_prompt.GetString () + " \" }" ;
170
215
request_body += " ]}" ;
171
216
172
217
try {
173
218
// Make the POST request
174
- auto client_and_path = SetupHttpClient (api_url);
219
+ auto client_and_path = SetupHttpClient (api_url. ToString () );
175
220
auto &client = client_and_path.first ;
176
221
auto &path = client_and_path.second ;
177
222
178
223
// Setup headers
179
224
duckdb_httplib_openssl::Headers header_map;
180
225
header_map.emplace (" Content-Type" , " application/json" );
181
- if (!api_token.empty ()) {
182
- header_map.emplace (" Authorization" , " Bearer " + api_token);
226
+ if (!api_token.ToString (). empty ()) {
227
+ header_map.emplace (" Authorization" , " Bearer " + api_token. ToString () );
183
228
}
184
229
185
230
// Send the request
186
231
auto res = client.Post (path.c_str (), header_map, request_body, " application/json" );
187
232
if (res && res->status == 200 ) {
188
233
// Extract the first choice's message content from the response
189
234
std::string response_body = res->body ;
190
- size_t choices_pos = response_body.find (" \" choices\" :" );
191
- if (choices_pos != std::string::npos) {
192
- size_t message_pos = response_body.find (" \" message\" :" , choices_pos);
193
- size_t content_pos = response_body.find (" \" content\" :\" " , message_pos);
194
- if (content_pos != std::string::npos) {
195
- content_pos += 11 ; // Move to the start of the content value
196
- size_t content_end = response_body.find (" \" " , content_pos);
197
- if (content_end != std::string::npos) {
198
- std::string first_message_content = response_body.substr (content_pos, content_end - content_pos);
199
- return StringVector::AddString (result, first_message_content);
200
- }
201
- }
202
- }
203
- throw std::runtime_error (" Failed to parse the first message content in the API response." );
204
- } else {
205
- throw std::runtime_error (" HTTP POST error: " + std::to_string (res->status ) + " - " + res->reason );
235
+ doc = duckdb_yyjson::yyjson_read (
236
+ response_body.c_str (), response_body.length (), 0 );
237
+ auto root = assert_null (duckdb_yyjson::yyjson_doc_get_root (doc));
238
+ auto choices = assert_null (duckdb_yyjson::yyjson_obj_get (root, " choices" ));
239
+ auto choices_0 = assert_null (duckdb_yyjson::yyjson_arr_get_first (choices));
240
+ auto message = assert_null (duckdb_yyjson::yyjson_obj_get (choices_0, " message" ));
241
+ auto content = assert_null (duckdb_yyjson::yyjson_obj_get (message, " content" ));
242
+ auto c_content = assert_null (duckdb_yyjson::yyjson_get_str (content));
243
+ return StringVector::AddString (result, c_content);
206
244
}
245
+ throw std::runtime_error (" HTTP POST error: " + std::to_string (res->status ) + " - " + res->reason );
207
246
} catch (std::exception &e) {
208
247
// In case of any error, return the original input text to avoid disruption
209
- return StringVector::AddString (result, user_prompt );
248
+ return StringVector::AddString (result, e. what () );
210
249
}
211
250
});
212
251
}
@@ -216,45 +255,21 @@ static void LoadInternal(DatabaseInstance &instance) {
216
255
// Register open_prompt function with two arguments: prompt and model
217
256
ScalarFunctionSet open_prompt (" open_prompt" );
218
257
open_prompt.AddFunction (ScalarFunction (
219
- {LogicalType::VARCHAR, LogicalType::VARCHAR }, LogicalType::VARCHAR, OpenPromptRequestFunction));
258
+ {LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
220
259
ExtensionUtil::RegisterFunction (instance, open_prompt);
221
260
222
261
// Other set_* functions remain the same as before
223
262
ExtensionUtil::RegisterFunction (instance, ScalarFunction (
224
263
" set_api_token" , {LogicalType::VARCHAR}, LogicalType::VARCHAR,
225
- [](DataChunk &args, ExpressionState &state, Vector &result) {
226
- try {
227
- auto token = args.data [0 ].GetValue (0 ).ToString ();
228
- SetApiToken (token);
229
- return StringVector::AddString (result, " API token set successfully." );
230
- } catch (std::exception &e) {
231
- return StringVector::AddString (result, " Failed to set API token: " + std::string (e.what ()));
232
- }
233
- }));
264
+ SetApiToken));
234
265
235
266
ExtensionUtil::RegisterFunction (instance, ScalarFunction (
236
267
" set_api_url" , {LogicalType::VARCHAR}, LogicalType::VARCHAR,
237
- [](DataChunk &args, ExpressionState &state, Vector &result) {
238
- try {
239
- auto new_url = args.data [0 ].GetValue (0 ).ToString ();
240
- SetApiUrl (new_url);
241
- return StringVector::AddString (result, " API URL set successfully." );
242
- } catch (std::exception &e) {
243
- return StringVector::AddString (result, " Failed to set API URL: " + std::string (e.what ()));
244
- }
245
- }));
268
+ SetApiUrl));
246
269
247
270
ExtensionUtil::RegisterFunction (instance, ScalarFunction (
248
- " set_model_name" , {LogicalType::VARCHAR}, LogicalType::VARCHAR,
249
- [](DataChunk &args, ExpressionState &state, Vector &result) {
250
- try {
251
- auto model = args.data [0 ].GetValue (0 ).ToString ();
252
- SetModelName (model);
253
- return StringVector::AddString (result, " Model name set successfully." );
254
- } catch (std::exception &e) {
255
- return StringVector::AddString (result, " Failed to set model name: " + std::string (e.what ()));
256
- }
257
- }));
271
+ " set_model_name" , {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName
272
+ ));
258
273
}
259
274
260
275
0 commit comments