Skip to content

Commit 74526a2

Browse files
committed
add <|eom_id|> to stopTokens
1 parent 935a923 commit 74526a2

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama/llamastack/client/local/InferenceServiceLocalImpl.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ constructor(
2929
private var sequenceLengthKey: String = "seq_len"
3030

3131
override fun onResult(p0: String?) {
32-
if (p0.equals(PromptFormatLocal.getStopToken(modelName))) {
32+
if (PromptFormatLocal.getStopTokens(modelName).any { it == p0 }) {
3333
onResultComplete = true
3434
return
3535
}
@@ -62,8 +62,9 @@ constructor(
6262
PromptFormatLocal.getTotalFormattedPrompt(params.messages(), modelName)
6363

6464
// Developer can pass in their sequence length but if not then it will default to a
65-
// particular dynamic value. This is to ensure enough value is provided to give a reasonably complete response.
66-
// 0.75 is the approximate words per token. And 64 is buffer for tokens for generate response.
65+
// particular dynamic value. This is to ensure enough value is provided to give a reasonably
66+
// complete response. 0.75 is the approximate words per token. And 64 is buffer for tokens
67+
// for generate response.
6768
val seqLength =
6869
params._additionalQueryParams().values(sequenceLengthKey).lastOrNull()?.toInt()
6970
?: ((formattedPrompt.length * 0.75) + 64).toInt()

llama-stack-client-kotlin-client-local/src/main/kotlin/com/llama/llamastack/client/local/util/PromptFormatLocal.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ object PromptFormatLocal {
4040
}
4141
}
4242

43-
fun getStopToken(modelName: String?): String {
43+
fun getStopTokens(modelName: String?): List<String> {
4444
return when (modelName) {
4545
"LLAMA_3",
4646
"LLAMA_3_1",
4747
"LLAMA_3_2",
48-
"LLAMA_GUARD_3" -> "<|eot_id|>"
49-
else -> ""
48+
"LLAMA_GUARD_3" -> listOf("<|eot_id|>", "<|eom_id|>")
49+
else -> listOf("")
5050
}
5151
}
5252

0 commit comments

Comments
 (0)