Skip to content

Commit 0f41c62

Browse files
lukebaumanncopybara-github
authored andcommitted
Fixed a couple bugs with pathwaysutils.jax
PiperOrigin-RevId: 813906532
1 parent d83f050 commit 0f41c62

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pathwaysutils/jax/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable
7777

7878
try:
7979
# jax>=0.8.0
80-
from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top
80+
from jaxlib import _pathways as jaxlib_pathways # pylint: disable=g-import-not-at-top
8181

82-
jaxlib_pathways = _pathways
83-
del _pathways
84-
except ModuleNotFoundError:
82+
except ImportError:
8583
# jax<0.8.0
8684

8785
jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0")
8886

8987

88+
del jax
89+
del Any
9090
del _FakeJaxModule

0 commit comments

Comments
 (0)