From 793120837f70eec9417a76f44ab4eed2c3af24c5 Mon Sep 17 00:00:00 2001 From: Aria Haghighi Date: Wed, 29 Dec 2021 16:10:15 -0800 Subject: [PATCH] truncate decode_forward pass to skip spans longer than any word in SentencePieceModel vocab --- src/statistical/unigram.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/statistical/unigram.jl b/src/statistical/unigram.jl index 2901b1d..47eb576 100644 --- a/src/statistical/unigram.jl +++ b/src/statistical/unigram.jl @@ -8,6 +8,14 @@ structure, To hold unknown token index and map of vocabulary to log probability struct SentencePieceModel vocab_map::Dict{String, Tuple{Float64, Int}} unk_id::Int + # used to make decoding more efficient + max_vocab_codeunit_len::Int + + function SentencePieceModel(vocab_map, unk_id) + num_codeunits(idx) = isassigned(vocab_map.keys, idx) ? ncodeunits(vocab_map.keys[idx]) : 0 + max_vocab_codeunit_len, _ = findmax(num_codeunits.(eachindex(vocab_map.keys))) + new(vocab_map, unk_id, max_vocab_codeunit_len) + end end """ @@ -97,14 +105,16 @@ function decode_forward(sp::SentencePieceModel, text::String) scores = fill(-Inf, lastindex(text)) scores[1] = 0 for char_end in eachindex(text) - for char_start in eachindex(text) - char_start > char_end && break + min_start = max(firstindex(text), char_end - sp.max_vocab_codeunit_len + 1) + candidate_text_substr = SubString(text, thisind(text, min_start), char_end) + for relative_idx in eachindex(candidate_text_substr) + char_start = candidate_text_substr.offset + relative_idx subtoken = SubString(text, char_start:char_end) if haskey(sp.vocab_map, subtoken) subtokenid = sp.vocab_map[subtoken][2] local_score = scores[char_start] + sp.vocab_map[subtoken][1] if local_score > scores[char_end] - results[char_end] = Nodes(SubString(text, char_start:char_end), local_score, subtokenid, char_start, char_end) + results[char_end] = Nodes(subtoken, local_score, subtokenid, char_start, char_end) scores[char_end] = local_score end end