Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/zarr/core/buffer/cpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numbers
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -155,7 +156,10 @@ def create(
fill_value: Any | None = None,
) -> Self:
# np.zeros is much faster than np.full, and therefore using it when possible is better.
if fill_value is None or (isinstance(fill_value, int) and fill_value == 0):
# See https://numpy.org/doc/stable/reference/generated/numpy.isscalar.html#numpy-isscalar
# notes for why we use `numbers.Number`.
# Tehcnically `numbers.Number` need not support __eq__ hence the `ignore`.
if fill_value is None or (isinstance(fill_value, numbers.Number) and fill_value == 0): # type: ignore[comparison-overlap]
return cls(np.zeros(shape=tuple(shape), dtype=dtype, order=order))
else:
return cls(np.full(shape=tuple(shape), fill_value=fill_value, dtype=dtype, order=order))
Expand Down
27 changes: 27 additions & 0 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -243,3 +244,29 @@ def test_empty(
assert result.flags.c_contiguous # type: ignore[attr-defined]
else:
assert result.flags.f_contiguous # type: ignore[attr-defined]


@pytest.mark.parametrize("dtype", [np.int8, np.uint16, np.float32, int, float])
@pytest.mark.parametrize("fill_value", [None, 0, 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("fill_value", [None, 0, 1])
@pytest.mark.parametrize("fill_value", [None, 0, 0.0, 1])

Worth explicitly including a float here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think everything is cast anyway by dtype, so it shouldn't matter, no?

def test_no_full_with_zeros(
dtype: type[np.number[np.typing.NBitBase] | float],
fill_value: None | float,
) -> None:
"""Ensure that fill value of 0 (or None with a numeric dtype) does not trigger np.full, and instead triggers np.zeros"""
# full never called with fill 0
if fill_value == 0:
with mock.patch("numpy.full", side_effect=RuntimeError):
cpu.buffer_prototype.nd_buffer.create(
shape=(10,), dtype=dtype, fill_value=dtype(fill_value)
)
# full or zeros called appropriately based on fill value
with mock.patch(
"numpy.zeros" if fill_value == 0 or fill_value is None else "numpy.full",
side_effect=RuntimeError("called"),
):
with pytest.raises(RuntimeError, match=r"called"):
cpu.buffer_prototype.nd_buffer.create(
shape=(10,),
dtype=dtype,
fill_value=dtype(fill_value) if fill_value is not None else fill_value,
)
Loading