diff --git a/keras_hub/src/models/parseq/parseq_decoder.py b/keras_hub/src/models/parseq/parseq_decoder.py index 69303c7e67..ea01f85beb 100644 --- a/keras_hub/src/models/parseq/parseq_decoder.py +++ b/keras_hub/src/models/parseq/parseq_decoder.py @@ -364,20 +364,32 @@ def call( null_context = self.hidden_dim**0.5 * self.token_embedding( token_ids[:, :1] ) - if tokens_length > 1: - content = self.pos_query_embeddings[:, : tokens_length - 1, :] - content = content + self.hidden_dim**0.5 * self.token_embedding( - token_ids[:, 1:] - ) - content = ops.concatenate([null_context, content], axis=1) - else: - content = null_context + + # Build content embeddings. When tokens_length == 1, this produces an + # empty tensor (shape: bs, 0, hidden), avoiding the need for a Python + # conditional. + content_embeddings = self.hidden_dim**0.5 * self.token_embedding( + token_ids[:, 1:] + ) + # Use ops.take instead of dynamic slicing for JAX/TF graph compatibility + pos_embeds = ops.take( + self.pos_query_embeddings, + ops.arange(ops.shape(content_embeddings)[1], dtype="int32"), + axis=1, + ) + content = ops.concatenate( + [null_context, pos_embeds + content_embeddings], axis=1 + ) content = self.dropout(content) query = ops.multiply( ops.ones((bs, 1, 1), dtype=self.dtype), - self.pos_query_embeddings[:, :tokens_length, :], + ops.take( + self.pos_query_embeddings, + ops.arange(tokens_length, dtype="int32"), + axis=1, + ), ) query = self.dropout(query)