@@ -260,23 +260,44 @@ static size_t validate_utf8(const std::string& text) {
260
260
// template utils
261
261
//
262
262
263
- // format rerank task: [BOS]query[EOS][SEP]doc[EOS]
264
- static llama_tokens format_rerank (const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
263
+ // format rerank task:
264
+ // - using SEP token: [BOS]query[EOS][SEP]doc[EOS]
265
+ // - using prompt: <rerank_prefix>query<rerank_suffix>doc
266
+ static llama_tokens format_rerank (const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
267
+ const llama_vocab * vocab = llama_model_get_vocab (model);
265
268
llama_tokens result;
266
269
267
- // Get EOS token - use SEP token as fallback if EOS is not available
268
- llama_token eos_token = llama_vocab_eos (vocab);
269
- if (eos_token == LLAMA_TOKEN_NULL) {
270
- eos_token = llama_vocab_sep (vocab);
271
- }
270
+ if (llama_vocab_sep (vocab) != LLAMA_TOKEN_NULL) {
271
+ // Get EOS token - use SEP token as fallback if EOS is not available
272
+ llama_token eos_token = llama_vocab_eos (vocab);
273
+ if (eos_token == LLAMA_TOKEN_NULL) {
274
+ eos_token = llama_vocab_sep (vocab);
275
+ }
276
+
277
+ result.reserve (doc.size () + query.size () + 4 );
278
+ result.push_back (llama_vocab_bos (vocab));
279
+ result.insert (result.end (), query.begin (), query.end ());
280
+ result.push_back (eos_token);
281
+ result.push_back (llama_vocab_sep (vocab));
282
+ result.insert (result.end (), doc.begin (), doc.end ());
283
+ result.push_back (eos_token);
284
+ } else {
285
+ // using prompt template
286
+ const char * prefix = llama_model_chat_template (model, " rerank_prefix" );
287
+ const char * suffix = llama_model_chat_template (model, " rerank_suffix" );
288
+
289
+ if (prefix == NULL && suffix == NULL ) {
290
+ throw std::runtime_error (" Rerank prompt template not found in the model\n " );
291
+ }
272
292
273
- result.reserve (doc.size () + query.size () + 4 );
274
- result.push_back (llama_vocab_bos (vocab));
275
- result.insert (result.end (), query.begin (), query.end ());
276
- result.push_back (eos_token);
277
- result.push_back (llama_vocab_sep (vocab));
278
- result.insert (result.end (), doc.begin (), doc.end ());
279
- result.push_back (eos_token);
293
+ const llama_tokens prefix_tokens = prefix ? common_tokenize (vocab, prefix, true , false ) : llama_tokens ();
294
+ const llama_tokens suffix_tokens = suffix ? common_tokenize (vocab, suffix, false , false ) : llama_tokens ();
295
+ result.reserve (prefix_tokens.size () + query.size () + suffix_tokens.size () + doc.size ());
296
+ result.insert (result.end (), prefix_tokens.begin (), prefix_tokens.end ());
297
+ result.insert (result.end (), query.begin (), query.end ());
298
+ result.insert (result.end (), suffix_tokens.begin (), suffix_tokens.end ());
299
+ result.insert (result.end (), doc.begin (), doc.end ());
300
+ }
280
301
281
302
return result;
282
303
}
0 commit comments