8
8
9
9
class BPE :
10
10
def __init__ (self ):
11
- path = lab .get_data_path () / 'train.py'
11
+ self .char_itos = []
12
+ self .char_stoi = {}
13
+ self .bpe_itos = []
14
+ self .bpe = []
15
+ self .common = {}
16
+
17
+ self .bpe_itos = self .calc_bpe_itos ()
18
+
19
+ def to_char_stoi (self , w : str ):
20
+ return [self .char_stoi [c ] for c in w ]
21
+
22
+ def calc_bpe_itos (self ):
23
+ itos = list (self .char_itos )
24
+ itos += [itos [p1 ] + itos [p2 ] for p1 , p2 in self .bpe [len (self .char_itos ):]]
25
+ return itos
12
26
13
- with open (str (path ), 'r' ) as f :
14
- self .data = f .read () # [:100_000]
15
27
28
+ class BPELearner :
29
+ def __init__ (self , data : str ):
30
+ self .data = data
16
31
self .words = {}
17
32
self .heap = []
18
33
self .heap_modified = set ()
19
- self .itos = []
20
- self .vocab = {}
34
+ self .char_itos = []
35
+ self .char_stoi = {}
21
36
self .bpe = []
22
- self .word_codes = {}
37
+ self .word_codes = []
23
38
self .word_code_prev = {}
24
39
self .word_code_next = {}
25
40
26
41
self .counts = {}
27
42
self .locations = {}
28
43
44
+ self .collect_words ()
45
+ self .build_vocab ()
46
+ self .build_word_arrays ()
47
+ self .collect_pairs ()
48
+
49
+ def learn (self , merges : int ):
50
+ for i in monit .iterate ('BPE' , merges ):
51
+ while True :
52
+ res = self .merge_pair ()
53
+ if res is not None :
54
+ break
55
+
29
56
def add_word (self , word ):
30
57
if not word :
31
58
return
@@ -52,32 +79,38 @@ def collect_words(self):
52
79
is_id = False
53
80
54
81
self .add_word (self .data [last_idx :])
82
+ words_list = [(f , w ) for w , f in self .words .items ()]
83
+ words_list .sort (key = lambda x : - x [0 ])
84
+
85
+ self .words_list = [w for _ , w in words_list ]
86
+ self .word_freq = [f for f , _ in words_list ]
55
87
56
88
def build_vocab (self ):
57
89
vocab = set ()
58
- for k in self .words :
90
+ for k in self .words_list :
59
91
for c in k :
60
92
vocab .add (c )
61
93
62
- self .itos = list (sorted (vocab ))
63
- self .vocab = {c : i for i , c in enumerate (self .itos )}
94
+ self .char_itos = list (sorted (vocab ))
95
+ self .char_stoi = {c : i for i , c in enumerate (self .char_itos )}
64
96
65
- self .bpe = [i for i in range (len (self .vocab ))]
97
+ self .bpe = [i for i in range (len (self .char_stoi ))]
66
98
67
- def build_word_arrays (self ):
68
- words = {}
69
- for k in self .words :
70
- a = []
71
- for c in k :
72
- a .append (self .vocab [c ])
73
- words [k ] = a
99
+ def to_char_stoi (self , w : str ):
100
+ return [self .char_stoi [c ] for c in w ]
101
+
102
+ @staticmethod
103
+ def default_next_pointers (length : int ):
104
+ return [i + 1 for i in range (length - 1 )] + [- 1 ]
74
105
75
- self .word_codes = words
106
+ @staticmethod
107
+ def default_prev_pointers (length : int ):
108
+ return [i - 1 for i in range (length )]
76
109
77
- for k , v in self . word_codes . items ( ):
78
- self .word_code_next [ k ] = [i + 1 for i in range ( len ( v )) ]
79
- self .word_code_prev [ k ] = [i - 1 for i in range ( len ( v )) ]
80
- self .word_code_next [ k ][ - 1 ] = - 1
110
+ def build_word_arrays ( self ):
111
+ self .word_codes = [self . to_char_stoi ( w ) for w in self . words_list ]
112
+ self .word_code_next = [self . default_next_pointers ( len ( w )) for w in self . word_codes ]
113
+ self . word_code_prev = [ self . default_prev_pointers ( len ( w )) for w in self .word_codes ]
81
114
82
115
def heap_add_all (self ):
83
116
for pair in self .heap_modified :
@@ -95,14 +128,14 @@ def add_pair(self, w, i, nxt):
95
128
if w not in self .locations [pair ]:
96
129
self .locations [pair ][w ] = set ()
97
130
98
- self .counts [pair ] += self .words [w ]
131
+ self .counts [pair ] += self .word_freq [w ]
99
132
self .locations [pair ][w ].add (i )
100
133
101
134
self .heap_modified .add (pair )
102
135
103
136
def collect_pairs (self ):
104
- for w , v in monit .iterate ('Collect pairs' , self .word_codes . items () ):
105
- f = self .words [w ]
137
+ for w , v in monit .enum ('Collect pairs' , self .word_codes ):
138
+ f = self .word_freq [w ]
106
139
107
140
for i in range (len (v ) - 1 ):
108
141
self .add_pair (w , i , i + 1 )
@@ -114,12 +147,8 @@ def remove_pair(self, w, i, nxt):
114
147
assert pair [0 ] != - 1 and pair [1 ] != - 1
115
148
if pair not in self .counts :
116
149
return
117
- try :
118
- self .locations [pair ][w ].remove (i )
119
- except :
120
- print (pair , f"|{ w } |" , i )
121
- raise
122
- self .counts [pair ] -= self .words [w ]
150
+ self .locations [pair ][w ].remove (i )
151
+ self .counts [pair ] -= self .word_freq [w ]
123
152
self .heap_modified .add (pair )
124
153
125
154
def merge_pair (self ):
@@ -177,35 +206,36 @@ def merge_pair(self):
177
206
return pair
178
207
179
208
def bpe_itos (self ):
180
- itos = list (self .itos )
181
- for p1 , p2 in self .bpe [len (self .itos ):]:
209
+ itos = list (self .char_itos )
210
+ for p1 , p2 in self .bpe [len (self .char_itos ):]:
182
211
itos .append (itos [p1 ] + itos [p2 ])
183
212
184
213
return itos
185
214
186
215
def get_length (self ):
187
216
res = 0
188
- for w , v in self .word_codes :
217
+ for w , v in enumerate ( self .word_codes ) :
189
218
cnt = 0
190
219
for idx in v :
191
220
if idx != - 1 :
192
221
cnt += 1
193
- res += cnt * self .words [w ]
222
+ res += cnt * self .word_freq [w ]
194
223
195
224
return res
196
225
197
226
198
- if __name__ == '__main__' :
199
- bpe = BPE ()
200
- bpe .collect_words ()
201
- bpe .build_vocab ()
202
- bpe .build_word_arrays ()
203
- bpe .collect_pairs ()
204
- for i in monit .iterate ('BPE' , 1_000 ):
205
- while True :
206
- res = bpe .merge_pair ()
207
- if res is not None :
208
- break
227
+ def main ():
228
+ path = lab .get_data_path () / 'train.py'
229
+
230
+ with open (str (path ), 'r' ) as f :
231
+ data = f .read ()[:100_000 ]
232
+
233
+ bpe = BPELearner (data )
234
+ bpe .learn (1000 )
209
235
print (len (bpe .bpe ))
210
- print (bpe .bpe_itos ()[len (bpe .itos ):])
236
+ print (bpe .bpe_itos ()[len (bpe .char_itos ):])
211
237
print (len (bpe .data ), bpe .get_length ())
238
+
239
+
240
+ if __name__ == '__main__' :
241
+ main ()
0 commit comments