Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions neo/core/baseneo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def merge_annotation(a, b):
return a


def merge_annotations(A, B):
def merge_annotations(A, *Bs):
"""
Merge two sets of annotations.

Expand All @@ -102,21 +102,19 @@ def merge_annotations(A, B):
For strings: concatenate with ';'
Otherwise: warn if the annotations are not equal
"""
merged = {}
for name in A:
if name in B:
try:
merged[name] = merge_annotation(A[name], B[name])
except BaseException as exc:
# exc.args += ('key %s' % name,)
# raise
merged[name] = "MERGE CONFLICT" # temporary hack
else:
merged[name] = A[name]
for name in B:
if name not in merged:
merged[name] = B[name]
logger.debug("Merging annotations: A=%s B=%s merged=%s", A, B, merged)
merged = A.copy()
for B in Bs:
for name in B:
if name not in merged:
merged[name] = B[name]
else:
try:
merged[name] = merge_annotation(merged[name], B[name])
except BaseException as exc:
# exc.args += ('key %s' % name,)
# raise
merged[name] = "MERGE CONFLICT" # temporary hack
logger.debug("Merging annotations: A=%s Bs=%s merged=%s", A, Bs, merged)
return merged


Expand Down Expand Up @@ -369,7 +367,7 @@ def _all_attrs(self):
"""
return self._necessary_attrs + self._recommended_attrs

def merge_annotations(self, other):
def merge_annotations(self, *others):
"""
Merge annotations from the other object into this one.

Expand All @@ -381,17 +379,18 @@ def merge_annotations(self, other):
For strings: concatenate with ';'
Otherwise: fail if the annotations are not equal
"""
other_annotations = [other.annotations for other in others]
merged_annotations = merge_annotations(self.annotations,
other.annotations)
*other_annotations)
self.annotations.update(merged_annotations)

def merge(self, other):
def merge(self, *others):
"""
Merge the contents of another object into this one.

See :meth:`merge_annotations` for details of the merge operation.
"""
self.merge_annotations(other)
self.merge_annotations(*others)

def set_parent(self, obj):
"""
Expand Down
134 changes: 90 additions & 44 deletions neo/core/spiketrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

# needed for python 3 compatibility
from __future__ import absolute_import, division, print_function

import neo
import sys

from copy import deepcopy, copy
Expand Down Expand Up @@ -635,77 +637,119 @@ def time_shift(self, t_shift):

return new_st

def merge(self, other):
def merge(self, *others):
'''
Merge another :class:`SpikeTrain` into this one.
Merge other :class:`SpikeTrain` objects into this one.

The times of the :class:`SpikeTrain` objects combined in one array
and sorted.

If the attributes of the two :class:`SpikeTrain` are not
If the attributes of the :class:`SpikeTrain` objects are not
compatible, an Exception is raised.
'''
if self.sampling_rate != other.sampling_rate:
raise MergeError("Cannot merge, different sampling rates")
if self.t_start != other.t_start:
raise MergeError("Cannot merge, different t_start")
if self.t_stop != other.t_stop:
raise MemoryError("Cannot merge, different t_stop")
if self.left_sweep != other.left_sweep:
raise MemoryError("Cannot merge, different left_sweep")
if self.segment != other.segment:
raise MergeError("Cannot merge these two signals as they belong to"
" different segments.")
if hasattr(self, "lazy_shape"):
if hasattr(other, "lazy_shape"):
merged_lazy_shape = (self.lazy_shape[0] + other.lazy_shape[0])
else:
raise MergeError("Cannot merge a lazy object with a real"
" object.")
if other.units != self.units:
other = other.rescale(self.units)
wfs = [self.waveforms is not None, other.waveforms is not None]
for other in others:
if isinstance(other, neo.io.proxyobjects.SpikeTrainProxy):
raise MergeError("Cannot merge, SpikeTrainProxy objects cannot be merged"
"into regular SpikeTrain objects, please load them first.")
elif not isinstance(other, SpikeTrain):
raise MergeError("Cannot merge, only SpikeTrain"
"can be merged into a SpikeTrain.")
if self.sampling_rate != other.sampling_rate:
raise MergeError("Cannot merge, different sampling rates")
if self.t_start != other.t_start:
raise MergeError("Cannot merge, different t_start")
if self.t_stop != other.t_stop:
raise MergeError("Cannot merge, different t_stop")
if self.left_sweep != other.left_sweep:
raise MergeError("Cannot merge, different left_sweep")
if self.segment != other.segment:
raise MergeError("Cannot merge these signals as they belong to"
" different segments.")

all_spiketrains = [self]
all_spiketrains.extend([st.rescale(self.units) for st in others])

wfs = [st.waveforms is not None for st in all_spiketrains]
if any(wfs) and not all(wfs):
raise MergeError("Cannot merge signal with waveform and signal "
"without waveform.")
stack = np.concatenate((np.asarray(self), np.asarray(other)))
stack = np.concatenate([np.asarray(st) for st in all_spiketrains])
sorting = np.argsort(stack)
stack = stack[sorting]

kwargs = {}

kwargs['array_annotations'] = self._merge_array_annotations(other, sorting=sorting)
kwargs['array_annotations'] = self._merge_array_annotations(others, sorting=sorting)

for name in ("name", "description", "file_origin"):
attr_self = getattr(self, name)
attr_other = getattr(other, name)
if attr_self == attr_other:
kwargs[name] = attr_self
else:
kwargs[name] = "merge({}, {})".format(attr_self, attr_other)
merged_annotations = merge_annotations(self.annotations, other.annotations)
attr = getattr(self, name)

# check if self is already a merged spiketrain
# if it is, get rid of the bracket at the end to append more attributes
if attr is not None:
if attr.startswith('merge(') and attr.endswith(')'):
attr = attr[:-1]

for other in others:
attr_other = getattr(other, name)

# both attributes are None --> nothing to do
if attr is None and attr_other is None:
continue

# one of the attributes is None --> convert to string in order to merge them
elif attr is None or attr_other is None:
attr = str(attr)
attr_other = str(attr_other)

# check if the other spiketrain is already a merged spiketrain
# if it is, append all of its merged attributes that aren't already in attr
if attr_other.startswith('merge(') and attr_other.endswith(')'):
for subattr in attr_other[6:-1].split('; '):
if subattr not in attr:
attr += '; ' + subattr
if not attr.startswith('merge('):
attr = 'merge(' + attr

# if the other attribute is not in the list --> append
# if attr doesn't already start with merge add merge( in the beginning
elif attr_other not in attr:
attr += '; ' + attr_other
if not attr.startswith('merge('):
attr = 'merge(' + attr

# close the bracket of merge(...) if necessary
if attr is not None:
if attr.startswith('merge('):
attr += ')'

# write attr into kwargs dict
kwargs[name] = attr

merged_annotations = merge_annotations(*(st.annotations for st in
all_spiketrains))
kwargs.update(merged_annotations)

train = SpikeTrain(stack, units=self.units, dtype=self.dtype, copy=False,
t_start=self.t_start, t_stop=self.t_stop,
sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs)
if all(wfs):
wfs_stack = np.vstack((self.waveforms, other.waveforms))
wfs_stack = wfs_stack[sorting]
wfs_stack = np.vstack([st.waveforms.rescale(self.waveforms.units)
for st in all_spiketrains])
wfs_stack = wfs_stack[sorting] * self.waveforms.units
train.waveforms = wfs_stack
train.segment = self.segment
if train.segment is not None:
self.segment.spiketrains.append(train)

if hasattr(self, "lazy_shape"):
train.lazy_shape = merged_lazy_shape
return train

def _merge_array_annotations(self, other, sorting=None):
def _merge_array_annotations(self, others, sorting=None):
'''
Merges array annotations of 2 different objects.
Merges array annotations of multiple different objects.
The merge happens in such a way that the result fits the merged data
In general this means concatenating the arrays from the 2 objects.
If an annotation is only present in one of the objects, it will be omitted.
In general this means concatenating the arrays from the objects.
If an annotation is not present in one of the objects, it will be omitted.
Apart from that the array_annotations need to be sorted according to the sorting of
the spikes.
:return Merged array_annotations
Expand All @@ -721,7 +765,8 @@ def _merge_array_annotations(self, other, sorting=None):
for key in keys:
try:
self_ann = deepcopy(self.array_annotations[key])
other_ann = deepcopy(other.array_annotations[key])
other_ann = np.concatenate([deepcopy(other.array_annotations[key])
for other in others])
if isinstance(self_ann, pq.Quantity):
other_ann.rescale(self_ann.units)
arr_ann = np.concatenate([self_ann, other_ann]) * self_ann.units
Expand All @@ -734,13 +779,14 @@ def _merge_array_annotations(self, other, sorting=None):
omitted_keys_self.append(key)
continue

omitted_keys_other = [key for key in other.array_annotations if
key not in self.array_annotations]
omitted_keys_other = [key for key in np.unique([key for other in others
for key in other.array_annotations])
if key not in self.array_annotations]

if omitted_keys_self or omitted_keys_other:
warnings.warn("The following array annotations were omitted, because they were only "
"present in one of the merged objects: {} from the one that was merged "
"into and {} from the one that was merged into the other"
"into and {} from the ones that were merged into it."
"".format(omitted_keys_self, omitted_keys_other), UserWarning)

return merged_array_annotations
Expand Down
Loading