Skip to content

Commit b09c776

Browse files
committed
bpe learner refractor
1 parent f39f517 commit b09c776

File tree

1 file changed

+77
-47
lines changed

1 file changed

+77
-47
lines changed

python_autocomplete/bpe.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,51 @@
88

99
class BPE:
1010
def __init__(self):
11-
path = lab.get_data_path() / 'train.py'
11+
self.char_itos = []
12+
self.char_stoi = {}
13+
self.bpe_itos = []
14+
self.bpe = []
15+
self.common = {}
16+
17+
self.bpe_itos = self.calc_bpe_itos()
18+
19+
def to_char_stoi(self, w: str):
20+
return [self.char_stoi[c] for c in w]
21+
22+
def calc_bpe_itos(self):
23+
itos = list(self.char_itos)
24+
itos += [itos[p1] + itos[p2] for p1, p2 in self.bpe[len(self.char_itos):]]
25+
return itos
1226

13-
with open(str(path), 'r') as f:
14-
self.data = f.read() # [:100_000]
1527

28+
class BPELearner:
29+
def __init__(self, data: str):
30+
self.data = data
1631
self.words = {}
1732
self.heap = []
1833
self.heap_modified = set()
19-
self.itos = []
20-
self.vocab = {}
34+
self.char_itos = []
35+
self.char_stoi = {}
2136
self.bpe = []
22-
self.word_codes = {}
37+
self.word_codes = []
2338
self.word_code_prev = {}
2439
self.word_code_next = {}
2540

2641
self.counts = {}
2742
self.locations = {}
2843

44+
self.collect_words()
45+
self.build_vocab()
46+
self.build_word_arrays()
47+
self.collect_pairs()
48+
49+
def learn(self, merges: int):
50+
for i in monit.iterate('BPE', merges):
51+
while True:
52+
res = self.merge_pair()
53+
if res is not None:
54+
break
55+
2956
def add_word(self, word):
3057
if not word:
3158
return
@@ -52,32 +79,38 @@ def collect_words(self):
5279
is_id = False
5380

5481
self.add_word(self.data[last_idx:])
82+
words_list = [(f, w) for w, f in self.words.items()]
83+
words_list.sort(key=lambda x: -x[0])
84+
85+
self.words_list = [w for _, w in words_list]
86+
self.word_freq = [f for f, _ in words_list]
5587

5688
def build_vocab(self):
5789
vocab = set()
58-
for k in self.words:
90+
for k in self.words_list:
5991
for c in k:
6092
vocab.add(c)
6193

62-
self.itos = list(sorted(vocab))
63-
self.vocab = {c: i for i, c in enumerate(self.itos)}
94+
self.char_itos = list(sorted(vocab))
95+
self.char_stoi = {c: i for i, c in enumerate(self.char_itos)}
6496

65-
self.bpe = [i for i in range(len(self.vocab))]
97+
self.bpe = [i for i in range(len(self.char_stoi))]
6698

67-
def build_word_arrays(self):
68-
words = {}
69-
for k in self.words:
70-
a = []
71-
for c in k:
72-
a.append(self.vocab[c])
73-
words[k] = a
99+
def to_char_stoi(self, w: str):
100+
return [self.char_stoi[c] for c in w]
101+
102+
@staticmethod
103+
def default_next_pointers(length: int):
104+
return [i + 1 for i in range(length - 1)] + [-1]
74105

75-
self.word_codes = words
106+
@staticmethod
107+
def default_prev_pointers(length: int):
108+
return [i - 1 for i in range(length)]
76109

77-
for k, v in self.word_codes.items():
78-
self.word_code_next[k] = [i + 1 for i in range(len(v))]
79-
self.word_code_prev[k] = [i - 1 for i in range(len(v))]
80-
self.word_code_next[k][-1] = -1
110+
def build_word_arrays(self):
111+
self.word_codes = [self.to_char_stoi(w) for w in self.words_list]
112+
self.word_code_next = [self.default_next_pointers(len(w)) for w in self.word_codes]
113+
self.word_code_prev = [self.default_prev_pointers(len(w)) for w in self.word_codes]
81114

82115
def heap_add_all(self):
83116
for pair in self.heap_modified:
@@ -95,14 +128,14 @@ def add_pair(self, w, i, nxt):
95128
if w not in self.locations[pair]:
96129
self.locations[pair][w] = set()
97130

98-
self.counts[pair] += self.words[w]
131+
self.counts[pair] += self.word_freq[w]
99132
self.locations[pair][w].add(i)
100133

101134
self.heap_modified.add(pair)
102135

103136
def collect_pairs(self):
104-
for w, v in monit.iterate('Collect pairs', self.word_codes.items()):
105-
f = self.words[w]
137+
for w, v in monit.enum('Collect pairs', self.word_codes):
138+
f = self.word_freq[w]
106139

107140
for i in range(len(v) - 1):
108141
self.add_pair(w, i, i + 1)
@@ -114,12 +147,8 @@ def remove_pair(self, w, i, nxt):
114147
assert pair[0] != -1 and pair[1] != -1
115148
if pair not in self.counts:
116149
return
117-
try:
118-
self.locations[pair][w].remove(i)
119-
except:
120-
print(pair, f"|{w}|", i)
121-
raise
122-
self.counts[pair] -= self.words[w]
150+
self.locations[pair][w].remove(i)
151+
self.counts[pair] -= self.word_freq[w]
123152
self.heap_modified.add(pair)
124153

125154
def merge_pair(self):
@@ -177,35 +206,36 @@ def merge_pair(self):
177206
return pair
178207

179208
def bpe_itos(self):
180-
itos = list(self.itos)
181-
for p1, p2 in self.bpe[len(self.itos):]:
209+
itos = list(self.char_itos)
210+
for p1, p2 in self.bpe[len(self.char_itos):]:
182211
itos.append(itos[p1] + itos[p2])
183212

184213
return itos
185214

186215
def get_length(self):
187216
res = 0
188-
for w, v in self.word_codes:
217+
for w, v in enumerate(self.word_codes):
189218
cnt = 0
190219
for idx in v:
191220
if idx != -1:
192221
cnt += 1
193-
res += cnt * self.words[w]
222+
res += cnt * self.word_freq[w]
194223

195224
return res
196225

197226

198-
if __name__ == '__main__':
199-
bpe = BPE()
200-
bpe.collect_words()
201-
bpe.build_vocab()
202-
bpe.build_word_arrays()
203-
bpe.collect_pairs()
204-
for i in monit.iterate('BPE', 1_000):
205-
while True:
206-
res = bpe.merge_pair()
207-
if res is not None:
208-
break
227+
def main():
228+
path = lab.get_data_path() / 'train.py'
229+
230+
with open(str(path), 'r') as f:
231+
data = f.read()[:100_000]
232+
233+
bpe = BPELearner(data)
234+
bpe.learn(1000)
209235
print(len(bpe.bpe))
210-
print(bpe.bpe_itos()[len(bpe.itos):])
236+
print(bpe.bpe_itos()[len(bpe.char_itos):])
211237
print(len(bpe.data), bpe.get_length())
238+
239+
240+
if __name__ == '__main__':
241+
main()

0 commit comments

Comments
 (0)