Skip to content

Commit e836a78

Browse files
committed
comments
1 parent cdb5f76 commit e836a78

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

keras_hub/src/samplers/sampler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def __init__(self, temperature=1.0):
5252
self._seed_generators = []
5353

5454
def __setattr__(self, name, value):
55+
# We could update to the `Tracker` class from keras-core if our needs
56+
# become more advanced (e.g. list assignment, nested trackables). For
57+
# now, we only track `SeedGenerator` instances directly on the sampler.
5558
if isinstance(value, random.SeedGenerator):
5659
self._seed_generators.append(value)
5760
return super().__setattr__(name, value)
@@ -75,13 +78,15 @@ def __call__(
7578
model=None,
7679
):
7780
max_length = ops.shape(prompt)[-1]
81+
# Make sure `max_length` and `index` are the same dtype.
7882
index = ops.cast(index, "int32")
7983
max_length = ops.cast(max_length, "int32")
8084
batch_size = ops.shape(prompt)[0]
8185
if mask is None:
8286
mask = ops.zeros_like(prompt, dtype="bool")
8387
else:
8488
mask = ops.cast(mask, dtype="bool")
89+
# `ops.while_loop` will not accept `None` as a value for `loop_vars`.
8590
cache = () if cache is None else cache
8691
finished = ops.zeros([batch_size], dtype="bool")
8792
if stop_token_ids is not None:

0 commit comments

Comments
 (0)