Skip to content

Commit 5bac50f

Browse files
authored
Fix PARSeq decoder for TFLite compatibility (#2467)
* Fix PARSeq decoder for TFLite compatibility Use ops.cond() instead of Python if-statement in PARSeqDecoder.call() to ensure graph mode compatibility during TFLite conversion. - Changed 'if tokens_length > 1:' to ops.cond(tokens_length > 1, ...) - This allows TensorFlow to properly trace the graph with symbolic tensors This fixes the 'Using a symbolic tf.Tensor as a Python bool is not allowed' error during TFLite conversion when sequence length is dynamic (None). * Fix PARSeq decoder for TFLite and JAX compatibility * Fix PARSeq decoder for TFLite and JAX compatibility * Update parseq_decoder.py * Update parseq_decoder.py * Refactor content and query embedding logic in PARSeqDecoder Simplifies content and query embedding construction for better compatibility with JAX/TF graph backends. Removes dynamic slicing and Python conditionals, using ops.take and shape-based indexing to ensure consistent tensor shapes.
1 parent cd82a95 commit 5bac50f

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

keras_hub/src/models/parseq/parseq_decoder.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -364,20 +364,32 @@ def call(
364364
null_context = self.hidden_dim**0.5 * self.token_embedding(
365365
token_ids[:, :1]
366366
)
367-
if tokens_length > 1:
368-
content = self.pos_query_embeddings[:, : tokens_length - 1, :]
369-
content = content + self.hidden_dim**0.5 * self.token_embedding(
370-
token_ids[:, 1:]
371-
)
372-
content = ops.concatenate([null_context, content], axis=1)
373-
else:
374-
content = null_context
367+
368+
# Build content embeddings. When tokens_length == 1, this produces an
369+
# empty tensor (shape: bs, 0, hidden), avoiding the need for a Python
370+
# conditional.
371+
content_embeddings = self.hidden_dim**0.5 * self.token_embedding(
372+
token_ids[:, 1:]
373+
)
374+
# Use ops.take instead of dynamic slicing for JAX/TF graph compatibility
375+
pos_embeds = ops.take(
376+
self.pos_query_embeddings,
377+
ops.arange(ops.shape(content_embeddings)[1], dtype="int32"),
378+
axis=1,
379+
)
380+
content = ops.concatenate(
381+
[null_context, pos_embeds + content_embeddings], axis=1
382+
)
375383

376384
content = self.dropout(content)
377385

378386
query = ops.multiply(
379387
ops.ones((bs, 1, 1), dtype=self.dtype),
380-
self.pos_query_embeddings[:, :tokens_length, :],
388+
ops.take(
389+
self.pos_query_embeddings,
390+
ops.arange(tokens_length, dtype="int32"),
391+
axis=1,
392+
),
381393
)
382394
query = self.dropout(query)
383395

0 commit comments

Comments
 (0)