From b1b628bb738f7bcb2ac3400efa2cf55c325276a1 Mon Sep 17 00:00:00 2001 From: zy Date: Wed, 27 Nov 2019 23:44:40 +0800 Subject: [PATCH] 'update' --- input_data.py | 6 +++--- model.py | 2 +- word2vec.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/input_data.py b/input_data.py index bb47ed4..a029ea7 100644 --- a/input_data.py +++ b/input_data.py @@ -23,7 +23,7 @@ def __init__(self, file_name, min_count): print('Sentence Length: %d' % (self.sentence_length)) def get_words(self, min_count): - self.input_file = open(self.input_file_name) + self.input_file = open(self.input_file_name, 'r', encoding='utf-8', errors='ignore') self.sentence_length = 0 self.sentence_count = 0 word_frequency = dict() @@ -66,7 +66,7 @@ def get_batch_pairs(self, batch_size, window_size): while len(self.word_pair_catch) < batch_size: sentence = self.input_file.readline() if sentence is None or sentence == '': - self.input_file = open(self.input_file_name) + self.input_file = open(self.input_file_name, 'r', encoding='utf-8', errors='ignore') sentence = self.input_file.readline() word_ids = [] for word in sentence.strip().split(' '): @@ -79,7 +79,7 @@ def get_batch_pairs(self, batch_size, window_size): word_ids[max(i - window_size, 0):i + window_size]): assert u < self.word_count assert v < self.word_count - if i == j: + if u == v: continue self.word_pair_catch.append((u, v)) batch_pairs = [] diff --git a/model.py b/model.py index 8f8bcc1..fcde0db 100644 --- a/model.py +++ b/model.py @@ -85,7 +85,7 @@ def save_embedding(self, id2word, file_name, use_cuda): embedding = self.u_embeddings.weight.cpu().data.numpy() else: embedding = self.u_embeddings.weight.data.numpy() - fout = open(file_name, 'w') + fout = open(file_name, 'w',encoding='utf-8',errors='ignore') fout.write('%d %d\n' % (len(id2word), self.emb_dimension)) for wid, w in id2word.items(): e = embedding[wid] diff --git a/word2vec.py b/word2vec.py index 6b844d9..968239e 100644 --- a/word2vec.py +++ b/word2vec.py @@ -81,7 +81,7 @@ def train(self): self.optimizer.step() process_bar.set_description("Loss: %0.8f, lr: %0.6f" % - (loss.data[0], + (loss, self.optimizer.param_groups[0]['lr'])) if i * self.batch_size % 100000 == 0: lr = self.initial_lr * (1.0 - 1.0 * i / batch_count)