@@ -39,12 +39,12 @@ static std::ostringstream * g_output_ss;
39
39
static std::vector<llama_token> * g_output_tokens;
40
40
static bool is_interacting = false ;
41
41
42
- static bool file_exists (const std::string &path) {
42
+ static bool file_exists (const std::string & path) {
43
43
std::ifstream f (path.c_str ());
44
44
return f.good ();
45
45
}
46
46
47
- static bool file_is_empty (const std::string &path) {
47
+ static bool file_is_empty (const std::string & path) {
48
48
std::ifstream f;
49
49
f.exceptions (std::ifstream::failbit | std::ifstream::badbit);
50
50
f.open (path.c_str (), std::ios::in | std::ios::binary | std::ios::ate);
@@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
117
117
LOG_TEE (" %s" , text);
118
118
}
119
119
120
+ static std::string chat_add_and_format (struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
121
+ llama_chat_msg new_msg{role, content};
122
+ auto formatted = llama_chat_format_single (
123
+ model, g_params->chat_template , chat_msgs, new_msg, role == " user" );
124
+ chat_msgs.push_back ({role, content});
125
+ return formatted;
126
+ }
127
+
120
128
int main (int argc, char ** argv) {
121
129
gpt_params params;
122
130
g_params = ¶ms;
@@ -190,6 +198,7 @@ int main(int argc, char ** argv) {
190
198
llama_model * model;
191
199
llama_context * ctx;
192
200
llama_context * ctx_guidance = NULL ;
201
+ std::vector<llama_chat_msg> chat_msgs;
193
202
g_model = &model;
194
203
g_ctx = &ctx;
195
204
@@ -215,6 +224,8 @@ int main(int argc, char ** argv) {
215
224
__func__, n_ctx_train, n_ctx);
216
225
}
217
226
227
+ LOG_TEE (" %s: chat template example: %s\n " , __func__, llama_chat_format_example (model, params.chat_template ).c_str ());
228
+
218
229
// print system information
219
230
{
220
231
LOG_TEE (" \n " );
@@ -249,16 +260,21 @@ int main(int argc, char ** argv) {
249
260
250
261
std::vector<llama_token> embd_inp;
251
262
252
- if (params.interactive_first || !params.prompt .empty () || session_tokens.empty ()) {
253
- LOG (" tokenize the prompt\n " );
254
- embd_inp = ::llama_tokenize (ctx, params.prompt , true , true );
255
- } else {
256
- LOG (" use session tokens\n " );
257
- embd_inp = session_tokens;
258
- }
263
+ {
264
+ auto prompt = params.conversation
265
+ ? chat_add_and_format (model, chat_msgs, " system" , params.prompt ) // format the system prompt in conversation mode
266
+ : params.prompt ;
267
+ if (params.interactive_first || !params.prompt .empty () || session_tokens.empty ()) {
268
+ LOG (" tokenize the prompt\n " );
269
+ embd_inp = ::llama_tokenize (ctx, prompt, true , true );
270
+ } else {
271
+ LOG (" use session tokens\n " );
272
+ embd_inp = session_tokens;
273
+ }
259
274
260
- LOG (" prompt: \" %s\"\n " , log_tostr (params.prompt ));
261
- LOG (" tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp).c_str ());
275
+ LOG (" prompt: \" %s\"\n " , log_tostr (prompt));
276
+ LOG (" tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp).c_str ());
277
+ }
262
278
263
279
// Should not run without any tokens
264
280
if (embd_inp.empty ()) {
@@ -478,6 +494,7 @@ int main(int argc, char ** argv) {
478
494
std::vector<int > input_tokens; g_input_tokens = &input_tokens;
479
495
std::vector<int > output_tokens; g_output_tokens = &output_tokens;
480
496
std::ostringstream output_ss; g_output_ss = &output_ss;
497
+ std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
481
498
482
499
// the first thing we will do is to output the prompt, so set color accordingly
483
500
console::set_display (console::prompt);
@@ -793,11 +810,18 @@ int main(int argc, char ** argv) {
793
810
is_antiprompt = true ;
794
811
}
795
812
813
+ chat_add_and_format (model, chat_msgs, " system" , assistant_ss.str ());
796
814
is_interacting = true ;
797
815
printf (" \n " );
798
816
}
799
817
}
800
818
819
+ // if current token is not EOG, we add it to current assistant message
820
+ if (params.conversation ) {
821
+ auto id = llama_sampling_last (ctx_sampling);
822
+ assistant_ss << llama_token_to_piece (ctx, id, false );
823
+ }
824
+
801
825
if (n_past > 0 && is_interacting) {
802
826
LOG (" waiting for user input\n " );
803
827
@@ -848,8 +872,12 @@ int main(int argc, char ** argv) {
848
872
string_process_escapes (buffer);
849
873
}
850
874
875
+ std::string user_inp = params.conversation
876
+ ? chat_add_and_format (model, chat_msgs, " user" , std::move (buffer))
877
+ : std::move (buffer);
878
+ // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
851
879
const auto line_pfx = ::llama_tokenize (ctx, params.input_prefix , false , true );
852
- const auto line_inp = ::llama_tokenize (ctx, buffer , false , false );
880
+ const auto line_inp = ::llama_tokenize (ctx, user_inp , false , params. conversation );
853
881
const auto line_sfx = ::llama_tokenize (ctx, params.input_suffix , false , true );
854
882
855
883
LOG (" input tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, line_inp).c_str ());
@@ -864,6 +892,9 @@ int main(int argc, char ** argv) {
864
892
output_ss << llama_token_to_piece (ctx, token);
865
893
}
866
894
895
+ // reset assistant message
896
+ assistant_ss.str (" " );
897
+
867
898
n_remain -= line_inp.size ();
868
899
LOG (" n_remain: %d\n " , n_remain);
869
900
} else {
0 commit comments