From 98566bc9abf96ea9adfa45cc54104f9b9be46375 Mon Sep 17 00:00:00 2001 From: Rainbow_piggy Date: Tue, 11 Aug 2020 11:55:44 +0800 Subject: [PATCH] Update hmm.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1.修改注释中Viterbi的递推公式中的错误(参见统计学习方法(第二版)209页);2.优化decoding获取最大值部分,矩阵加法替代循环,时间复杂度从O(n^2)到O(n)。 --- models/hmm.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/models/hmm.py b/models/hmm.py index 54e7b8a..55e5c73 100644 --- a/models/hmm.py +++ b/models/hmm.py @@ -104,7 +104,7 @@ def decoding(self, word_list, word2id, tag2id): backpointer[:, 0] = -1 # 递推公式: - # viterbi[tag_id, step] = max(viterbi[:, step-1]* self.A.t()[tag_id] * Bt[word]) + # viterbi[tag_id, step] = max(viterbi[:, step-1]* self.A.t()[tag_id]) * Bt[wordid] # 这里取最大值的范围应该是viterbi * A.t()的,不包括Bt[wordid] # 其中word是step时刻对应的字 # 由上述递推公式求后续各步 for step in range(1, seq_len): @@ -116,13 +116,17 @@ def decoding(self, word_list, word2id, tag2id): bt = torch.log(torch.ones(self.N) / self.N) else: bt = Bt[wordid] # 否则从观测概率矩阵中取bt - for tag_id in range(len(tag2id)): - max_prob, max_id = torch.max( - viterbi[:, step-1] + A[:, tag_id], - dim=0 - ) - viterbi[tag_id, step] = max_prob + bt[tag_id] - backpointer[tag_id, step] = max_id +# for tag_id in range(len(tag2id)): +# max_prob, max_id = torch.max( +# viterbi[:, step-1] + A[:, tag_id], +# dim=0 +# ) +# viterbi[tag_id, step] = max_prob + bt[tag_id] +# backpointer[tag_id, step] = max_id + # 修改为直接使用矩阵加法运算,使程序时间复杂度从O(n^2)到O(n) + max_probs, max_ids = torch.max(viterbi[:, step-1] + A.t(), dim=1) + viterbi[:, step] = max_probs + Bt[wordid] + backpointer[:, step] = max_ids # 终止, t=seq_len 即 viterbi[:, seq_len]中的最大概率,就是最优路径的概率 best_path_prob, best_path_pointer = torch.max(