Skip to content

Commit 05b1f0e

Browse files
committed
updated stop condition in sampler
1 parent e5e5eb9 commit 05b1f0e

File tree

3 files changed

+62
-65
lines changed

3 files changed

+62
-65
lines changed

keras_hub/api/models/__init__.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -579,25 +579,24 @@
579579
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
580580
StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor,
581581
)
582+
from keras_hub.src.models.stablelm.stablelm_backbone import (
583+
StableLMBackbone as StableLMBackbone,
584+
)
585+
from keras_hub.src.models.stablelm.stablelm_causal_lm import (
586+
StableLMCausalLM as StableLMCausalLM,
587+
)
588+
from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import (
589+
StableLMCausalLMPreprocessor as StableLMCausalLMPreprocessor,
590+
)
591+
from keras_hub.src.models.stablelm.stablelm_tokenizer import (
592+
StableLMTokenizer as StableLMTokenizer,
593+
)
582594
from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone
583595
from keras_hub.src.models.t5.t5_preprocessor import (
584596
T5Preprocessor as T5Preprocessor,
585597
)
586-
587-
from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone
588-
from keras_hub.src.models.stablelm.stablelm_causal_lm import StableLMCausalLM
589-
from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import (
590-
StableLMCausalLMPreprocessor,
591-
)
592-
from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer
593-
from keras_hub.src.models.t5.t5_backbone import T5Backbone
594-
from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor
595-
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
596-
from keras_hub.src.models.task import Task
597-
from keras_hub.src.models.text_classifier import TextClassifier
598598
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
599599
from keras_hub.src.models.task import Task as Task
600-
601600
from keras_hub.src.models.text_classifier import TextClassifier as Classifier
602601
from keras_hub.src.models.text_classifier import (
603602
TextClassifier as TextClassifier,

keras_hub/api/tokenizers/__init__.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,13 @@
8686
from keras_hub.src.models.siglip.siglip_tokenizer import (
8787
SigLIPTokenizer as SigLIPTokenizer,
8888
)
89+
from keras_hub.src.models.stablelm.stablelm_tokenizer import (
90+
StableLMTokenizer as StableLMTokenizer,
91+
)
8992
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
9093
from keras_hub.src.models.whisper.whisper_tokenizer import (
9194
WhisperTokenizer as WhisperTokenizer,
9295
)
93-
from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
94-
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
95-
from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
96-
from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer
97-
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
98-
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
9996
from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import (
10097
XLMRobertaTokenizer as XLMRobertaTokenizer,
10198
)

keras_hub/src/samplers/sampler.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from keras import random
44

55
from keras_hub.src.api_export import keras_hub_export
6-
from keras_hub.src.utils.tensor_utils import any_equal
76

87

98
@keras_hub_export("keras_hub.samplers.Sampler")
@@ -48,17 +47,11 @@ def get_next_token(self, probs):
4847
```
4948
"""
5049

51-
def __init__(
52-
self,
53-
temperature=1.0,
54-
):
50+
def __init__(self, temperature=1.0):
5551
self.temperature = temperature
5652
self._seed_generators = []
5753

5854
def __setattr__(self, name, value):
59-
# We could update to the `Tracker` class from keras-core if our needs
60-
# become more advanced (e.g. list assignment, nested trackables). For
61-
# now, we only track `SeedGenerator` instances directly on the sampler.
6255
if isinstance(value, random.SeedGenerator):
6356
self._seed_generators.append(value)
6457
return super().__setattr__(name, value)
@@ -82,54 +75,66 @@ def __call__(
8275
model=None,
8376
):
8477
max_length = ops.shape(prompt)[-1]
85-
# Make sure `max_length` and `index` are the same dtype.
8678
index = ops.cast(index, "int32")
8779
max_length = ops.cast(max_length, "int32")
80+
batch_size = ops.shape(prompt)[0]
8881
if mask is None:
8982
mask = ops.zeros_like(prompt, dtype="bool")
9083
else:
9184
mask = ops.cast(mask, dtype="bool")
92-
# `ops.while_loop` will not accept `None` as a value for `loop_vars`.
9385
cache = () if cache is None else cache
86+
finished = ops.zeros([batch_size], dtype="bool")
87+
if stop_token_ids is not None:
88+
stop_token_ids_tensor = ops.convert_to_tensor(
89+
stop_token_ids, dtype=prompt.dtype
90+
)
91+
else:
92+
stop_token_ids_tensor = None
9493

95-
def cond(prompt, cache, index):
94+
# Compute generated_mask
95+
seq_length = ops.shape(prompt)[1]
96+
row_lengths = ops.sum(ops.cast(mask, "int32"), axis=-1)
97+
indices = ops.arange(seq_length, dtype="int32")
98+
indices = ops.expand_dims(indices, axis=0)
99+
generated_mask = indices >= ops.expand_dims(row_lengths, axis=-1)
100+
generated_mask = ops.cast(generated_mask, "bool")
101+
102+
def cond(prompt, cache, index, finished):
96103
if stop_token_ids is None:
97-
return True
98-
# Stop if all sequences have produced a *new* id from
99-
# stop_token_ids.
100-
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
101-
prompt_done = ops.any(end_tokens, axis=-1)
102-
return ops.logical_not(ops.all(prompt_done))
103-
104-
def body(prompt, cache, index):
105-
# Compute the softmax distribution for the next token.
104+
return index < max_length
105+
return ops.logical_not(ops.all(finished))
106+
107+
def body(prompt, cache, index, finished):
106108
logits, _, cache = next(prompt, cache, index)
107109
probabilities = self.compute_probabilities(logits)
108-
# Compute the next token.
109110
next_token = self.get_next_token(probabilities)
110-
# Don't overwrite anywhere mask is True.
111111
next_token = ops.cast(next_token, prompt.dtype)
112+
# Preserve prompt tokens
112113
next_token = ops.where(mask[:, index], prompt[:, index], next_token)
113-
# Update the prompt with the next token.
114+
if stop_token_ids is not None:
115+
# Check stop tokens only for generated positions
116+
# and non-finished sequences
117+
is_generating = generated_mask[:, index] & ~finished
118+
is_stop = is_generating & ops.any(
119+
next_token[:, None] == stop_token_ids_tensor, axis=-1
120+
)
121+
finished = ops.logical_or(finished, is_stop)
114122
next_token = next_token[:, None]
115123
prompt = ops.slice_update(prompt, [0, index], next_token)
124+
return (prompt, cache, index + 1, finished)
116125

117-
# Return the next prompt, cache and incremented index.
118-
return (prompt, cache, index + 1)
119-
120-
prompt, _, _ = self.run_loop(
126+
prompt, _, _, _ = self.run_loop(
121127
cond,
122128
body,
123-
loop_vars=(prompt, cache, index),
129+
loop_vars=(prompt, cache, index, finished),
124130
maximum_iterations=(max_length - index),
125131
model=model,
126132
)
127133
return prompt
128134

129135
def compute_probabilities(self, logits):
130136
"""Compute token probabilities from logits.
131-
132-
This will always be done in full precision, regardless of dtype, and
137+
This will always be done in full precision, regardless of dtype, and
133138
scale by `temperature`.
134139
"""
135140
logits = ops.cast(logits, "float32")
@@ -138,7 +143,6 @@ def compute_probabilities(self, logits):
138143
def run_loop(
139144
self, cond, body, model=None, loop_vars=None, maximum_iterations=None
140145
):
141-
"""Run ops.while_loops with a `StatelessScope` if necessary."""
142146
if keras.config.backend() == "jax":
143147
import itertools
144148

@@ -165,16 +169,17 @@ def stateless_body(state, *loop_vars):
165169
)
166170
with keras.StatelessScope(state_mapping=mapping) as scope:
167171
loop_vars = body(*loop_vars)
168-
169-
sampler_variables = []
170-
for v in self.variables:
171-
new_v = scope.get_current_value(v)
172-
sampler_variables.append(new_v if new_v is not None else v)
173-
state = (
174-
sampler_variables,
175-
trainable_variables,
176-
non_trainable_variables,
177-
)
172+
sampler_variables = []
173+
for v in self.variables:
174+
new_v = scope.get_current_value(v)
175+
sampler_variables.append(
176+
new_v if new_v is not None else v
177+
)
178+
state = (
179+
sampler_variables,
180+
trainable_variables,
181+
non_trainable_variables,
182+
)
178183
return state, *loop_vars
179184

180185
variables = [ops.convert_to_tensor(v) for v in self.variables]
@@ -184,11 +189,7 @@ def stateless_body(state, *loop_vars):
184189
non_trainable_variables = [
185190
ops.convert_to_tensor(v) for v in model_non_trainable_variables
186191
]
187-
state = (
188-
variables,
189-
trainable_variables,
190-
non_trainable_variables,
191-
)
192+
state = (variables, trainable_variables, non_trainable_variables)
192193
state, *loop_vars = ops.while_loop(
193194
cond=stateless_cond,
194195
body=stateless_body,

0 commit comments

Comments
 (0)