Commit 5d7f18a
Use the JTE API to set stacked table stats... (#169)
* Use the JTE API to set stacked table stats to the maximum of the input table specs.
This allows setting parameters like `max_ids_per_partition` and `max_unique_ids_per_partition`,
`suggested_coo_buffer_size` for stacked tables with auto-stacking.
Although the heuristic may not be optimal, this at least provides a method for directly
setting the values in the stacked tables, and is consistent with the default values
if nothing is set.
Uses the `jax_tpu_embedding` API for future-proofing.
* Update keras_rs/src/layers/embedding/jax/distributed_embedding.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
---------
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>1 parent 04fb241 commit 5d7f18a
1 file changed
+40
-22
lines changedLines changed: 40 additions & 22 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | 3 | | |
5 | 4 | | |
6 | 5 | | |
| |||
446 | 445 | | |
447 | 446 | | |
448 | 447 | | |
449 | | - | |
450 | | - | |
451 | | - | |
452 | | - | |
453 | | - | |
454 | | - | |
455 | | - | |
456 | | - | |
457 | | - | |
458 | | - | |
459 | | - | |
460 | | - | |
461 | | - | |
462 | | - | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
463 | 464 | | |
464 | | - | |
465 | | - | |
466 | 465 | | |
467 | | - | |
468 | | - | |
469 | | - | |
470 | | - | |
471 | | - | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
472 | 490 | | |
473 | 491 | | |
474 | 492 | | |
| |||
0 commit comments