diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 5194fe1..13d9cfc 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -1802,7 +1802,7 @@ def is_namedarray(value) -> typing.TypeGuard[NamedArrayBase]: def full( named_shape: Mapping[AxisName, int], fill_value: jax.typing.ArrayLike, - dtype: np.DTypeLike | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Constructs a full named array with a given shape. @@ -1823,7 +1823,7 @@ def full( def zeros( named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of zeros with a given shape. @@ -1842,7 +1842,7 @@ def zeros( def ones( named_shape: Mapping[AxisName, int], - dtype: np.DTypeLike | None = None, + dtype: jax.typing.DTypeLike | None = None, ) -> NamedArray: """Constructs a named array of ones with a given shape.