Skip to content

Commit 93b1ce7

Browse files
authored
Merge pull request #836 from JuliaSprenger/enh/patch_signals
Add concatenate functionality for signal objects
2 parents d8febf7 + 2933091 commit 93b1ce7

File tree

7 files changed

+593
-30
lines changed

7 files changed

+593
-30
lines changed

neo/core/analogsignal.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import numpy as np
3030
import quantities as pq
3131

32-
from neo.core.baseneo import BaseNeo, MergeError, merge_annotations
32+
from neo.core.baseneo import BaseNeo, MergeError, merge_annotations, intersect_annotations
3333
from neo.core.dataobject import DataObject
3434
from copy import copy, deepcopy
3535

@@ -657,3 +657,136 @@ def rectify(self, **kwargs):
657657
rectified_signal.array_annotations = self.array_annotations.copy()
658658

659659
return rectified_signal
660+
661+
def concatenate(self, *signals, overwrite=False, padding=False):
662+
"""
663+
Concatenate multiple neo.AnalogSignal objects across time.
664+
665+
Units, sampling_rate and number of signal traces must be the same
666+
for all signals. Otherwise a ValueError is raised.
667+
Note that timestamps of concatenated signals might shift in oder to
668+
align the sampling times of all signals.
669+
670+
Parameters
671+
----------
672+
signals: neo.AnalogSignal objects
673+
AnalogSignals that will be concatenated
674+
overwrite : bool
675+
If True, samples of the earlier (lower index in `signals`)
676+
signals are overwritten by that of later (higher index in `signals`)
677+
signals.
678+
If False, samples of the later are overwritten by earlier signal.
679+
Default: False
680+
padding : bool, scalar quantity
681+
Sampling values to use as padding in case signals do not overlap.
682+
If False, do not apply padding. Signals have to align or
683+
overlap. If True, signals will be padded using
684+
np.NaN as pad values. If a scalar quantity is provided, this
685+
will be used for padding. The other signal is moved
686+
forward in time by maximum one sampling period to
687+
align the sampling times of both signals.
688+
Default: False
689+
690+
Returns
691+
-------
692+
signal: neo.AnalogSignal
693+
concatenated output signal
694+
"""
695+
696+
# Sanity of inputs
697+
if not hasattr(signals, '__iter__'):
698+
raise TypeError('signals must be iterable')
699+
if not all([isinstance(a, AnalogSignal) for a in signals]):
700+
raise TypeError('Entries of anasiglist have to be of type neo.AnalogSignal')
701+
if len(signals) == 0:
702+
return self
703+
704+
signals = [self] + list(signals)
705+
706+
# Check required common attributes: units, sampling_rate and shape[-1]
707+
shared_attributes = ['units', 'sampling_rate']
708+
attribute_values = [tuple((getattr(anasig, attr) for attr in shared_attributes))
709+
for anasig in signals]
710+
# add shape dimensions that do not relate to time
711+
attribute_values = [(attribute_values[i] + (signals[i].shape[1:],))
712+
for i in range(len(signals))]
713+
if not all([attrs == attribute_values[0] for attrs in attribute_values]):
714+
raise MergeError(
715+
f'AnalogSignals have to share {shared_attributes} attributes to be concatenated.')
716+
units, sr, shape = attribute_values[0]
717+
718+
# find gaps between Analogsignals
719+
combined_time_ranges = self._concatenate_time_ranges(
720+
[(s.t_start, s.t_stop) for s in signals])
721+
missing_time_ranges = self._invert_time_ranges(combined_time_ranges)
722+
if len(missing_time_ranges):
723+
diffs = np.diff(np.asarray(missing_time_ranges), axis=1)
724+
else:
725+
diffs = []
726+
727+
if padding is False and any(diffs > signals[0].sampling_period):
728+
raise MergeError(f'Signals are not continuous. Can not concatenate signals with gaps. '
729+
f'Please provide a padding value.')
730+
if padding is not False:
731+
logger.warning('Signals will be padded using {}.'.format(padding))
732+
if padding is True:
733+
padding = np.NaN * units
734+
if isinstance(padding, pq.Quantity):
735+
padding = padding.rescale(units).magnitude
736+
else:
737+
raise MergeError('Invalid type of padding value. Please provide a bool value '
738+
'or a quantities object.')
739+
740+
t_start = min([a.t_start for a in signals])
741+
t_stop = max([a.t_stop for a in signals])
742+
n_samples = int(np.rint(((t_stop - t_start) * sr).rescale('dimensionless').magnitude))
743+
shape = (n_samples,) + shape
744+
745+
# Collect attributes and annotations across all concatenated signals
746+
kwargs = {}
747+
common_annotations = signals[0].annotations
748+
common_array_annotations = signals[0].array_annotations
749+
for anasig in signals[1:]:
750+
common_annotations = intersect_annotations(common_annotations, anasig.annotations)
751+
common_array_annotations = intersect_annotations(common_array_annotations,
752+
anasig.array_annotations)
753+
754+
kwargs['annotations'] = common_annotations
755+
kwargs['array_annotations'] = common_array_annotations
756+
757+
for name in ("name", "description", "file_origin"):
758+
attr = [getattr(s, name) for s in signals]
759+
if all([a == attr[0] for a in attr]):
760+
kwargs[name] = attr[0]
761+
else:
762+
kwargs[name] = f'concatenation ({attr})'
763+
764+
conc_signal = AnalogSignal(np.full(shape=shape, fill_value=padding, dtype=signals[0].dtype),
765+
sampling_rate=sr, t_start=t_start, units=units, **kwargs)
766+
767+
if not overwrite:
768+
signals = signals[::-1]
769+
while len(signals) > 0:
770+
conc_signal.splice(signals.pop(0), copy=False)
771+
772+
return conc_signal
773+
774+
def _concatenate_time_ranges(self, time_ranges):
775+
time_ranges = sorted(time_ranges)
776+
new_ranges = time_ranges[:1]
777+
for t_start, t_stop in time_ranges[1:]:
778+
# time range are non continuous -> define new range
779+
if t_start > new_ranges[-1][1]:
780+
new_ranges.append((t_start, t_stop))
781+
# time range is continuous -> extend time range
782+
elif t_stop > new_ranges[-1][1]:
783+
new_ranges[-1] = (new_ranges[-1][0], t_stop)
784+
return new_ranges
785+
786+
def _invert_time_ranges(self, time_ranges):
787+
i = 0
788+
new_ranges = []
789+
while i < len(time_ranges) - 1:
790+
new_ranges.append((time_ranges[i][1], time_ranges[i + 1][0]))
791+
i += 1
792+
return new_ranges

neo/core/baseneo.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
used by all :module:`neo.core` classes.
44
"""
55

6+
from copy import deepcopy
67
from datetime import datetime, date, time, timedelta
78
from decimal import Decimal
89
import logging
@@ -109,6 +110,37 @@ def merge_annotations(A, *Bs):
109110
return merged
110111

111112

113+
def intersect_annotations(A, B):
114+
"""
115+
Identify common entries in dictionaries A and B
116+
and return these in a separate dictionary.
117+
118+
Entries have to share key as well as value to be
119+
considered common.
120+
121+
Parameters
122+
----------
123+
A, B : dict
124+
Dictionaries to merge.
125+
"""
126+
127+
result = {}
128+
129+
for key in set(A.keys()) & set(B.keys()):
130+
v1, v2 = A[key], B[key]
131+
assert type(v1) == type(v2), 'type({}) {} != type({}) {}'.format(v1, type(v1),
132+
v2, type(v2))
133+
if isinstance(v1, dict) and v1 == v2:
134+
result[key] = deepcopy(v1)
135+
elif isinstance(v1, str) and v1 == v2:
136+
result[key] = A[key]
137+
elif isinstance(v1, list) and v1 == v2:
138+
result[key] = deepcopy(v1)
139+
elif isinstance(v1, np.ndarray) and all(v1 == v2):
140+
result[key] = deepcopy(v1)
141+
return result
142+
143+
112144
def _reference_name(class_name):
113145
"""
114146
Given the name of a class, return an attribute name to be used for

neo/core/basesignal.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import quantities as pq
2323

24-
from neo.core.baseneo import BaseNeo, MergeError, merge_annotations
24+
from neo.core.baseneo import MergeError, merge_annotations
2525
from neo.core.dataobject import DataObject, ArrayDict
2626
from neo.core.channelindex import ChannelIndex
2727

@@ -282,11 +282,52 @@ def merge(self, other):
282282
# merge channel_index (move to ChannelIndex.merge()?)
283283
if self.channel_index and other.channel_index:
284284
signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]),
285-
channel_ids=np.hstack(
286-
[self.channel_index.channel_ids, other.channel_index.channel_ids]),
287-
channel_names=np.hstack(
288-
[self.channel_index.channel_names, other.channel_index.channel_names]))
285+
channel_ids=np.hstack(
286+
[self.channel_index.channel_ids,
287+
other.channel_index.channel_ids]),
288+
channel_names=np.hstack(
289+
[self.channel_index.channel_names,
290+
other.channel_index.channel_names]))
289291
else:
290292
signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]))
291293

292294
return signal
295+
296+
def time_slice(self, t_start, t_stop):
297+
'''
298+
Creates a new AnalogSignal corresponding to the time slice of the
299+
original Signal between times t_start, t_stop.
300+
'''
301+
NotImplementedError('Needs to be implemented for subclasses.')
302+
303+
def concatenate(self, *signals):
304+
'''
305+
Concatenate multiple signals across time.
306+
307+
The signal objects are concatenated vertically
308+
(row-wise, :func:`np.vstack`). Concatenation can be
309+
used to combine signals across segments.
310+
Note: Only (array) annotations common to
311+
both signals are attached to the concatenated signal.
312+
313+
If the attributes of the signals are not
314+
compatible, an Exception is raised.
315+
316+
Parameters
317+
----------
318+
signals : multiple neo.BaseSignal objects
319+
The objects that is concatenated with this one.
320+
321+
Returns
322+
-------
323+
signal : neo.BaseSignal
324+
Signal containing all non-overlapping samples of
325+
the source signals.
326+
327+
Raises
328+
------
329+
MergeError
330+
If `other` object has incompatible attributes.
331+
'''
332+
333+
NotImplementedError('Patching need to be implemented in sublcasses')

neo/core/irregularlysampledsignal.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import numpy as np
3232
import quantities as pq
3333

34-
from neo.core.baseneo import BaseNeo, MergeError, merge_annotations
34+
from neo.core.baseneo import MergeError, merge_annotations, intersect_annotations
3535
from neo.core.basesignal import BaseSignal
3636
from neo.core.analogsignal import AnalogSignal
3737
from neo.core.channelindex import ChannelIndex
@@ -514,3 +514,94 @@ def merge(self, other):
514514
signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]))
515515

516516
return signal
517+
518+
def concatenate(self, other, allow_overlap=False):
519+
'''
520+
Combine this and another signal along the time axis.
521+
522+
The signal objects are concatenated vertically
523+
(row-wise, :func:`np.vstack`). Patching can be
524+
used to combine signals across segments.
525+
Note: Only array annotations common to
526+
both signals are attached to the concatenated signal.
527+
528+
If the attributes of the two signal are not
529+
compatible, an Exception is raised.
530+
531+
Required attributes of the signal are used.
532+
533+
Parameters
534+
----------
535+
other : neo.BaseSignal
536+
The object that is merged into this one.
537+
allow_overlap : bool
538+
If false, overlapping samples between the two
539+
signals are not permitted and an ValueError is raised.
540+
If true, no check for overlapping samples is
541+
performed and all samples are combined.
542+
543+
Returns
544+
-------
545+
signal : neo.IrregularlySampledSignal
546+
Signal containing all non-overlapping samples of
547+
both source signals.
548+
549+
Raises
550+
------
551+
MergeError
552+
If `other` object has incompatible attributes.
553+
'''
554+
555+
for attr in self._necessary_attrs:
556+
if not (attr[0] in ['signal', 'times', 't_start', 't_stop', 'times']):
557+
if getattr(self, attr[0], None) != getattr(other, attr[0], None):
558+
raise MergeError(
559+
"Cannot concatenate these two signals as the %s differ." % attr[0])
560+
561+
if hasattr(self, "lazy_shape"):
562+
if hasattr(other, "lazy_shape"):
563+
if self.lazy_shape[-1] != other.lazy_shape[-1]:
564+
raise MergeError("Cannot concatenate signals as they contain"
565+
" different numbers of traces.")
566+
merged_lazy_shape = (self.lazy_shape[0] + other.lazy_shape[0], self.lazy_shape[-1])
567+
else:
568+
raise MergeError("Cannot concatenate a lazy object with a real object.")
569+
if other.units != self.units:
570+
other = other.rescale(self.units)
571+
572+
new_times = np.hstack((self.times, other.times))
573+
sorting = np.argsort(new_times)
574+
new_samples = np.vstack((self.magnitude, other.magnitude))
575+
576+
kwargs = {}
577+
for name in ("name", "description", "file_origin"):
578+
attr_self = getattr(self, name)
579+
attr_other = getattr(other, name)
580+
if attr_self == attr_other:
581+
kwargs[name] = attr_self
582+
else:
583+
kwargs[name] = "merge({}, {})".format(attr_self, attr_other)
584+
merged_annotations = merge_annotations(self.annotations, other.annotations)
585+
kwargs.update(merged_annotations)
586+
587+
kwargs['array_annotations'] = intersect_annotations(self.array_annotations,
588+
other.array_annotations)
589+
590+
if not allow_overlap:
591+
if max(self.t_start, other.t_start) <= min(self.t_stop, other.t_stop):
592+
raise ValueError('Can not combine signals that overlap in time. Allow for '
593+
'overlapping samples using the "no_overlap" parameter.')
594+
595+
t_start = min(self.t_start, other.t_start)
596+
t_stop = max(self.t_start, other.t_start)
597+
598+
signal = IrregularlySampledSignal(signal=new_samples[sorting], times=new_times[sorting],
599+
units=self.units, dtype=self.dtype, copy=False,
600+
t_start=t_start, t_stop=t_stop, **kwargs)
601+
signal.segment = None
602+
signal.channel_index = None
603+
604+
if hasattr(self, "lazy_shape"):
605+
signal.lazy_shape = merged_lazy_shape
606+
607+
return signal

0 commit comments

Comments
 (0)