@@ -52,6 +52,9 @@ def __init__(self, temperature=1.0):
52
52
self ._seed_generators = []
53
53
54
54
def __setattr__ (self , name , value ):
55
+ # We could update to the `Tracker` class from keras-core if our needs
56
+ # become more advanced (e.g. list assignment, nested trackables). For
57
+ # now, we only track `SeedGenerator` instances directly on the sampler.
55
58
if isinstance (value , random .SeedGenerator ):
56
59
self ._seed_generators .append (value )
57
60
return super ().__setattr__ (name , value )
@@ -75,13 +78,15 @@ def __call__(
75
78
model = None ,
76
79
):
77
80
max_length = ops .shape (prompt )[- 1 ]
81
+ # Make sure `max_length` and `index` are the same dtype.
78
82
index = ops .cast (index , "int32" )
79
83
max_length = ops .cast (max_length , "int32" )
80
84
batch_size = ops .shape (prompt )[0 ]
81
85
if mask is None :
82
86
mask = ops .zeros_like (prompt , dtype = "bool" )
83
87
else :
84
88
mask = ops .cast (mask , dtype = "bool" )
89
+ # `ops.while_loop` will not accept `None` as a value for `loop_vars`.
85
90
cache = () if cache is None else cache
86
91
finished = ops .zeros ([batch_size ], dtype = "bool" )
87
92
if stop_token_ids is not None :
0 commit comments