Skip to content

Commit 8451e66

Browse files
lukebaumanncopybara-github
authored andcommitted
Fixed a couple bugs with pathwaysutils.jax
PiperOrigin-RevId: 813906532
1 parent d83f050 commit 8451e66

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pathwaysutils/jax/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable
7777

7878
try:
7979
# jax>=0.8.0
80-
from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top
80+
from jaxlib import _pathways # pylint: disable=g-import-not-at-top
8181

8282
jaxlib_pathways = _pathways
8383
del _pathways
@@ -87,4 +87,6 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable
8787
jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0")
8888

8989

90+
del jax
91+
del Any
9092
del _FakeJaxModule

0 commit comments

Comments
 (0)