Skip to content

Commit ed4267c

Browse files
authored
Add files via upload
1 parent 733108d commit ed4267c

File tree

2 files changed

+50
-15
lines changed

2 files changed

+50
-15
lines changed

bert4keras3/layers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,6 @@ def __init__(self, end_token,k=1, **kwargs):
438438
super(SearchBase, self).__init__(**kwargs)
439439
self.k = k
440440
self.end_token=end_token
441-
self.seed = int(np.random.get_state()[1][0])
442441
def get_config(self):
443442
config = {
444443
'k': self.k,
@@ -454,7 +453,7 @@ def sample(self,x):
454453
# sure we have full precision here.
455454
x,
456455
1,
457-
seed=self.seed ,
456+
seed=np.random.randint(1,2147483648),
458457
dtype="int32",
459458
)
460459

bert4keras3/models.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ def __init__(
3737
layers=None, # 外部传入的Keras层
3838
prefix=None, # 层名前缀
3939
name=None, # 模型名称
40+
segment_vocab_size=0,
4041
**kwargs
4142
):
4243
if keep_tokens is not None:
4344
vocab_size = len(keep_tokens)
4445
if compound_tokens is not None:
4546
vocab_size += len(compound_tokens)
4647
self.vocab_size = vocab_size
48+
self.segment_vocab_size = segment_vocab_size
4749
self.hidden_size = hidden_size
4850
self.num_hidden_layers = num_hidden_layers
4951
self.num_attention_heads = num_attention_heads
@@ -346,23 +348,23 @@ def Search(self,inputs,k=1,mode='greedy'):
346348
layer=ToppSearch,
347349
k=k,
348350
end_token=self.end_token,
349-
name='ToppSearchLayer'
351+
name='SearchLayer'
350352
)
351353
elif mode=='topk':
352354
return self.apply(
353355
inputs=inputs,
354356
layer=TopkSearch,
355357
k=k,
356358
end_token=self.end_token,
357-
name='TopkSearchLayer'
359+
name='SearchLayer'
358360
)
359361
else:
360362
return self.apply(
361363
inputs=inputs,
362364
layer=GreedySearch,
363365
k=k,
364366
end_token=self.end_token,
365-
name='GreedySearchLayer'
367+
name='SearchLayer'
366368
)
367369

368370
def compute_cache_position_bias(self, inputs=None,self_cache_update_index=None,index=None):
@@ -480,7 +482,7 @@ def cache_call(self,inputs:list,input_lengths:list,end_token,
480482
#initial inputs and cache
481483

482484
z = self.apply_embeddings(inputs)
483-
#print(z)
485+
484486
if not isinstance(z,list):
485487
z = [z]
486488
j = len(caches)//self.num_hidden_layers
@@ -518,7 +520,7 @@ def body(inputs, caches, index , flags):
518520
xs = self.slice_inputs(inputs,key,index)
519521
self.custom_position_ids = self.get_custom_position_ids()
520522
new_inputs = self.get_new_inputs(inputs,key,xs,index)
521-
#print(xs)
523+
522524
z = self.apply_embeddings(new_inputs)
523525

524526
if not isinstance(z,list):
@@ -537,14 +539,12 @@ def body(inputs, caches, index , flags):
537539

538540
caches[i*j:i*j+j]=cache
539541

540-
#print(z[1])
542+
541543
o = self.apply_final_layers(z)
542544

543545
index += 1
544546
search_in = [o,index,inputs[key],flags]
545547
inputs[key],flags = self.Search(search_in,k=k,mode=search_mode)
546-
#print(index)
547-
#print(ops.cast(inputs[key],'int32'))
548548
return (inputs, caches, index , flags)
549549
class WhileLayer(keras.Layer):
550550
def call(self, x):
@@ -575,16 +575,18 @@ def compute_output_shape(self, input_shape):
575575
def build_cache_model(self,input_lengths:list,end_token,
576576
search_mode='greedy',k=1,progress_print=False,index_bias=0):
577577

578-
579578
inputs=self.get_cache_inputs(input_lengths)
580-
out = self.cache_call(inputs,input_lengths,end_token,search_mode,k,progress_print,index_bias)
579+
580+
out = self.cache_call(inputs=inputs,input_lengths=input_lengths,end_token=end_token,
581+
search_mode=search_mode,k=k,progress_print=progress_print,index_bias=index_bias)
581582
model = keras.Model(inputs,out)
582583
inputs = []
583584
for modelin in model.inputs:
584585
shape=keras.ops.shape(modelin)
585586
shape=[1 if t==None else t for t in shape]
586587
inputs.append(ops.convert_to_tensor(np.ones(shape),modelin.dtype))
587-
self.cache_call(inputs,input_lengths,end_token)
588+
self.cache_call(inputs=inputs,input_lengths=input_lengths,end_token=end_token,
589+
search_mode=search_mode,k=k,progress_print=progress_print,index_bias=index_bias)
588590

589591
return model
590592
class UniLM_Mask(LM_Mask):
@@ -2454,13 +2456,23 @@ def get_inputs(self):
24542456
shape=(self.sequence_length,),
24552457
name='Encoder-Input-Token'
24562458
)
2459+
if self.segment_vocab_size > 0:
2460+
s_in = self.apply(
2461+
layer=Input,
2462+
shape=(self.sequence_length,),
2463+
name='Segment-Input-Token'
2464+
)
2465+
return [x_in,s_in]
24572466
return x_in
24582467

24592468
def apply_embeddings(self, inputs):
24602469
"""T5的embedding只有token embedding,
24612470
并把relative position embedding准备好,待attention使用。
24622471
"""
2463-
x = inputs
2472+
if type(inputs)==list:
2473+
x,s = inputs[:]
2474+
else:
2475+
x = inputs
24642476

24652477
x = self.apply(
24662478
inputs=x,
@@ -2471,6 +2483,18 @@ def apply_embeddings(self, inputs):
24712483
mask_zero=True,
24722484
name='Embedding-Token'
24732485
)
2486+
if self.segment_vocab_size > 0:
2487+
s = self.apply(
2488+
inputs=s,
2489+
layer=Embedding,
2490+
input_dim=self.segment_vocab_size,
2491+
output_dim=self.embedding_size,
2492+
embeddings_initializer='zeros',
2493+
name='Embedding-Segment'
2494+
)
2495+
x = self.apply(
2496+
inputs=[x, s], layer=Add, name='Embedding-Token-Segment'
2497+
)
24742498
x = self.apply(
24752499
inputs=x,
24762500
layer=Dropout,
@@ -3128,7 +3152,7 @@ def build_cache_model(self,input_lengths:list,end_token,search_mode='greedy',k=1
31283152
progress_print,
31293153
index_bias)
31303154
y = self.cache_decoder([self.encoder.output,self.cache_decoder.inputs[1]])
3131-
self.cache_t5 = keras.Model([self.encoder.inputs[0],self.cache_decoder.inputs[1]],y)
3155+
self.cache_t5 = keras.Model(self.encoder.inputs[:]+self.cache_decoder.inputs[1:],y)
31323156

31333157
return self.cache_t5
31343158
def extend_with_language_model(BaseModel):
@@ -3266,6 +3290,18 @@ def apply_embeddings(self, inputs):
32663290
mask_zero=True,
32673291
name='Embedding-Token'
32683292
)
3293+
if self.segment_vocab_size > 0:
3294+
s = self.apply(
3295+
inputs=s,
3296+
layer=Embedding,
3297+
input_dim=self.segment_vocab_size,
3298+
output_dim=self.embedding_size,
3299+
embeddings_initializer='zeros',
3300+
name='Embedding-Segment'
3301+
)
3302+
x = self.apply(
3303+
inputs=[x, s], layer=Add, name='Embedding-Token-Segment'
3304+
)
32693305
x = self.apply(
32703306
inputs=x,
32713307
layer=Dropout,

0 commit comments

Comments
 (0)