Skip to content

Commit 545bb18

Browse files
Merge pull request #31439 from chapman20j:abstract_device
PiperOrigin-RevId: 801031638
2 parents bbc9d86 + 9c9bb2a commit 545bb18

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

jax/sharding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from jax._src.mesh import (
2929
Mesh as Mesh,
30+
AbstractDevice as AbstractDevice,
3031
AbstractMesh as AbstractMesh,
3132
AxisType as AxisType,
3233
get_abstract_mesh as get_abstract_mesh,

tests/pjit_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from jax.lax import with_sharding_constraint
4242
from jax._src import prng
4343
from jax.sharding import (PartitionSpec as P, Mesh, auto_axes, explicit_axes,
44-
reshard)
44+
reshard, AbstractDevice)
4545
from jax.experimental import multihost_utils
4646
from jax._src.shard_map import shard_map
4747
from jax._src.compilation_cache import is_persistent_cache_enabled
@@ -53,7 +53,7 @@
5353
from jax._src.sharding_impls import (
5454
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding,
5555
SingleDeviceSharding, parse_flatten_op_sharding)
56-
from jax._src.mesh import use_abstract_mesh, AbstractDevice
56+
from jax._src.mesh import use_abstract_mesh
5757
from jax._src.pjit import pjit, _pjit_lower
5858
from jax._src.layout import Format, Layout as DLL
5959
from jax._src.named_sharding import DuplicateSpecError

0 commit comments

Comments
 (0)