Skip to content

Conversation

copybara-service[bot]
Copy link

Adds initial Keras Orbax checkpointer V2 implementation.

First step in creating a memory efficient Keras + Jax checkpointer that uses nested PyTrees instead of flat tuples to enable model surgery.

  • Checkpoints the serialized model config as metadata.
  • Upgrades the checkpointing logic to the new Orbax API.
  • Writes checkpoints as a dict instead of a tuple.
  • Removes unnecessary expensive jax_state_sync calls.

Reverts changelist 793734230

First step in creating a memory efficient Keras + Jax checkpointer that uses nested PyTrees instead of flat tuples to enable model surgery.

- Checkpoints the serialized model config as metadata.
- Upgrades the checkpointing logic to the new Orbax API.
- Writes checkpoints as a dict instead of a tuple.
- Removes unnecessary expensive `jax_state_sync` calls.

Reverts changelist 793734230

PiperOrigin-RevId: 774946771
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant