Commit 5bac50f
authored
Fix PARSeq decoder for TFLite compatibility (#2467)
* Fix PARSeq decoder for TFLite compatibility
Use ops.cond() instead of Python if-statement in PARSeqDecoder.call()
to ensure graph mode compatibility during TFLite conversion.
- Changed 'if tokens_length > 1:' to ops.cond(tokens_length > 1, ...)
- This allows TensorFlow to properly trace the graph with symbolic tensors
This fixes the 'Using a symbolic tf.Tensor as a Python bool is not allowed'
error during TFLite conversion when sequence length is dynamic (None).
* Fix PARSeq decoder for TFLite and JAX compatibility
* Fix PARSeq decoder for TFLite and JAX compatibility
* Update parseq_decoder.py
* Update parseq_decoder.py
* Refactor content and query embedding logic in PARSeqDecoder
Simplifies content and query embedding construction for better compatibility with JAX/TF graph backends. Removes dynamic slicing and Python conditionals, using ops.take and shape-based indexing to ensure consistent tensor shapes.1 parent cd82a95 commit 5bac50f
1 file changed
+21
-9
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
364 | 364 | | |
365 | 365 | | |
366 | 366 | | |
367 | | - | |
368 | | - | |
369 | | - | |
370 | | - | |
371 | | - | |
372 | | - | |
373 | | - | |
374 | | - | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
375 | 383 | | |
376 | 384 | | |
377 | 385 | | |
378 | 386 | | |
379 | 387 | | |
380 | | - | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
381 | 393 | | |
382 | 394 | | |
383 | 395 | | |
| |||
0 commit comments