From 116d91bc8bc6c555a91673b8e234df9be0a667bf Mon Sep 17 00:00:00 2001 From: James Mochizuki-Freeman Date: Thu, 16 Oct 2025 16:53:47 -0400 Subject: [PATCH] Fix dtype type annotations to use `jax.typing.DTypeLike`. The `named_axis` functions `full()`, `zeros()` and `ones()` annotate their dtype parameters with the non-existent `np.DTypeLike`. This change replaces the annotations with `jax.typing.DTypeLike | None` to match the dtype parameter of their wrapped functions. --- penzai/core/named_axes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 5194fe1..13d9cfc 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -1802,7 +1802,7 @@ def is_namedarray(value) -> typing.TypeGuard[NamedArrayBase]: def full( named_shape: Mapping[AxisName, int], fill_value: jax.typing.ArrayLike, - dtype: np.DTypeLike | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Constructs a full named array with a given shape. @@ -1823,7 +1823,7 @@ def full( def zeros( named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of zeros with a given shape. @@ -1842,7 +1842,7 @@ def zeros( def ones( named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of ones with a given shape.