diff --git a/src/nitypes/waveform/_analog_waveform.py b/src/nitypes/waveform/_analog_waveform.py index f7ba073a..b8f825b4 100644 --- a/src/nitypes/waveform/_analog_waveform.py +++ b/src/nitypes/waveform/_analog_waveform.py @@ -13,6 +13,10 @@ from nitypes._arguments import arg_to_uint, validate_dtype, validate_unsupported_arg from nitypes._exceptions import invalid_arg_type, invalid_array_ndim from nitypes._typing import Self, TypeAlias +from nitypes.waveform._exceptions import ( + input_array_data_type_mismatch, + input_waveform_data_type_mismatch, +) from nitypes.waveform._extended_properties import ( CHANNEL_NAME, UNIT_DESCRIPTION, @@ -800,16 +804,9 @@ def _append_array( timestamps: Sequence[dt.datetime] | Sequence[ht.datetime] | None = None, ) -> None: if array.dtype != self.dtype: - raise TypeError( - "The data type of the input array must match the waveform data type.\n\n" - f"Input array data type: {array.dtype}\n" - f"Waveform data type: {self.dtype}" - ) + raise input_array_data_type_mismatch(array.dtype, self.dtype) if array.ndim != 1: - raise ValueError( - "The input array must be a one-dimensional array.\n\n" - f"Number of dimensions: {array.ndim}" - ) + raise invalid_array_ndim("input array", "one-dimensional array", array.ndim) if timestamps is not None and len(array) != len(timestamps): raise ValueError( "The number of irregular timestamps must be equal to the input array length.\n\n" @@ -832,11 +829,7 @@ def _append_waveform(self, waveform: AnalogWaveform[_ScalarType_co]) -> None: def _append_waveforms(self, waveforms: Sequence[AnalogWaveform[_ScalarType_co]]) -> None: for waveform in waveforms: if waveform.dtype != self.dtype: - raise TypeError( - "The data type of the input waveform must match the waveform data type.\n\n" - f"Input waveform data type: {waveform.dtype}\n" - f"Waveform data type: {self.dtype}" - ) + raise input_waveform_data_type_mismatch(waveform.dtype, self.dtype) if waveform._scale_mode != self._scale_mode: warnings.warn(scale_mode_mismatch()) @@ -859,6 +852,60 @@ def _increase_capacity(self, amount: int) -> None: if new_capacity > self.capacity: self.capacity = new_capacity + def load_data( + self, + array: npt.NDArray[_ScalarType_co], + *, + copy: bool = True, + start_index: SupportsIndex | None = 0, + sample_count: SupportsIndex | None = None, + ) -> None: + """Load new data into an existing waveform. + + Args: + array: A NumPy array containing the data to load. + copy: Specifies whether to copy the array or save a reference to it. + start_index: The sample index at which the analog waveform data begins. + sample_count: The number of samples in the analog waveform. + """ + if isinstance(array, np.ndarray): + self._load_array(array, copy=copy, start_index=start_index, sample_count=sample_count) + else: + raise invalid_arg_type("input array", "array", array) + + def _load_array( + self, + array: npt.NDArray[_ScalarType_co], + *, + copy: bool = True, + start_index: SupportsIndex | None = 0, + sample_count: SupportsIndex | None = None, + ) -> None: + if array.dtype != self.dtype: + raise input_array_data_type_mismatch(array.dtype, self.dtype) + if array.ndim != 1: + raise invalid_array_ndim("input array", "one-dimensional array", array.ndim) + if self._timing._timestamps is not None and len(array) != len(self._timing._timestamps): + raise ValueError( + "The input array length must be equal to the number of irregular timestamps.\n\n" + f"Array length: {len(array)}\n" + f"Number of timestamps: {len(self._timing._timestamps)}" + ) + + start_index = arg_to_uint("start index", start_index, 0) + sample_count = arg_to_uint("sample count", sample_count, len(array) - start_index) + + if copy: + if sample_count > len(self._data): + self.capacity = sample_count + self._data[0:sample_count] = array[start_index : start_index + sample_count] + self._start_index = 0 + self._sample_count = sample_count + else: + self._data = array + self._start_index = start_index + self._sample_count = sample_count + def __eq__(self, value: object, /) -> bool: """Return self==value.""" if not isinstance(value, self.__class__): diff --git a/src/nitypes/waveform/_exceptions.py b/src/nitypes/waveform/_exceptions.py index f52fb788..6a6da901 100644 --- a/src/nitypes/waveform/_exceptions.py +++ b/src/nitypes/waveform/_exceptions.py @@ -7,6 +7,24 @@ class TimingMismatchError(RuntimeError): pass +def input_array_data_type_mismatch(input_dtype: object, waveform_dtype: object) -> TypeError: + """Create a TypeError for an input array data type mismatch.""" + return TypeError( + "The data type of the input array must match the waveform data type.\n\n" + f"Input array data type: {input_dtype}\n" + f"Waveform data type: {waveform_dtype}" + ) + + +def input_waveform_data_type_mismatch(input_dtype: object, waveform_dtype: object) -> TypeError: + """Create a TypeError for an input waveform data type mismatch.""" + return TypeError( + "The data type of the input waveform must match the waveform data type.\n\n" + f"Input waveform data type: {input_dtype}\n" + f"Waveform data type: {waveform_dtype}" + ) + + def no_timestamp_information() -> RuntimeError: """Create a RuntimeError for waveform timing with no timestamp information.""" return RuntimeError( diff --git a/tests/unit/waveform/test_analog_waveform.py b/tests/unit/waveform/test_analog_waveform.py index b91e5f24..0a1a67ad 100644 --- a/tests/unit/waveform/test_analog_waveform.py +++ b/tests/unit/waveform/test_analog_waveform.py @@ -1421,6 +1421,170 @@ def test___regular_waveform_and_irregular_waveform_list___append___raises_runtim assert waveform.timing.sample_interval == dt.timedelta(milliseconds=1) +############################################################################### +# load data +############################################################################### +def test___empty_ndarray___load_data___clears_data() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([], np.int32) + + waveform.load_data(array) + + assert list(waveform.raw_data) == [] + + +def test___int32_ndarray___load_data___overwrites_data() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([3, 4, 5], np.int32) + + waveform.load_data(array) + + assert list(waveform.raw_data) == [3, 4, 5] + + +def test___float64_ndarray___load_data___overwrites_data() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64) + array = np.array([3, 4, 5], np.float64) + + waveform.load_data(array) + + assert list(waveform.raw_data) == [3, 4, 5] + + +def test___ndarray_with_mismatched_dtype___load_data___raises_type_error() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64) + array = np.array([3, 4, 5], np.int32) + + with pytest.raises(TypeError) as exc: + waveform.load_data(array) # type: ignore[arg-type] + + assert exc.value.args[0].startswith( + "The data type of the input array must match the waveform data type." + ) + + +def test___ndarray_2d___load_data___raises_value_error() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64) + array = np.array([[3, 4, 5], [6, 7, 8]], np.float64) + + with pytest.raises(ValueError) as exc: + waveform.load_data(array) + + assert exc.value.args[0].startswith("The input array must be a one-dimensional array.") + + +def test___smaller_ndarray___load_data___preserves_capacity() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([3], np.int32) + + waveform.load_data(array) + + assert list(waveform.raw_data) == [3] + assert waveform.capacity == 3 + + +def test___larger_ndarray___load_data___grows_capacity() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([3, 4, 5, 6], np.int32) + + waveform.load_data(array) + + assert list(waveform.raw_data) == [3, 4, 5, 6] + assert waveform.capacity == 4 + + +def test___waveform_with_start_index___load_data___clears_start_index() -> None: + waveform = AnalogWaveform.from_array_1d( + np.array([0, 1, 2], np.int32), np.int32, copy=False, start_index=1, sample_count=1 + ) + assert waveform._start_index == 1 + array = np.array([3], np.int32) + + waveform.load_data(array) + + assert list(waveform.raw_data) == [3] + assert waveform._start_index == 0 + + +def test___ndarray_subset___load_data___overwrites_data() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([3, 4, 5], np.int32) + + waveform.load_data(array, start_index=1, sample_count=1) + + assert list(waveform.raw_data) == [4] + assert waveform._start_index == 0 + assert waveform.capacity == 3 + + +def test___smaller_ndarray_no_copy___load_data___takes_ownership_of_array() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([3], np.int32) + + waveform.load_data(array, copy=False) + + assert list(waveform.raw_data) == [3] + assert waveform._data is array + + +def test___larger_ndarray_no_copy___load_data___takes_ownership_of_array() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([3, 4, 5, 6], np.int32) + + waveform.load_data(array, copy=False) + + assert list(waveform.raw_data) == [3, 4, 5, 6] + assert waveform._data is array + + +def test___ndarray_subset_no_copy___load_data___takes_ownership_of_array_subset() -> None: + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + array = np.array([3, 4, 5, 6], np.int32) + + waveform.load_data(array, copy=False, start_index=1, sample_count=2) + + assert list(waveform.raw_data) == [4, 5] + assert waveform._data is array + + +def test___irregular_waveform_and_int32_ndarray_with_timestamps___load_data___overwrites_data_but_not_timestamps() -> ( + None +): + start_time = dt.datetime.now(dt.timezone.utc) + waveform_offsets = [dt.timedelta(0), dt.timedelta(1), dt.timedelta(2)] + waveform_timestamps = [start_time + offset for offset in waveform_offsets] + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + waveform.timing = Timing.create_with_irregular_interval(waveform_timestamps) + array = np.array([3, 4, 5], np.int32) + + waveform.load_data(array) + + assert list(waveform.raw_data) == [3, 4, 5] + assert waveform.timing.sample_interval_mode == SampleIntervalMode.IRREGULAR + assert waveform.timing._timestamps == waveform_timestamps + + +def test___irregular_waveform_and_int32_ndarray_with_wrong_sample_count___load_data___raises_value_error_and_does_not_overwrite_data() -> ( + None +): + start_time = dt.datetime.now(dt.timezone.utc) + waveform_offsets = [dt.timedelta(0), dt.timedelta(1), dt.timedelta(2)] + waveform_timestamps = [start_time + offset for offset in waveform_offsets] + waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32) + waveform.timing = Timing.create_with_irregular_interval(waveform_timestamps) + array = np.array([3, 4], np.int32) + + with pytest.raises(ValueError) as exc: + waveform.load_data(array) + + assert exc.value.args[0].startswith( + "The input array length must be equal to the number of irregular timestamps." + ) + assert list(waveform.raw_data) == [0, 1, 2] + assert waveform.timing.sample_interval_mode == SampleIntervalMode.IRREGULAR + assert waveform.timing._timestamps == waveform_timestamps + + ############################################################################### # magic methods ###############################################################################