Skip to content

Commit 0c6d665

Browse files
laxmareddypcantonioshertschuh
authored
Version bump 0.2.1.dev0 (#115)
* Ignore shard_map attr error in mypy. (#97) * Added TF specific documentation to `DistributedEmbedding`. (#94) * Fix symbolic calls for `EmbedReduce`. (#98) `EmbedReduce` was inheriting the behavior from `Embedding` and not correctly applying the reduction. * Move `DistributedEmbedding` declaration to its own file. (#99) Having it in `__init__.py` doesn't play nice with pytype. * Remove dependency on `tree` and use `keras.tree`. (#100) Keras can already depend on either `dmtree` or `optree` and use whichever is best or available on the current platform. * Only enable JAX on linux_x86_64. (#101) * Add out_sharding argument WrappedKerasInitializer. (#102) This is for forward-compatibility. Latest versions of JAX introduce the `out_sharding` argument. * Use Python 3.10 style type annotations. (#104) Now that we require Python 3.10, we can use the shorter annotation style, which should improve the readability of the documentation. * Do not bundle test utils in wheel. (#105) * Update version number to 0.2.1 (#106) As 0.2.0 was just released. * Fix invalid escape sequence in unit test. (#108) * Replace leftover `unflatten_as` to `pack_sequence_as`. (#109) This instance was missed as it is only run on TPU. * Make the declaration of `Nested` compatible with pytype. (#110) Which doesn't support `|` between forward declarations using a string. * Add ragged support for default_device placement on JAX. (#107) Requires calling `preprocess`. Internally, we currently convert ragged inputs to dense before passing to the embedding call(...) function. * Add documentation for using DistributedEmbedding with JAX. (#111) * `api_gen` now excludes backend specific code. (#103) This: - Allows development (`api_gen` / git presubmit hooks) without all backends and backend specific dependencies installed and working. For instance, jax_tpu_embedding currently doesn't import on MacOS Sequoia, this allows running `api_gen` regardless. - Makes sure we don't accidentally create and honor exports that are backend specific. * Enable preprocess calls with symbolic input tensors. (#113) This allows us to more-easily create functional models via: ```python preprocessed_inputs = distributed_embedding.preprocess(symbolic_inputs, symbolic_weights) outputs = distributed_embedding(preprocessed_inputs) model = keras.Model(inputs=preprocessed_inputs, outputs=outputs) ``` * Check for jax_tpu_embedding on JAX backend. (#114) This is to allow users to potentially run Keras RS _without_ the dependency. If a user doesn't have `jax-tpu-embedding` installed, but are on `linux_x86_64` and has a sparsecore-capable TPU available, and if they try to use `auto` or `sparsecore` placement with distributed embedding, will raise an error informing them to install the dependency. --------- Co-authored-by: C. Antonio Sánchez <[email protected]> Co-authored-by: hertschuh <[email protected]>
1 parent 060d2c5 commit 0c6d665

37 files changed

+1015
-503
lines changed

api_gen.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@ def build() -> None:
5454
os.chdir(build_dir)
5555
# Generates `keras_rs/api` directory.
5656
open(build_api_init_fname, "w").close()
57-
namex.generate_api_files("keras_rs", code_directory="src")
57+
namex.generate_api_files(
58+
"keras_rs",
59+
code_directory="src",
60+
exclude_directories=[
61+
os.path.join("src", "layers", "embedding", "jax"),
62+
os.path.join("src", "layers", "embedding", "tensorflow"),
63+
],
64+
)
5865
# Add `__version__` to `keras_rs/__init__.py`.
5966
export_version_string(build_api_init_fname)
6067
# Copy back `keras_rs` from build dir to `api` excluding `src/`.

keras_rs/api/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras_rs.src.layers.embedding import (
7+
from keras_rs.src.layers.embedding.distributed_embedding import (
88
DistributedEmbedding as DistributedEmbedding,
99
)
1010
from keras_rs.src.layers.embedding.distributed_embedding_config import (
Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +0,0 @@
1-
import keras
2-
3-
from keras_rs.src.api_export import keras_rs_export
4-
5-
if keras.backend.backend() == "jax":
6-
from keras_rs.src.layers.embedding.jax.distributed_embedding import (
7-
DistributedEmbedding as BackendDistributedEmbedding,
8-
)
9-
elif keras.backend.backend() == "tensorflow":
10-
from keras_rs.src.layers.embedding.tensorflow.distributed_embedding import (
11-
DistributedEmbedding as BackendDistributedEmbedding,
12-
)
13-
else:
14-
from keras_rs.src.layers.embedding.base_distributed_embedding import (
15-
DistributedEmbedding as BackendDistributedEmbedding,
16-
)
17-
18-
19-
@keras_rs_export("keras_rs.layers.DistributedEmbedding")
20-
class DistributedEmbedding(BackendDistributedEmbedding):
21-
pass

0 commit comments

Comments
 (0)