3
3
from keras import random
4
4
5
5
from keras_hub .src .api_export import keras_hub_export
6
- from keras_hub .src .utils .tensor_utils import any_equal
7
6
8
7
9
8
@keras_hub_export ("keras_hub.samplers.Sampler" )
@@ -48,17 +47,11 @@ def get_next_token(self, probs):
48
47
```
49
48
"""
50
49
51
- def __init__ (
52
- self ,
53
- temperature = 1.0 ,
54
- ):
50
+ def __init__ (self , temperature = 1.0 ):
55
51
self .temperature = temperature
56
52
self ._seed_generators = []
57
53
58
54
def __setattr__ (self , name , value ):
59
- # We could update to the `Tracker` class from keras-core if our needs
60
- # become more advanced (e.g. list assignment, nested trackables). For
61
- # now, we only track `SeedGenerator` instances directly on the sampler.
62
55
if isinstance (value , random .SeedGenerator ):
63
56
self ._seed_generators .append (value )
64
57
return super ().__setattr__ (name , value )
@@ -82,54 +75,66 @@ def __call__(
82
75
model = None ,
83
76
):
84
77
max_length = ops .shape (prompt )[- 1 ]
85
- # Make sure `max_length` and `index` are the same dtype.
86
78
index = ops .cast (index , "int32" )
87
79
max_length = ops .cast (max_length , "int32" )
80
+ batch_size = ops .shape (prompt )[0 ]
88
81
if mask is None :
89
82
mask = ops .zeros_like (prompt , dtype = "bool" )
90
83
else :
91
84
mask = ops .cast (mask , dtype = "bool" )
92
- # `ops.while_loop` will not accept `None` as a value for `loop_vars`.
93
85
cache = () if cache is None else cache
86
+ finished = ops .zeros ([batch_size ], dtype = "bool" )
87
+ if stop_token_ids is not None :
88
+ stop_token_ids_tensor = ops .convert_to_tensor (
89
+ stop_token_ids , dtype = prompt .dtype
90
+ )
91
+ else :
92
+ stop_token_ids_tensor = None
94
93
95
- def cond (prompt , cache , index ):
94
+ # Compute generated_mask
95
+ seq_length = ops .shape (prompt )[1 ]
96
+ row_lengths = ops .sum (ops .cast (mask , "int32" ), axis = - 1 )
97
+ indices = ops .arange (seq_length , dtype = "int32" )
98
+ indices = ops .expand_dims (indices , axis = 0 )
99
+ generated_mask = indices >= ops .expand_dims (row_lengths , axis = - 1 )
100
+ generated_mask = ops .cast (generated_mask , "bool" )
101
+
102
+ def cond (prompt , cache , index , finished ):
96
103
if stop_token_ids is None :
97
- return True
98
- # Stop if all sequences have produced a *new* id from
99
- # stop_token_ids.
100
- end_tokens = any_equal (prompt , stop_token_ids , ~ mask )
101
- prompt_done = ops .any (end_tokens , axis = - 1 )
102
- return ops .logical_not (ops .all (prompt_done ))
103
-
104
- def body (prompt , cache , index ):
105
- # Compute the softmax distribution for the next token.
104
+ return index < max_length
105
+ return ops .logical_not (ops .all (finished ))
106
+
107
+ def body (prompt , cache , index , finished ):
106
108
logits , _ , cache = next (prompt , cache , index )
107
109
probabilities = self .compute_probabilities (logits )
108
- # Compute the next token.
109
110
next_token = self .get_next_token (probabilities )
110
- # Don't overwrite anywhere mask is True.
111
111
next_token = ops .cast (next_token , prompt .dtype )
112
+ # Preserve prompt tokens
112
113
next_token = ops .where (mask [:, index ], prompt [:, index ], next_token )
113
- # Update the prompt with the next token.
114
+ if stop_token_ids is not None :
115
+ # Check stop tokens only for generated positions
116
+ # and non-finished sequences
117
+ is_generating = generated_mask [:, index ] & ~ finished
118
+ is_stop = is_generating & ops .any (
119
+ next_token [:, None ] == stop_token_ids_tensor , axis = - 1
120
+ )
121
+ finished = ops .logical_or (finished , is_stop )
114
122
next_token = next_token [:, None ]
115
123
prompt = ops .slice_update (prompt , [0 , index ], next_token )
124
+ return (prompt , cache , index + 1 , finished )
116
125
117
- # Return the next prompt, cache and incremented index.
118
- return (prompt , cache , index + 1 )
119
-
120
- prompt , _ , _ = self .run_loop (
126
+ prompt , _ , _ , _ = self .run_loop (
121
127
cond ,
122
128
body ,
123
- loop_vars = (prompt , cache , index ),
129
+ loop_vars = (prompt , cache , index , finished ),
124
130
maximum_iterations = (max_length - index ),
125
131
model = model ,
126
132
)
127
133
return prompt
128
134
129
135
def compute_probabilities (self , logits ):
130
136
"""Compute token probabilities from logits.
131
-
132
- This will always be done in full precision, regardless of dtype, and
137
+ This will always be done in full precision, regardless of dtype, and
133
138
scale by `temperature`.
134
139
"""
135
140
logits = ops .cast (logits , "float32" )
@@ -138,7 +143,6 @@ def compute_probabilities(self, logits):
138
143
def run_loop (
139
144
self , cond , body , model = None , loop_vars = None , maximum_iterations = None
140
145
):
141
- """Run ops.while_loops with a `StatelessScope` if necessary."""
142
146
if keras .config .backend () == "jax" :
143
147
import itertools
144
148
@@ -165,16 +169,17 @@ def stateless_body(state, *loop_vars):
165
169
)
166
170
with keras .StatelessScope (state_mapping = mapping ) as scope :
167
171
loop_vars = body (* loop_vars )
168
-
169
- sampler_variables = []
170
- for v in self .variables :
171
- new_v = scope .get_current_value (v )
172
- sampler_variables .append (new_v if new_v is not None else v )
173
- state = (
174
- sampler_variables ,
175
- trainable_variables ,
176
- non_trainable_variables ,
177
- )
172
+ sampler_variables = []
173
+ for v in self .variables :
174
+ new_v = scope .get_current_value (v )
175
+ sampler_variables .append (
176
+ new_v if new_v is not None else v
177
+ )
178
+ state = (
179
+ sampler_variables ,
180
+ trainable_variables ,
181
+ non_trainable_variables ,
182
+ )
178
183
return state , * loop_vars
179
184
180
185
variables = [ops .convert_to_tensor (v ) for v in self .variables ]
@@ -184,11 +189,7 @@ def stateless_body(state, *loop_vars):
184
189
non_trainable_variables = [
185
190
ops .convert_to_tensor (v ) for v in model_non_trainable_variables
186
191
]
187
- state = (
188
- variables ,
189
- trainable_variables ,
190
- non_trainable_variables ,
191
- )
192
+ state = (variables , trainable_variables , non_trainable_variables )
192
193
state , * loop_vars = ops .while_loop (
193
194
cond = stateless_cond ,
194
195
body = stateless_body ,
0 commit comments