-
Notifications
You must be signed in to change notification settings - Fork 292
[DeepSeek R1] Qwen2.5 Distillations #2236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
d3d6164
169ec14
eece281
3b22383
0ee0033
05d500f
59629e8
cad36b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -22,6 +22,12 @@ | |||||||||||||||||
from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( | ||||||||||||||||||
DebertaV3Tokenizer as DebertaV3Tokenizer, | ||||||||||||||||||
) | ||||||||||||||||||
from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( | ||||||||||||||||||
DeepSeekR1QwenTokenizer as DeepSeekR1Qwen2Tokenizer, | ||||||||||||||||||
) | ||||||||||||||||||
from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( | ||||||||||||||||||
DeepSeekR1QwenTokenizer as DeepSeekR1QwenTokenizer, | ||||||||||||||||||
Comment on lines
+25
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. |
||||||||||||||||||
) | ||||||||||||||||||
from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( | ||||||||||||||||||
DistilBertTokenizer as DistilBertTokenizer, | ||||||||||||||||||
) | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.causal_lm import CausalLM | ||
from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_causal_lm_preprocessor import ( # noqa: E501 | ||
DeepSeekR1QwenCausalLMPreprocessor, | ||
) | ||
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone | ||
from keras_hub.src.utils.tensor_utils import any_equal | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.models.DeepSeekR1QwenCausalLM", | ||
"keras_hub.models.DeepSeekR1Qwen2CausalLM", | ||
] | ||
) | ||
class DeepSeekR1QwenCausalLM(CausalLM): | ||
backbone_cls = QwenBackbone | ||
preprocessor_cls = DeepSeekR1QwenCausalLMPreprocessor | ||
|
||
def __init__(self, backbone, preprocessor=None, **kwargs): | ||
# === Layers === | ||
self.backbone = backbone | ||
self.preprocessor = preprocessor | ||
|
||
# === Functional Model === | ||
# This must be "backbone.input" i.e. the full input structure, | ||
# rather than "backbone.inputs" which is the flattened list of inputs. | ||
inputs = backbone.input | ||
hidden_states = backbone(inputs) | ||
outputs = backbone.token_embedding(hidden_states, reverse=True) | ||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
**kwargs, | ||
) | ||
|
||
def call_with_cache( | ||
self, | ||
token_ids, | ||
cache, | ||
cache_update_index, | ||
): | ||
"""Forward pass of `DeepSeekR1QwenCausalLM` with cache. | ||
|
||
`call_with_cache` adds an additional forward pass for the model for | ||
autoregressive inference. Unlike calling the model directly, this method | ||
allows caching previous key/value Tensors in multi-head attention layer, | ||
and avoids recomputing the outputs of seen tokens. | ||
|
||
Args: | ||
token_ids: a dense int Tensor with shape `(batch_size, max_length)`. | ||
cache: a dense float Tensor, the cache of key and value. | ||
cache_update_index: int, or int Tensor. The index of current inputs | ||
in the whole sequence. | ||
|
||
Returns: | ||
A (logits, hidden_states, cache) tuple. Where `logits` is the | ||
language model logits for the input token_ids, `hidden_states` is | ||
the final hidden representation of the input tokens, and `cache` is | ||
the decoding cache. | ||
""" | ||
x = self.backbone.token_embedding(token_ids) | ||
# Each decoder layer has a cache; we update them separately. | ||
updated_cache = [] | ||
for i in range(self.backbone.num_layers): | ||
current_cache = cache[:, i, ...] | ||
x, next_cache = self.backbone.transformer_layers[i]( | ||
x, | ||
self_attention_cache=current_cache, | ||
self_attention_cache_update_index=cache_update_index, | ||
) | ||
updated_cache.append(next_cache) | ||
cache = ops.stack(updated_cache, axis=1) | ||
hidden_states = x = self.backbone.layer_norm(x) | ||
logits = self.backbone.token_embedding(x, reverse=True) | ||
return logits, hidden_states, cache | ||
|
||
def _build_cache(self, token_ids): | ||
"""Build an empty cache for use with `call_with_cache()`.""" | ||
batch_size = ops.shape(token_ids)[0] | ||
max_length = ops.shape(token_ids)[1] | ||
num_layers = self.backbone.num_layers | ||
num_key_value_heads = self.backbone.num_key_value_heads | ||
head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads | ||
shape = [ | ||
batch_size, | ||
num_layers, | ||
2, | ||
max_length, | ||
num_key_value_heads, | ||
head_dim, | ||
] | ||
cache = ops.zeros(shape, dtype=self.compute_dtype) | ||
# Seed the cache. | ||
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) | ||
return hidden_states, cache | ||
|
||
def generate_step( | ||
self, | ||
inputs, | ||
stop_token_ids=None, | ||
): | ||
"""A compilable generation function for a single batch of inputs. | ||
|
||
This function represents the inner, XLA-compilable, generation function | ||
for a single batch of inputs. Inputs should have the same structure as | ||
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. | ||
|
||
Args: | ||
inputs: A dictionary with two keys `"token_ids"` and | ||
`"padding_mask"` and batched tensor values. | ||
stop_token_ids: Tuple of id's of the end token to stop on. If all | ||
sequences have produced a new stop token, generation | ||
will stop. | ||
""" | ||
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] | ||
# Create and seed cache with a single forward pass. | ||
hidden_states, cache = self._build_cache(token_ids) | ||
# Compute the lengths of all user inputted tokens ids. | ||
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) | ||
# Start at the first index that has no user inputted id. | ||
index = ops.min(row_lengths) | ||
|
||
def next(prompt, cache, index): | ||
# The cache index is the index of our previous token. | ||
cache_update_index = index - 1 | ||
batch_size = ops.shape(prompt)[0] | ||
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) | ||
logits, hidden_states, cache = self.call_with_cache( | ||
prompt, | ||
cache, | ||
cache_update_index, | ||
) | ||
return ( | ||
ops.squeeze(logits, axis=1), | ||
ops.squeeze(hidden_states, axis=1), | ||
cache, | ||
) | ||
|
||
token_ids = self.sampler( | ||
next=next, | ||
prompt=token_ids, | ||
cache=cache, | ||
index=index, | ||
mask=padding_mask, | ||
stop_token_ids=stop_token_ids, | ||
hidden_states=hidden_states, | ||
model=self, | ||
) | ||
|
||
# Compute an output padding mask with the token ids we updated. | ||
if stop_token_ids is not None: | ||
# Build a mask of stop token locations not in the original | ||
# prompt (not in locations where `padding_mask` is True). | ||
end_locations = any_equal( | ||
token_ids, stop_token_ids, ops.logical_not(padding_mask) | ||
) | ||
end_locations = ops.cast(end_locations, "int32") | ||
# Use cumsum to get ones in all locations after end_locations. | ||
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") | ||
overflow = cumsum - end_locations | ||
# Our padding mask is the inverse of these overflow locations. | ||
padding_mask = ops.logical_not(ops.cast(overflow, "bool")) | ||
else: | ||
# Without early stopping, all locations will have been updated. | ||
padding_mask = ops.ones_like(token_ids, dtype="bool") | ||
return { | ||
"token_ids": token_ids, | ||
"padding_mask": padding_mask, | ||
} | ||
|
||
def score( | ||
self, | ||
token_ids, | ||
padding_mask=None, | ||
scoring_mode="logits", | ||
layer_intercept_fn=None, | ||
target_ids=None, | ||
): | ||
"""Score a generation represented by the provided token ids. | ||
|
||
Args: | ||
token_ids: A <int>[batch_size, num_tokens] tensor containing tokens | ||
to score. Typically, this tensor captures the output from a call | ||
to `QwenCausalLM.generate()`, i.e., tokens for both the input | ||
text and the model-generated text. | ||
padding_mask: A <bool>[batch_size, num_tokens] tensor indicating the | ||
tokens that should be preserved during generation. This is an | ||
artifact required by the `QwenBackbone` and isn't influential | ||
on the computation of this function. If omitted, this function | ||
uses `keras.ops.ones()` to create a tensor of the appropriate | ||
shape. | ||
scoring_mode: The type of scores to return, either "logits" or | ||
"loss", both will be per input token. | ||
layer_intercept_fn: An optional function for augmenting activations | ||
with additional computation, for example, as part of | ||
interpretability research. This function will be passed the | ||
activations as its first parameter and a numeric index | ||
associated with that backbone layer. _This index _is not_ an | ||
index into `self.backbone.layers`_. The index -1 accompanies the | ||
embeddings returned by calling `self.backbone.token_embedding()` | ||
on `token_ids` in the forward direction. All subsequent indexes | ||
will be 0-based indices for the activations returned by each of | ||
the Transformers layers in the backbone. This function must | ||
return a <float>[batch_size, num_tokens, hidden_dims] tensor | ||
that can be passed as an input to the next layer in the model. | ||
target_ids: An <bool>[batch_size, num_tokens] tensor containing the | ||
predicted tokens against which the loss should be computed. If a | ||
span of tokens is provided (sequential truthy values along | ||
axis=1 in the tensor), the loss will be computed as the | ||
aggregate across those tokens. | ||
|
||
Raises: | ||
ValueError: If an unsupported scoring_mode is provided, or if the | ||
target_ids are not provided when using ScoringMode.LOSS. | ||
|
||
Returns: | ||
The per-token scores as a tensor of size | ||
<float>[batch_size, num_tokens, vocab_size] in "logits" mode, or | ||
<float>[batch_size, num_tokens] in "loss" mode. | ||
|
||
Example: | ||
|
||
Compute gradients between embeddings and loss scores with TensorFlow: | ||
```python | ||
qwen_lm = keras_hub.models.QwenCausalLM.from_preset("qwen2.5_0.5b_en") | ||
generations = qwen_lm.generate( | ||
["This is a", "Where are you"], | ||
max_length=30 | ||
) | ||
preprocessed = qwen_lm.preprocessor.generate_preprocess(generations) | ||
generation_ids = preprocessed["token_ids"] | ||
padding_mask = preprocessed["padding_mask"] | ||
target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) | ||
|
||
embeddings = None | ||
with tf.GradientTape(watch_accessed_variables=True) as tape: | ||
def layer_intercept_fn(x, i): | ||
if i == -1: | ||
nonlocal embeddings, tape | ||
embeddings = x | ||
tape.watch(embeddings) | ||
return x | ||
|
||
losses = qwen_lm.score( | ||
token_ids=generation_ids, | ||
padding_mask=padding_mask, | ||
scoring_mode="loss", | ||
layer_intercept_fn=layer_intercept_fn, | ||
target_ids=target_ids, | ||
) | ||
|
||
grads = tape.gradient(losses, embeddings) | ||
``` | ||
""" | ||
if scoring_mode not in ("logits", "loss"): | ||
raise ValueError( | ||
"Unsupported scoring_mode. Must be one of 'logits' or 'loss'." | ||
) | ||
|
||
if scoring_mode == "loss" and target_ids is None: | ||
raise ValueError( | ||
"Cannot compute loss without targets. Please provide target " | ||
"token ids via the target_ids parameter." | ||
) | ||
|
||
batch_shape = ops.shape(token_ids)[:2] | ||
assert len(batch_shape) == 2 | ||
|
||
if padding_mask is None: | ||
padding_mask = ops.ones(shape=batch_shape) | ||
|
||
if layer_intercept_fn is None: | ||
|
||
def default_layer_intercept_fn(x, unused_i): | ||
return x | ||
|
||
layer_intercept_fn = default_layer_intercept_fn | ||
|
||
token_embeddings = self.backbone.token_embedding(token_ids) | ||
x = layer_intercept_fn(token_embeddings, -1) | ||
|
||
for i, transformer_layer in enumerate(self.backbone.transformer_layers): | ||
x = transformer_layer(x, decoder_padding_mask=padding_mask) | ||
x = layer_intercept_fn(x, i) | ||
|
||
x = self.backbone.layer_norm(x) | ||
logits = self.backbone.token_embedding(x, reverse=True) | ||
|
||
if scoring_mode == "logits": | ||
return logits | ||
|
||
per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( | ||
from_logits=True, reduction="none" | ||
) | ||
per_token_loss = per_token_loss_fn(target_ids, logits) | ||
return per_token_loss |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor | ||
from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( | ||
DeepSeekR1QwenTokenizer, | ||
) | ||
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.models.DeepSeekR1QwenCausalLMPreprocessor", | ||
"keras_hub.models.DeepSeekR1Qwen2CausalLMPreprocessor", | ||
] | ||
) | ||
class DeepSeekR1QwenCausalLMPreprocessor(CausalLMPreprocessor): | ||
backbone_cls = QwenBackbone | ||
tokenizer_cls = DeepSeekR1QwenTokenizer | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appear to be duplicate imports with different aliases.
DeepSeekR1QwenCausalLM
andDeepSeekR1QwenCausalLMPreprocessor
andDeepSeekR1QwenTokenizer
are imported twice, once with an alias ending in2
and once without. This can lead to confusion and potential errors if the wrong alias is used. Consider removing the duplicate imports or ensuring the aliases are used consistently.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is part of the API design - these have multiple aliases.