Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 61 additions & 14 deletions src/nitypes/waveform/_analog_waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from nitypes._arguments import arg_to_uint, validate_dtype, validate_unsupported_arg
from nitypes._exceptions import invalid_arg_type, invalid_array_ndim
from nitypes._typing import Self, TypeAlias
from nitypes.waveform._exceptions import (
input_array_data_type_mismatch,
input_waveform_data_type_mismatch,
)
from nitypes.waveform._extended_properties import (
CHANNEL_NAME,
UNIT_DESCRIPTION,
Expand Down Expand Up @@ -800,16 +804,9 @@ def _append_array(
timestamps: Sequence[dt.datetime] | Sequence[ht.datetime] | None = None,
) -> None:
if array.dtype != self.dtype:
raise TypeError(
"The data type of the input array must match the waveform data type.\n\n"
f"Input array data type: {array.dtype}\n"
f"Waveform data type: {self.dtype}"
)
raise input_array_data_type_mismatch(array.dtype, self.dtype)
if array.ndim != 1:
raise ValueError(
"The input array must be a one-dimensional array.\n\n"
f"Number of dimensions: {array.ndim}"
)
raise invalid_array_ndim("input array", "one-dimensional array", array.ndim)
if timestamps is not None and len(array) != len(timestamps):
raise ValueError(
"The number of irregular timestamps must be equal to the input array length.\n\n"
Expand All @@ -832,11 +829,7 @@ def _append_waveform(self, waveform: AnalogWaveform[_ScalarType_co]) -> None:
def _append_waveforms(self, waveforms: Sequence[AnalogWaveform[_ScalarType_co]]) -> None:
for waveform in waveforms:
if waveform.dtype != self.dtype:
raise TypeError(
"The data type of the input waveform must match the waveform data type.\n\n"
f"Input waveform data type: {waveform.dtype}\n"
f"Waveform data type: {self.dtype}"
)
raise input_waveform_data_type_mismatch(waveform.dtype, self.dtype)
if waveform._scale_mode != self._scale_mode:
warnings.warn(scale_mode_mismatch())

Expand All @@ -859,6 +852,60 @@ def _increase_capacity(self, amount: int) -> None:
if new_capacity > self.capacity:
self.capacity = new_capacity

def load_data(
self,
array: npt.NDArray[_ScalarType_co],
*,
copy: bool = True,
start_index: SupportsIndex | None = 0,
sample_count: SupportsIndex | None = None,
) -> None:
"""Load new data into an existing waveform.

Args:
array: A NumPy array containing the data to load.
copy: Specifies whether to copy the array or save a reference to it.
start_index: The sample index at which the analog waveform data begins.
sample_count: The number of samples in the analog waveform.
"""
if isinstance(array, np.ndarray):
self._load_array(array, copy=copy, start_index=start_index, sample_count=sample_count)
else:
raise invalid_arg_type("input array", "array", array)

def _load_array(
self,
array: npt.NDArray[_ScalarType_co],
*,
copy: bool = True,
start_index: SupportsIndex | None = 0,
sample_count: SupportsIndex | None = None,
) -> None:
if array.dtype != self.dtype:
raise input_array_data_type_mismatch(array.dtype, self.dtype)
if array.ndim != 1:
raise invalid_array_ndim("input array", "one-dimensional array", array.ndim)
if self._timing._timestamps is not None and len(array) != len(self._timing._timestamps):
raise ValueError(
"The input array length must be equal to the number of irregular timestamps.\n\n"
f"Array length: {len(array)}\n"
f"Number of timestamps: {len(self._timing._timestamps)}"
)

start_index = arg_to_uint("start index", start_index, 0)
sample_count = arg_to_uint("sample count", sample_count, len(array) - start_index)

if copy:
if sample_count > len(self._data):
self.capacity = sample_count
self._data[0:sample_count] = array[start_index : start_index + sample_count]
self._start_index = 0
self._sample_count = sample_count
else:
self._data = array
self._start_index = start_index
self._sample_count = sample_count

def __eq__(self, value: object, /) -> bool:
"""Return self==value."""
if not isinstance(value, self.__class__):
Expand Down
18 changes: 18 additions & 0 deletions src/nitypes/waveform/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ class TimingMismatchError(RuntimeError):
pass


def input_array_data_type_mismatch(input_dtype: object, waveform_dtype: object) -> TypeError:
"""Create a TypeError for an input array data type mismatch."""
return TypeError(
"The data type of the input array must match the waveform data type.\n\n"
f"Input array data type: {input_dtype}\n"
f"Waveform data type: {waveform_dtype}"
)


def input_waveform_data_type_mismatch(input_dtype: object, waveform_dtype: object) -> TypeError:
"""Create a TypeError for an input waveform data type mismatch."""
return TypeError(
"The data type of the input waveform must match the waveform data type.\n\n"
f"Input waveform data type: {input_dtype}\n"
f"Waveform data type: {waveform_dtype}"
)


def no_timestamp_information() -> RuntimeError:
"""Create a RuntimeError for waveform timing with no timestamp information."""
return RuntimeError(
Expand Down
164 changes: 164 additions & 0 deletions tests/unit/waveform/test_analog_waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,170 @@ def test___regular_waveform_and_irregular_waveform_list___append___raises_runtim
assert waveform.timing.sample_interval == dt.timedelta(milliseconds=1)


###############################################################################
# load data
###############################################################################
def test___empty_ndarray___load_data___clears_data() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([], np.int32)

waveform.load_data(array)

assert list(waveform.raw_data) == []


def test___int32_ndarray___load_data___overwrites_data() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([3, 4, 5], np.int32)

waveform.load_data(array)

assert list(waveform.raw_data) == [3, 4, 5]


def test___float64_ndarray___load_data___overwrites_data() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64)
array = np.array([3, 4, 5], np.float64)

waveform.load_data(array)

assert list(waveform.raw_data) == [3, 4, 5]


def test___ndarray_with_mismatched_dtype___load_data___raises_type_error() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64)
array = np.array([3, 4, 5], np.int32)

with pytest.raises(TypeError) as exc:
waveform.load_data(array) # type: ignore[arg-type]

assert exc.value.args[0].startswith(
"The data type of the input array must match the waveform data type."
)


def test___ndarray_2d___load_data___raises_value_error() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64)
array = np.array([[3, 4, 5], [6, 7, 8]], np.float64)

with pytest.raises(ValueError) as exc:
waveform.load_data(array)

assert exc.value.args[0].startswith("The input array must be a one-dimensional array.")


def test___smaller_ndarray___load_data___preserves_capacity() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([3], np.int32)

waveform.load_data(array)

assert list(waveform.raw_data) == [3]
assert waveform.capacity == 3


def test___larger_ndarray___load_data___grows_capacity() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([3, 4, 5, 6], np.int32)

waveform.load_data(array)

assert list(waveform.raw_data) == [3, 4, 5, 6]
assert waveform.capacity == 4


def test___waveform_with_start_index___load_data___clears_start_index() -> None:
waveform = AnalogWaveform.from_array_1d(
np.array([0, 1, 2], np.int32), np.int32, copy=False, start_index=1, sample_count=1
)
assert waveform._start_index == 1
array = np.array([3], np.int32)

waveform.load_data(array)

assert list(waveform.raw_data) == [3]
assert waveform._start_index == 0


def test___ndarray_subset___load_data___overwrites_data() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([3, 4, 5], np.int32)

waveform.load_data(array, start_index=1, sample_count=1)

assert list(waveform.raw_data) == [4]
assert waveform._start_index == 0
assert waveform.capacity == 3


def test___smaller_ndarray_no_copy___load_data___takes_ownership_of_array() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([3], np.int32)

waveform.load_data(array, copy=False)

assert list(waveform.raw_data) == [3]
assert waveform._data is array


def test___larger_ndarray_no_copy___load_data___takes_ownership_of_array() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([3, 4, 5, 6], np.int32)

waveform.load_data(array, copy=False)

assert list(waveform.raw_data) == [3, 4, 5, 6]
assert waveform._data is array


def test___ndarray_subset_no_copy___load_data___takes_ownership_of_array_subset() -> None:
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
array = np.array([3, 4, 5, 6], np.int32)

waveform.load_data(array, copy=False, start_index=1, sample_count=2)

assert list(waveform.raw_data) == [4, 5]
assert waveform._data is array


def test___irregular_waveform_and_int32_ndarray_with_timestamps___load_data___overwrites_data_but_not_timestamps() -> (
None
):
start_time = dt.datetime.now(dt.timezone.utc)
waveform_offsets = [dt.timedelta(0), dt.timedelta(1), dt.timedelta(2)]
waveform_timestamps = [start_time + offset for offset in waveform_offsets]
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
waveform.timing = Timing.create_with_irregular_interval(waveform_timestamps)
array = np.array([3, 4, 5], np.int32)

waveform.load_data(array)

assert list(waveform.raw_data) == [3, 4, 5]
assert waveform.timing.sample_interval_mode == SampleIntervalMode.IRREGULAR
assert waveform.timing._timestamps == waveform_timestamps


def test___irregular_waveform_and_int32_ndarray_with_wrong_sample_count___load_data___raises_value_error_and_does_not_overwrite_data() -> (
None
):
start_time = dt.datetime.now(dt.timezone.utc)
waveform_offsets = [dt.timedelta(0), dt.timedelta(1), dt.timedelta(2)]
waveform_timestamps = [start_time + offset for offset in waveform_offsets]
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
waveform.timing = Timing.create_with_irregular_interval(waveform_timestamps)
array = np.array([3, 4], np.int32)

with pytest.raises(ValueError) as exc:
waveform.load_data(array)

assert exc.value.args[0].startswith(
"The input array length must be equal to the number of irregular timestamps."
)
assert list(waveform.raw_data) == [0, 1, 2]
assert waveform.timing.sample_interval_mode == SampleIntervalMode.IRREGULAR
assert waveform.timing._timestamps == waveform_timestamps


###############################################################################
# magic methods
###############################################################################
Expand Down