Skip to content

Commit f39f517

Browse files
committed
bpe
1 parent c1e535d commit f39f517

File tree

1 file changed

+211
-0
lines changed

1 file changed

+211
-0
lines changed

python_autocomplete/bpe.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import string
2+
from heapq import heappush, heappop
3+
4+
from labml import lab, monit
5+
6+
ID_CHARS = set(string.ascii_letters + string.digits + '_')
7+
8+
9+
class BPE:
10+
def __init__(self):
11+
path = lab.get_data_path() / 'train.py'
12+
13+
with open(str(path), 'r') as f:
14+
self.data = f.read() # [:100_000]
15+
16+
self.words = {}
17+
self.heap = []
18+
self.heap_modified = set()
19+
self.itos = []
20+
self.vocab = {}
21+
self.bpe = []
22+
self.word_codes = {}
23+
self.word_code_prev = {}
24+
self.word_code_next = {}
25+
26+
self.counts = {}
27+
self.locations = {}
28+
29+
def add_word(self, word):
30+
if not word:
31+
return
32+
33+
if word not in self.words:
34+
self.words[word] = 1
35+
else:
36+
self.words[word] += 1
37+
38+
def collect_words(self):
39+
last_idx = 0
40+
is_id = False
41+
42+
for i, c in monit.enum('Collect words', self.data):
43+
if c in ID_CHARS:
44+
if not is_id:
45+
self.add_word(self.data[last_idx:i])
46+
last_idx = i
47+
is_id = True
48+
else:
49+
if is_id:
50+
self.add_word(self.data[last_idx:i])
51+
last_idx = i
52+
is_id = False
53+
54+
self.add_word(self.data[last_idx:])
55+
56+
def build_vocab(self):
57+
vocab = set()
58+
for k in self.words:
59+
for c in k:
60+
vocab.add(c)
61+
62+
self.itos = list(sorted(vocab))
63+
self.vocab = {c: i for i, c in enumerate(self.itos)}
64+
65+
self.bpe = [i for i in range(len(self.vocab))]
66+
67+
def build_word_arrays(self):
68+
words = {}
69+
for k in self.words:
70+
a = []
71+
for c in k:
72+
a.append(self.vocab[c])
73+
words[k] = a
74+
75+
self.word_codes = words
76+
77+
for k, v in self.word_codes.items():
78+
self.word_code_next[k] = [i + 1 for i in range(len(v))]
79+
self.word_code_prev[k] = [i - 1 for i in range(len(v))]
80+
self.word_code_next[k][-1] = -1
81+
82+
def heap_add_all(self):
83+
for pair in self.heap_modified:
84+
if pair in self.counts:
85+
heappush(self.heap, (-self.counts[pair], pair))
86+
87+
def add_pair(self, w, i, nxt):
88+
pair = self.word_codes[w][i], self.word_codes[w][nxt]
89+
assert pair[0] != -1 and pair[1] != -1
90+
91+
if pair not in self.counts:
92+
self.counts[pair] = 0
93+
self.locations[pair] = {}
94+
95+
if w not in self.locations[pair]:
96+
self.locations[pair][w] = set()
97+
98+
self.counts[pair] += self.words[w]
99+
self.locations[pair][w].add(i)
100+
101+
self.heap_modified.add(pair)
102+
103+
def collect_pairs(self):
104+
for w, v in monit.iterate('Collect pairs', self.word_codes.items()):
105+
f = self.words[w]
106+
107+
for i in range(len(v) - 1):
108+
self.add_pair(w, i, i + 1)
109+
110+
self.heap_add_all()
111+
112+
def remove_pair(self, w, i, nxt):
113+
pair = self.word_codes[w][i], self.word_codes[w][nxt]
114+
assert pair[0] != -1 and pair[1] != -1
115+
if pair not in self.counts:
116+
return
117+
try:
118+
self.locations[pair][w].remove(i)
119+
except:
120+
print(pair, f"|{w}|", i)
121+
raise
122+
self.counts[pair] -= self.words[w]
123+
self.heap_modified.add(pair)
124+
125+
def merge_pair(self):
126+
cnt, pair = heappop(self.heap)
127+
if pair not in self.counts or self.counts[pair] != -cnt:
128+
return None
129+
130+
n = len(self.bpe)
131+
self.bpe.append(pair)
132+
del self.counts[pair]
133+
for w, locs in self.locations[pair].items():
134+
locs = list(reversed(sorted(locs)))
135+
prev = None
136+
merged = []
137+
for p2 in locs:
138+
p1 = self.word_code_prev[w][p2]
139+
p3 = self.word_code_next[w][p2]
140+
assert p3 != -1
141+
if p3 == prev:
142+
continue
143+
p4 = self.word_code_next[w][p3]
144+
145+
if p1 != -1:
146+
self.remove_pair(w, p1, p2)
147+
if p4 != -1 and p4 != prev:
148+
self.remove_pair(w, p3, p4)
149+
150+
prev = p2
151+
merged.append(p2)
152+
153+
for p2 in merged:
154+
p3 = self.word_code_next[w][p2]
155+
p4 = self.word_code_next[w][p3]
156+
self.word_codes[w][p2] = n
157+
self.word_codes[w][p3] = -1
158+
self.word_code_next[w][p3] = -1
159+
self.word_code_prev[w][p3] = -1
160+
if p4 != -1:
161+
self.word_code_next[w][p2] = p4
162+
self.word_code_prev[w][p4] = p2
163+
else:
164+
self.word_code_next[w][p2] = -1
165+
166+
for p2 in merged:
167+
p1 = self.word_code_prev[w][p2]
168+
p3 = self.word_code_next[w][p2]
169+
170+
if p1 != -1:
171+
self.add_pair(w, p1, p2)
172+
if p3 != -1:
173+
self.add_pair(w, p2, p3)
174+
175+
self.heap_add_all()
176+
177+
return pair
178+
179+
def bpe_itos(self):
180+
itos = list(self.itos)
181+
for p1, p2 in self.bpe[len(self.itos):]:
182+
itos.append(itos[p1] + itos[p2])
183+
184+
return itos
185+
186+
def get_length(self):
187+
res = 0
188+
for w, v in self.word_codes:
189+
cnt = 0
190+
for idx in v:
191+
if idx != -1:
192+
cnt += 1
193+
res += cnt * self.words[w]
194+
195+
return res
196+
197+
198+
if __name__ == '__main__':
199+
bpe = BPE()
200+
bpe.collect_words()
201+
bpe.build_vocab()
202+
bpe.build_word_arrays()
203+
bpe.collect_pairs()
204+
for i in monit.iterate('BPE', 1_000):
205+
while True:
206+
res = bpe.merge_pair()
207+
if res is not None:
208+
break
209+
print(len(bpe.bpe))
210+
print(bpe.bpe_itos()[len(bpe.itos):])
211+
print(len(bpe.data), bpe.get_length())

0 commit comments

Comments
 (0)