Skip to content

Commit 2a0164e

Browse files
committed
Redefine patch parameter overwrite and add padding feature
1 parent dee5fad commit 2a0164e

File tree

2 files changed

+131
-36
lines changed

2 files changed

+131
-36
lines changed

neo/core/analogsignal.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def rectify(self, **kwargs):
660660

661661
return rectified_signal
662662

663-
def patch(self, other, overwrite=True):
663+
def patch(self, other, overwrite=True, padding=False):
664664
'''
665665
Patch another signal to this one.
666666
@@ -681,10 +681,24 @@ def patch(self, other, overwrite=True):
681681
----------
682682
other : neo.BaseSignal
683683
The object that is merged into this one.
684+
The other signal needs cover a later time period than
685+
this one, i.e. self.t_start < other.t_start
684686
overwrite : bool
685-
If False, samples of this signal are overwritten
686-
by other signal. If True, samples of other signal
687-
are overwritten by this signal. Default: True
687+
If True, samples of the earlier (smaller t_start)
688+
signal are overwritten by the later signal.
689+
If False, samples of the later (higher t_start)
690+
are overwritten by earlier signal.
691+
Default: False
692+
padding : bool, scalar quantity
693+
Sampling values to use as padding in case signals
694+
do not overlap.
695+
If False, do not apply padding. Signals have to align or
696+
overlap. If True, signals will be padded using
697+
np.NaN as pad values. If a scalar quantity is provided, this
698+
will be used for padding. The other signal is moved
699+
forward in time by maximum one sampling period to
700+
align the sampling times of both signals.
701+
Default: False
688702
689703
Returns
690704
-------
@@ -698,11 +712,21 @@ def patch(self, other, overwrite=True):
698712
If `other` object has incompatible attributes.
699713
'''
700714

715+
if other.units != self.units:
716+
other = other.rescale(self.units)
717+
718+
if self.t_start > other.t_start:
719+
signal1, signal2 = other, self
720+
else:
721+
signal1, signal2 = self, other
722+
# raise MergeError('Inconsistent timing of signals. Other signal needs to be later than'
723+
# ' this signal')
724+
701725
for attr in self._necessary_attrs:
702-
if 'signal' != attr[0]:
726+
if attr[0] not in ['signal', 't_start', 't_stop']:
703727
if getattr(self, attr[0], None) != getattr(other, attr[0], None):
704-
if attr[0] in ['t_start','t_stop']:
705-
continue
728+
# if attr[0] in ['t_start','t_stop']:
729+
# continue
706730
raise MergeError("Cannot patch these two signals as the %s differ." % attr[0])
707731

708732
if hasattr(self, "lazy_shape"):
@@ -713,24 +737,45 @@ def patch(self, other, overwrite=True):
713737
merged_lazy_shape = (self.lazy_shape[0] + other.lazy_shape[0], self.lazy_shape[-1])
714738
else:
715739
raise MergeError("Cannot patch a lazy object with a real object.")
716-
if other.units != self.units:
717-
other = other.rescale(self.units)
718740

719-
if self.t_start > other.t_stop:
720-
raise MergeError('Signals do not overlap.')
721-
722-
# adjust overlapping signals
723-
if self.t_stop + self.sampling_period >= other.t_start:
724-
if not overwrite: # removing samples of other signal
725-
slice_t_start = self.t_stop + self.sampling_period
726-
sliced_other = other.time_slice(slice_t_start, None)
727-
stack = np.vstack((self.magnitude, sliced_other.magnitude))
728-
else: # removing samples of this signal
729-
slice_t_stop = other.t_start - other.sampling_period
730-
sliced_self = self.time_slice(None, slice_t_stop)
731-
stack = np.vstack((sliced_self.magnitude, other.magnitude))
741+
# in case of non-overlapping signals consider padding
742+
if signal2.t_start > signal1.t_stop + signal1.sampling_period:
743+
if padding != False:
744+
logger.warning('Signals will be padded using {}.'.format(padding))
745+
pad_time = signal2.t_start-signal1.t_stop
746+
n_pad_samples = int(((pad_time)*self.sampling_rate).rescale('dimensionless'))
747+
if padding is True:
748+
padding = np.NaN * self.units
749+
if isinstance(padding, pq.Quantity):
750+
padding = padding.rescale(self.units).magnitude
751+
else:
752+
raise ValueError('Invalid type of padding value. Please provide a bool value '
753+
'or a quantities object.')
754+
pad_data = np.full((n_pad_samples,) + signal1.shape[1:], padding)
755+
756+
# create new signal 1 with extended data, but keep array_annotations
757+
signal_tmp = signal1.duplicate_with_new_data(np.vstack((signal1.magnitude, pad_data)))
758+
signal_tmp.array_annotations = signal1.array_annotations
759+
signal1 = signal_tmp
760+
else:
761+
raise MergeError('Signals do not overlap, but no padding is provided.'
762+
'Please provide a padding mode.')
763+
764+
# in case of overlapping signals slice according to overwrite parameter
765+
elif signal2.t_start < signal1.t_stop + signal1.sampling_period:
766+
n_samples = int(((signal1.t_stop - signal2.t_start)*signal1.sampling_rate).simplified)
767+
logger.warning('Overwriting {} samples while patching signals.'.format(n_samples))
768+
if not overwrite: # removing samples second signal
769+
slice_t_start = signal1.t_stop + signal1.sampling_period
770+
signal2 = signal2.time_slice(slice_t_start, None)
771+
else: # removing samples of the first signal
772+
slice_t_stop = signal2.t_start - signal2.sampling_period
773+
signal1 = signal1.time_slice(None, slice_t_stop)
732774
else:
733-
raise MergeError("Cannot patch signals with non-overlapping times")
775+
assert signal2.t_start == signal1.t_stop + signal1.sampling_period, \
776+
"Cannot patch signals with non-overlapping times"
777+
778+
stack = np.vstack((signal1.magnitude, signal2.magnitude))
734779

735780
kwargs = {}
736781
for name in ("name", "description", "file_origin"):
@@ -747,7 +792,7 @@ def patch(self, other, overwrite=True):
747792
other.array_annotations)
748793

749794
signal = self.__class__(stack, units=self.units, dtype=self.dtype, copy=False,
750-
t_start=self.t_start, sampling_rate=self.sampling_rate, **kwargs)
795+
t_start=signal1.t_start, sampling_rate=self.sampling_rate, **kwargs)
751796
signal.segment = None
752797
signal.channel_index = None
753798

neo/test/coretest/test_analogsignal.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,18 +1558,28 @@ def test__merge(self):
15581558
assert_arrays_equal(mergeddata24, targdata24)
15591559

15601560
def test_patch_simple(self):
1561-
signal1 = AnalogSignal([0,1,2,3]*pq.s, sampling_rate=1*pq.Hz)
1562-
signal2 = AnalogSignal([4,5,6]*pq.s, sampling_rate=1*pq.Hz,
1561+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1562+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
15631563
t_start=signal1.t_stop + signal1.sampling_period)
15641564

15651565
result = signal1.patch(signal2)
15661566
assert_array_equal(np.arange(7).reshape((-1, 1)), result.magnitude)
15671567
for attr in signal1._necessary_attrs:
15681568
self.assertEqual(getattr(signal1, attr[0], None), getattr(result, attr[0], None))
15691569

1570+
def test_patch_inverse_signals(self):
1571+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1572+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
1573+
t_start=signal1.t_stop + signal1.sampling_period)
1574+
1575+
result = signal2.patch(signal1)
1576+
assert_array_equal(np.arange(7).reshape((-1, 1)), result.magnitude)
1577+
for attr in signal1._necessary_attrs:
1578+
self.assertEqual(getattr(signal1, attr[0], None), getattr(result, attr[0], None))
1579+
15701580
def test_patch_no_overlap(self):
1571-
signal1 = AnalogSignal([0,1,2,3]*pq.s, sampling_rate=1*pq.Hz)
1572-
signal2 = AnalogSignal([4,5,6]*pq.s, sampling_rate=1*pq.Hz,
1581+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1582+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
15731583
t_start=10*pq.s + signal1.sampling_period)
15741584

15751585
with self.assertRaises(MergeError):
@@ -1578,8 +1588,8 @@ def test_patch_no_overlap(self):
15781588
def test_patch_multi_trace(self):
15791589
data1 = np.arange(4).reshape(2,2)
15801590
data2 = np.arange(4,8).reshape(2,2)
1581-
signal1 = AnalogSignal(data1*pq.s, sampling_rate=1*pq.Hz)
1582-
signal2 = AnalogSignal(data2*pq.s, sampling_rate=1*pq.Hz,
1591+
signal1 = AnalogSignal(data1*pq.V, sampling_rate=1*pq.Hz)
1592+
signal2 = AnalogSignal(data2*pq.V, sampling_rate=1*pq.Hz,
15831593
t_start=signal1.t_stop + signal1.sampling_period)
15841594

15851595
result = signal1.patch(signal2)
@@ -1589,30 +1599,70 @@ def test_patch_multi_trace(self):
15891599
self.assertEqual(getattr(signal1, attr[0], None), getattr(result, attr[0], None))
15901600

15911601
def test_patch_overwrite_true(self):
1592-
signal1 = AnalogSignal([0,1,2,3]*pq.s, sampling_rate=1*pq.Hz)
1593-
signal2 = AnalogSignal([4,5,6]*pq.s, sampling_rate=1*pq.Hz,
1602+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1603+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
15941604
t_start=signal1.t_stop)
15951605

15961606
result = signal1.patch(signal2, overwrite=True)
15971607
assert_array_equal(np.array([0,1,2,4,5,6]).reshape((-1, 1)), result.magnitude)
15981608

15991609
def test_patch_overwrite_false(self):
1600-
signal1 = AnalogSignal([0,1,2,3]*pq.s, sampling_rate=1*pq.Hz)
1601-
signal2 = AnalogSignal([4,5,6]*pq.s, sampling_rate=1*pq.Hz,
1610+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1611+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
16021612
t_start=signal1.t_stop)
16031613

16041614
result = signal1.patch(signal2, overwrite=False)
16051615
assert_array_equal(np.array([0,1,2,3,5,6]).reshape((-1, 1)), result.magnitude)
16061616

1617+
def test_patch_padding_False(self):
1618+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1619+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
1620+
t_start=10*pq.s)
1621+
1622+
with self.assertRaises(MergeError):
1623+
result = signal1.patch(signal2, overwrite=False, padding=False)
1624+
1625+
def test_patch_padding_True(self):
1626+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1627+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
1628+
t_start=signal1.t_stop + 3 * signal1.sampling_period)
1629+
1630+
result = signal1.patch(signal2, overwrite=False, padding=True)
1631+
assert_array_equal(np.array([0,1,2,3,np.NaN,np.NaN,np.NaN,4,5,6]).reshape((-1, 1)),
1632+
result.magnitude)
1633+
1634+
def test_patch_padding_quantity(self):
1635+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1636+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
1637+
t_start=signal1.t_stop + 3 * signal1.sampling_period)
1638+
1639+
result = signal1.patch(signal2, overwrite=False, padding=-1*pq.mV)
1640+
assert_array_equal(np.array([0,1,2,3,-1e-3,-1e-3,-1e-3,4,5,6]).reshape((-1, 1)),
1641+
result.magnitude)
1642+
1643+
def test_patch_padding_invalid(self):
1644+
signal1 = AnalogSignal([0,1,2,3]*pq.V, sampling_rate=1*pq.Hz)
1645+
signal2 = AnalogSignal([4,5,6]*pq.V, sampling_rate=1*pq.Hz,
1646+
t_start=signal1.t_stop + 3 * signal1.sampling_period)
1647+
1648+
with self.assertRaises(ValueError):
1649+
result = signal1.patch(signal2, overwrite=False, padding=1)
1650+
with self.assertRaises(ValueError):
1651+
result = signal1.patch(signal2, overwrite=False, padding=[1])
1652+
with self.assertRaises(ValueError):
1653+
result = signal1.patch(signal2, overwrite=False, padding='a')
1654+
with self.assertRaises(ValueError):
1655+
result = signal1.patch(signal2, overwrite=False, padding=np.array([1,2,3]))
1656+
16071657
def test_patch_array_annotations(self):
16081658
array_anno1 = {'first': ['a','b']}
16091659
array_anno2 = {'first': ['a','b'],
16101660
'second': ['c','d']}
16111661
data1 = np.arange(4).reshape(2,2)
16121662
data2 = np.arange(4,8).reshape(2,2)
1613-
signal1 = AnalogSignal(data1*pq.s, sampling_rate=1*pq.Hz,
1663+
signal1 = AnalogSignal(data1*pq.V, sampling_rate=1*pq.Hz,
16141664
array_annotations=array_anno1)
1615-
signal2 = AnalogSignal(data2*pq.s, sampling_rate=1*pq.Hz,
1665+
signal2 = AnalogSignal(data2*pq.V, sampling_rate=1*pq.Hz,
16161666
t_start=signal1.t_stop + signal1.sampling_period,
16171667
array_annotations=array_anno2)
16181668

0 commit comments

Comments
 (0)