@@ -37,13 +37,15 @@ def __init__(
37
37
layers = None , # 外部传入的Keras层
38
38
prefix = None , # 层名前缀
39
39
name = None , # 模型名称
40
+ segment_vocab_size = 0 ,
40
41
** kwargs
41
42
):
42
43
if keep_tokens is not None :
43
44
vocab_size = len (keep_tokens )
44
45
if compound_tokens is not None :
45
46
vocab_size += len (compound_tokens )
46
47
self .vocab_size = vocab_size
48
+ self .segment_vocab_size = segment_vocab_size
47
49
self .hidden_size = hidden_size
48
50
self .num_hidden_layers = num_hidden_layers
49
51
self .num_attention_heads = num_attention_heads
@@ -346,23 +348,23 @@ def Search(self,inputs,k=1,mode='greedy'):
346
348
layer = ToppSearch ,
347
349
k = k ,
348
350
end_token = self .end_token ,
349
- name = 'ToppSearchLayer '
351
+ name = 'SearchLayer '
350
352
)
351
353
elif mode == 'topk' :
352
354
return self .apply (
353
355
inputs = inputs ,
354
356
layer = TopkSearch ,
355
357
k = k ,
356
358
end_token = self .end_token ,
357
- name = 'TopkSearchLayer '
359
+ name = 'SearchLayer '
358
360
)
359
361
else :
360
362
return self .apply (
361
363
inputs = inputs ,
362
364
layer = GreedySearch ,
363
365
k = k ,
364
366
end_token = self .end_token ,
365
- name = 'GreedySearchLayer '
367
+ name = 'SearchLayer '
366
368
)
367
369
368
370
def compute_cache_position_bias (self , inputs = None ,self_cache_update_index = None ,index = None ):
@@ -480,7 +482,7 @@ def cache_call(self,inputs:list,input_lengths:list,end_token,
480
482
#initial inputs and cache
481
483
482
484
z = self .apply_embeddings (inputs )
483
- #print(z)
485
+
484
486
if not isinstance (z ,list ):
485
487
z = [z ]
486
488
j = len (caches )// self .num_hidden_layers
@@ -518,7 +520,7 @@ def body(inputs, caches, index , flags):
518
520
xs = self .slice_inputs (inputs ,key ,index )
519
521
self .custom_position_ids = self .get_custom_position_ids ()
520
522
new_inputs = self .get_new_inputs (inputs ,key ,xs ,index )
521
- #print(xs)
523
+
522
524
z = self .apply_embeddings (new_inputs )
523
525
524
526
if not isinstance (z ,list ):
@@ -537,14 +539,12 @@ def body(inputs, caches, index , flags):
537
539
538
540
caches [i * j :i * j + j ]= cache
539
541
540
- #print(z[1])
542
+
541
543
o = self .apply_final_layers (z )
542
544
543
545
index += 1
544
546
search_in = [o ,index ,inputs [key ],flags ]
545
547
inputs [key ],flags = self .Search (search_in ,k = k ,mode = search_mode )
546
- #print(index)
547
- #print(ops.cast(inputs[key],'int32'))
548
548
return (inputs , caches , index , flags )
549
549
class WhileLayer (keras .Layer ):
550
550
def call (self , x ):
@@ -575,16 +575,18 @@ def compute_output_shape(self, input_shape):
575
575
def build_cache_model (self ,input_lengths :list ,end_token ,
576
576
search_mode = 'greedy' ,k = 1 ,progress_print = False ,index_bias = 0 ):
577
577
578
-
579
578
inputs = self .get_cache_inputs (input_lengths )
580
- out = self .cache_call (inputs ,input_lengths ,end_token ,search_mode ,k ,progress_print ,index_bias )
579
+
580
+ out = self .cache_call (inputs = inputs ,input_lengths = input_lengths ,end_token = end_token ,
581
+ search_mode = search_mode ,k = k ,progress_print = progress_print ,index_bias = index_bias )
581
582
model = keras .Model (inputs ,out )
582
583
inputs = []
583
584
for modelin in model .inputs :
584
585
shape = keras .ops .shape (modelin )
585
586
shape = [1 if t == None else t for t in shape ]
586
587
inputs .append (ops .convert_to_tensor (np .ones (shape ),modelin .dtype ))
587
- self .cache_call (inputs ,input_lengths ,end_token )
588
+ self .cache_call (inputs = inputs ,input_lengths = input_lengths ,end_token = end_token ,
589
+ search_mode = search_mode ,k = k ,progress_print = progress_print ,index_bias = index_bias )
588
590
589
591
return model
590
592
class UniLM_Mask (LM_Mask ):
@@ -2454,13 +2456,23 @@ def get_inputs(self):
2454
2456
shape = (self .sequence_length ,),
2455
2457
name = 'Encoder-Input-Token'
2456
2458
)
2459
+ if self .segment_vocab_size > 0 :
2460
+ s_in = self .apply (
2461
+ layer = Input ,
2462
+ shape = (self .sequence_length ,),
2463
+ name = 'Segment-Input-Token'
2464
+ )
2465
+ return [x_in ,s_in ]
2457
2466
return x_in
2458
2467
2459
2468
def apply_embeddings (self , inputs ):
2460
2469
"""T5的embedding只有token embedding,
2461
2470
并把relative position embedding准备好,待attention使用。
2462
2471
"""
2463
- x = inputs
2472
+ if type (inputs )== list :
2473
+ x ,s = inputs [:]
2474
+ else :
2475
+ x = inputs
2464
2476
2465
2477
x = self .apply (
2466
2478
inputs = x ,
@@ -2471,6 +2483,18 @@ def apply_embeddings(self, inputs):
2471
2483
mask_zero = True ,
2472
2484
name = 'Embedding-Token'
2473
2485
)
2486
+ if self .segment_vocab_size > 0 :
2487
+ s = self .apply (
2488
+ inputs = s ,
2489
+ layer = Embedding ,
2490
+ input_dim = self .segment_vocab_size ,
2491
+ output_dim = self .embedding_size ,
2492
+ embeddings_initializer = 'zeros' ,
2493
+ name = 'Embedding-Segment'
2494
+ )
2495
+ x = self .apply (
2496
+ inputs = [x , s ], layer = Add , name = 'Embedding-Token-Segment'
2497
+ )
2474
2498
x = self .apply (
2475
2499
inputs = x ,
2476
2500
layer = Dropout ,
@@ -3128,7 +3152,7 @@ def build_cache_model(self,input_lengths:list,end_token,search_mode='greedy',k=1
3128
3152
progress_print ,
3129
3153
index_bias )
3130
3154
y = self .cache_decoder ([self .encoder .output ,self .cache_decoder .inputs [1 ]])
3131
- self .cache_t5 = keras .Model ([ self .encoder .inputs [0 ], self .cache_decoder .inputs [1 ] ],y )
3155
+ self .cache_t5 = keras .Model (self .encoder .inputs [:] + self .cache_decoder .inputs [1 : ],y )
3132
3156
3133
3157
return self .cache_t5
3134
3158
def extend_with_language_model (BaseModel ):
@@ -3266,6 +3290,18 @@ def apply_embeddings(self, inputs):
3266
3290
mask_zero = True ,
3267
3291
name = 'Embedding-Token'
3268
3292
)
3293
+ if self .segment_vocab_size > 0 :
3294
+ s = self .apply (
3295
+ inputs = s ,
3296
+ layer = Embedding ,
3297
+ input_dim = self .segment_vocab_size ,
3298
+ output_dim = self .embedding_size ,
3299
+ embeddings_initializer = 'zeros' ,
3300
+ name = 'Embedding-Segment'
3301
+ )
3302
+ x = self .apply (
3303
+ inputs = [x , s ], layer = Add , name = 'Embedding-Token-Segment'
3304
+ )
3269
3305
x = self .apply (
3270
3306
inputs = x ,
3271
3307
layer = Dropout ,
0 commit comments