diff --git a/neo/core/baseneo.py b/neo/core/baseneo.py index 027b7f605..dc881fbed 100644 --- a/neo/core/baseneo.py +++ b/neo/core/baseneo.py @@ -90,7 +90,7 @@ def merge_annotation(a, b): return a -def merge_annotations(A, B): +def merge_annotations(A, *Bs): """ Merge two sets of annotations. @@ -102,21 +102,19 @@ def merge_annotations(A, B): For strings: concatenate with ';' Otherwise: warn if the annotations are not equal """ - merged = {} - for name in A: - if name in B: - try: - merged[name] = merge_annotation(A[name], B[name]) - except BaseException as exc: - # exc.args += ('key %s' % name,) - # raise - merged[name] = "MERGE CONFLICT" # temporary hack - else: - merged[name] = A[name] - for name in B: - if name not in merged: - merged[name] = B[name] - logger.debug("Merging annotations: A=%s B=%s merged=%s", A, B, merged) + merged = A.copy() + for B in Bs: + for name in B: + if name not in merged: + merged[name] = B[name] + else: + try: + merged[name] = merge_annotation(merged[name], B[name]) + except BaseException as exc: + # exc.args += ('key %s' % name,) + # raise + merged[name] = "MERGE CONFLICT" # temporary hack + logger.debug("Merging annotations: A=%s Bs=%s merged=%s", A, Bs, merged) return merged @@ -369,7 +367,7 @@ def _all_attrs(self): """ return self._necessary_attrs + self._recommended_attrs - def merge_annotations(self, other): + def merge_annotations(self, *others): """ Merge annotations from the other object into this one. @@ -381,17 +379,18 @@ def merge_annotations(self, other): For strings: concatenate with ';' Otherwise: fail if the annotations are not equal """ + other_annotations = [other.annotations for other in others] merged_annotations = merge_annotations(self.annotations, - other.annotations) + *other_annotations) self.annotations.update(merged_annotations) - def merge(self, other): + def merge(self, *others): """ Merge the contents of another object into this one. See :meth:`merge_annotations` for details of the merge operation. """ - self.merge_annotations(other) + self.merge_annotations(*others) def set_parent(self, obj): """ diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 20a87594c..0e39f8bed 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -20,6 +20,8 @@ # needed for python 3 compatibility from __future__ import absolute_import, division, print_function + +import neo import sys from copy import deepcopy, copy @@ -635,77 +637,119 @@ def time_shift(self, t_shift): return new_st - def merge(self, other): + def merge(self, *others): ''' - Merge another :class:`SpikeTrain` into this one. + Merge other :class:`SpikeTrain` objects into this one. The times of the :class:`SpikeTrain` objects combined in one array and sorted. - If the attributes of the two :class:`SpikeTrain` are not + If the attributes of the :class:`SpikeTrain` objects are not compatible, an Exception is raised. ''' - if self.sampling_rate != other.sampling_rate: - raise MergeError("Cannot merge, different sampling rates") - if self.t_start != other.t_start: - raise MergeError("Cannot merge, different t_start") - if self.t_stop != other.t_stop: - raise MemoryError("Cannot merge, different t_stop") - if self.left_sweep != other.left_sweep: - raise MemoryError("Cannot merge, different left_sweep") - if self.segment != other.segment: - raise MergeError("Cannot merge these two signals as they belong to" - " different segments.") - if hasattr(self, "lazy_shape"): - if hasattr(other, "lazy_shape"): - merged_lazy_shape = (self.lazy_shape[0] + other.lazy_shape[0]) - else: - raise MergeError("Cannot merge a lazy object with a real" - " object.") - if other.units != self.units: - other = other.rescale(self.units) - wfs = [self.waveforms is not None, other.waveforms is not None] + for other in others: + if isinstance(other, neo.io.proxyobjects.SpikeTrainProxy): + raise MergeError("Cannot merge, SpikeTrainProxy objects cannot be merged" + "into regular SpikeTrain objects, please load them first.") + elif not isinstance(other, SpikeTrain): + raise MergeError("Cannot merge, only SpikeTrain" + "can be merged into a SpikeTrain.") + if self.sampling_rate != other.sampling_rate: + raise MergeError("Cannot merge, different sampling rates") + if self.t_start != other.t_start: + raise MergeError("Cannot merge, different t_start") + if self.t_stop != other.t_stop: + raise MergeError("Cannot merge, different t_stop") + if self.left_sweep != other.left_sweep: + raise MergeError("Cannot merge, different left_sweep") + if self.segment != other.segment: + raise MergeError("Cannot merge these signals as they belong to" + " different segments.") + + all_spiketrains = [self] + all_spiketrains.extend([st.rescale(self.units) for st in others]) + + wfs = [st.waveforms is not None for st in all_spiketrains] if any(wfs) and not all(wfs): raise MergeError("Cannot merge signal with waveform and signal " "without waveform.") - stack = np.concatenate((np.asarray(self), np.asarray(other))) + stack = np.concatenate([np.asarray(st) for st in all_spiketrains]) sorting = np.argsort(stack) stack = stack[sorting] + kwargs = {} - kwargs['array_annotations'] = self._merge_array_annotations(other, sorting=sorting) + kwargs['array_annotations'] = self._merge_array_annotations(others, sorting=sorting) for name in ("name", "description", "file_origin"): - attr_self = getattr(self, name) - attr_other = getattr(other, name) - if attr_self == attr_other: - kwargs[name] = attr_self - else: - kwargs[name] = "merge({}, {})".format(attr_self, attr_other) - merged_annotations = merge_annotations(self.annotations, other.annotations) + attr = getattr(self, name) + + # check if self is already a merged spiketrain + # if it is, get rid of the bracket at the end to append more attributes + if attr is not None: + if attr.startswith('merge(') and attr.endswith(')'): + attr = attr[:-1] + + for other in others: + attr_other = getattr(other, name) + + # both attributes are None --> nothing to do + if attr is None and attr_other is None: + continue + + # one of the attributes is None --> convert to string in order to merge them + elif attr is None or attr_other is None: + attr = str(attr) + attr_other = str(attr_other) + + # check if the other spiketrain is already a merged spiketrain + # if it is, append all of its merged attributes that aren't already in attr + if attr_other.startswith('merge(') and attr_other.endswith(')'): + for subattr in attr_other[6:-1].split('; '): + if subattr not in attr: + attr += '; ' + subattr + if not attr.startswith('merge('): + attr = 'merge(' + attr + + # if the other attribute is not in the list --> append + # if attr doesn't already start with merge add merge( in the beginning + elif attr_other not in attr: + attr += '; ' + attr_other + if not attr.startswith('merge('): + attr = 'merge(' + attr + + # close the bracket of merge(...) if necessary + if attr is not None: + if attr.startswith('merge('): + attr += ')' + + # write attr into kwargs dict + kwargs[name] = attr + + merged_annotations = merge_annotations(*(st.annotations for st in + all_spiketrains)) kwargs.update(merged_annotations) train = SpikeTrain(stack, units=self.units, dtype=self.dtype, copy=False, t_start=self.t_start, t_stop=self.t_stop, sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs) if all(wfs): - wfs_stack = np.vstack((self.waveforms, other.waveforms)) - wfs_stack = wfs_stack[sorting] + wfs_stack = np.vstack([st.waveforms.rescale(self.waveforms.units) + for st in all_spiketrains]) + wfs_stack = wfs_stack[sorting] * self.waveforms.units train.waveforms = wfs_stack train.segment = self.segment if train.segment is not None: self.segment.spiketrains.append(train) - if hasattr(self, "lazy_shape"): - train.lazy_shape = merged_lazy_shape return train - def _merge_array_annotations(self, other, sorting=None): + def _merge_array_annotations(self, others, sorting=None): ''' - Merges array annotations of 2 different objects. + Merges array annotations of multiple different objects. The merge happens in such a way that the result fits the merged data - In general this means concatenating the arrays from the 2 objects. - If an annotation is only present in one of the objects, it will be omitted. + In general this means concatenating the arrays from the objects. + If an annotation is not present in one of the objects, it will be omitted. Apart from that the array_annotations need to be sorted according to the sorting of the spikes. :return Merged array_annotations @@ -721,7 +765,8 @@ def _merge_array_annotations(self, other, sorting=None): for key in keys: try: self_ann = deepcopy(self.array_annotations[key]) - other_ann = deepcopy(other.array_annotations[key]) + other_ann = np.concatenate([deepcopy(other.array_annotations[key]) + for other in others]) if isinstance(self_ann, pq.Quantity): other_ann.rescale(self_ann.units) arr_ann = np.concatenate([self_ann, other_ann]) * self_ann.units @@ -734,13 +779,14 @@ def _merge_array_annotations(self, other, sorting=None): omitted_keys_self.append(key) continue - omitted_keys_other = [key for key in other.array_annotations if - key not in self.array_annotations] + omitted_keys_other = [key for key in np.unique([key for other in others + for key in other.array_annotations]) + if key not in self.array_annotations] if omitted_keys_self or omitted_keys_other: warnings.warn("The following array annotations were omitted, because they were only " "present in one of the merged objects: {} from the one that was merged " - "into and {} from the one that was merged into the other" + "into and {} from the ones that were merged into it." "".format(omitted_keys_self, omitted_keys_other), UserWarning) return merged_array_annotations diff --git a/neo/test/coretest/test_base.py b/neo/test/coretest/test_base.py index 968ad8d95..e409e44d3 100644 --- a/neo/test/coretest/test_base.py +++ b/neo/test/coretest/test_base.py @@ -144,10 +144,13 @@ class Test_BaseNeo_merge_annotations_merge(unittest.TestCase): def setUp(self): self.name1 = 'a base 1' self.name2 = 'a base 2' + self.name3 = 'a base 3' self.description1 = 'this is a test 1' self.description2 = 'this is a test 2' + self.description3 = 'this is a test 3' self.base1 = BaseNeo(name=self.name1, description=self.description1) self.base2 = BaseNeo(name=self.name2, description=self.description2) + self.base3 = BaseNeo(name=self.name3, description=self.description3) def test_merge_annotations__dict(self): self.base1.annotations = {'val0': 'val0', 'val1': 1, @@ -189,6 +192,57 @@ def test_merge_annotations__dict(self): self.assertEqual(self.description1, self.base1.description) self.assertEqual(self.description2, self.base2.description) + def test_merge_multiple_annotations__dict(self): + self.base1.annotations = {'val0': 'val0', 'val1': 1, + 'val2': 2.2, 'val3': 'test1', + 'val4': [.4], 'val5': {0: 0, 1: {0: 0}}, + 'val6': np.array([0, 1, 2])} + self.base2.annotations = {'val2': 2.2, 'val3': 'test2', + 'val4': [4, 4.4], 'val5': {1: {1: 1}, 2: 2}, + 'val6': np.array([4, 5, 6]), 'val7': True} + self.base3.annotations = {'val2': 2.2, 'val3': 'test3', + 'val4': [44], 'val5': {1: {2: 2}, 2: 2, 3: 3}, + 'val6': np.array([8, 9, 10]), 'val8': False} + + ann1 = self.base1.annotations + ann2 = self.base2.annotations + ann3 = self.base3.annotations + ann1c = self.base1.annotations.copy() + ann2c = self.base2.annotations.copy() + ann3c = self.base3.annotations.copy() + + targ = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1;test2;test3', + 'val4': [.4, 4, 4.4, 44], 'val5': {0: 0, 1: {0: 0, 1: 1, 2: 2}, 2: 2, 3: 3}, + 'val7': True, 'val8': False} + + self.base1.merge_annotations(self.base2, self.base3) + + val6t = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10]) + val61 = ann1.pop('val6') + val61c = ann1c.pop('val6') + val62 = ann2.pop('val6') + val62c = ann2c.pop('val6') + val63 = ann3.pop('val6') + val63c = ann3c.pop('val6') + + self.assertEqual(ann1, self.base1.annotations) + self.assertNotEqual(ann1c, self.base1.annotations) + self.assertEqual(ann2c, self.base2.annotations) + self.assertEqual(ann3c, self.base3.annotations) + self.assertEqual(targ, self.base1.annotations) + + assert_arrays_equal(val61, val6t) + self.assertRaises(AssertionError, assert_arrays_equal, val61c, val6t) + assert_arrays_equal(val62, val62c) + assert_arrays_equal(val63, val63c) + + self.assertEqual(self.name1, self.base1.name) + self.assertEqual(self.name2, self.base2.name) + self.assertEqual(self.name3, self.base3.name) + self.assertEqual(self.description1, self.base1.description) + self.assertEqual(self.description2, self.base2.description) + self.assertEqual(self.description3, self.base3.description) + def test_merge_annotations__func__dict(self): ann1 = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1', 'val4': [.4], 'val5': {0: 0, 1: {0: 0}}, @@ -222,6 +276,47 @@ def test_merge_annotations__func__dict(self): assert_arrays_equal(val61, val61c) assert_arrays_equal(val62, val62c) + def test_merge_multiple_annotations__func__dict(self): + ann1 = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1', + 'val4': [.4], 'val5': {0: 0, 1: {0: 0}}, + 'val6': np.array([0, 1, 2])} + ann2 = {'val2': 2.2, 'val3': 'test2', + 'val4': [4, 4.4], 'val5': {1: {1: 1}, 2: 2}, + 'val6': np.array([4, 5, 6]), 'val7': True} + ann3 = {'val2': 2.2, 'val3': 'test3', + 'val4': [44], 'val5': {1: {2: 2}, 2: 2, 3: 3}, + 'val6': np.array([8, 9, 10]), 'val8': False} + + ann1c = ann1.copy() + ann2c = ann2.copy() + ann3c = ann3.copy() + + targ = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1;test2;test3', + 'val4': [.4, 4, 4.4, 44], 'val5': {0: 0, 1: {0: 0, 1: 1, 2: 2}, 2: 2, 3: 3}, + 'val7': True, 'val8': False} + + res = merge_annotations(ann1, ann2, ann3) + + val6t = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10]) + val6r = res.pop('val6') + val61 = ann1.pop('val6') + val61c = ann1c.pop('val6') + val62 = ann2.pop('val6') + val62c = ann2c.pop('val6') + val63 = ann3.pop('val6') + val63c = ann3c.pop('val6') + + self.assertEqual(ann1, ann1c) + self.assertEqual(ann2, ann2c) + self.assertEqual(ann3, ann3c) + self.assertEqual(res, targ) + + assert_arrays_equal(val6r, val6t) + self.assertRaises(AssertionError, assert_arrays_equal, val61, val6t) + assert_arrays_equal(val61, val61c) + assert_arrays_equal(val62, val62c) + assert_arrays_equal(val63, val63c) + def test_merge_annotation__func__str(self): ann1 = 'test1' ann2 = 'test2' @@ -343,6 +438,37 @@ def test_merge__dict(self): self.assertEqual(self.description1, self.base1.description) self.assertEqual(self.description2, self.base2.description) + def test_merge_multiple__dict(self): + self.base1.annotations = {'val0': 'val0', 'val1': 1, + 'val2': 2.2, 'val3': 'test1'} + self.base2.annotations = {'val2': 2.2, 'val3': 'test2', + 'val4': [4, 4.4], 'val5': True} + self.base3.annotations = {'val2': 2.2, 'val3': 'test3', + 'val4': [44], 'val5': True, 'val6': False} + + ann1 = self.base1.annotations + ann1c = self.base1.annotations.copy() + ann2c = self.base2.annotations.copy() + ann3c = self.base3.annotations.copy() + + targ = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1;test2;test3', + 'val4': [4, 4.4, 44], 'val5': True, 'val6': False} + + self.base1.merge(self.base2, self.base3) + + self.assertEqual(ann1, self.base1.annotations) + self.assertNotEqual(ann1c, self.base1.annotations) + self.assertEqual(ann2c, self.base2.annotations) + self.assertEqual(ann3c, self.base3.annotations) + self.assertEqual(targ, self.base1.annotations) + + self.assertEqual(self.name1, self.base1.name) + self.assertEqual(self.name2, self.base2.name) + self.assertEqual(self.name3, self.base3.name) + self.assertEqual(self.description1, self.base1.description) + self.assertEqual(self.description2, self.base2.description) + self.assertEqual(self.description3, self.base3.description) + def test_merge_annotations__different_type_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -355,6 +481,22 @@ def test_merge_annotations__different_type_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple_annotations__different_type_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': True} + self.base3.annotations = {'val5': 1, 'val6': 79, + 'val7': True} + self.base1.merge_annotations(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': 79, + 'val7': True}) + def test_merge__different_type_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -367,6 +509,22 @@ def test_merge__different_type_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple__different_type_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': True} + self.base3.annotations = {'val5': 3.1, 'val6': False, + 'val7': 'val7'} + self.base1.merge(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': False, + 'val7': 'val7'}) + def test_merge_annotations__unmergable_unequal_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -379,6 +537,22 @@ def test_merge_annotations__unmergable_unequal_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple_annotations__unmergable_unequal_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': 3.5} + self.base3.annotations = {'val5': 3.4, 'val6': [4, 4.4], + 'val7': True} + self.base1.merge_annotations(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': [4, 4.4], + 'val7': True}) + def test_merge__unmergable_unequal_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -391,6 +565,22 @@ def test_merge__unmergable_unequal_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple__unmergable_unequal_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': True} + self.base3.annotations = {'val5': 3.4, 'val6': [4, 4.4], + 'val7': True} + self.base1.merge(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': [4, 4.4], + 'val7': True}) + class TestBaseNeoCoreTypes(unittest.TestCase): ''' diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index 325f9bc3c..41a8a8f79 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -25,6 +25,9 @@ else: HAVE_IPYTHON = True +from neo.rawio.examplerawio import ExampleRawIO +from neo.io.proxyobjects import SpikeTrainProxy + from neo.core.spiketrain import (check_has_dimensions_time, SpikeTrain, _check_time_in_range, _new_spiketrain) from neo.core import Segment, Unit @@ -1291,8 +1294,8 @@ def test_merge_typical(self): "omitted, because they were only present" " in one of the merged objects: " "['label'] from the one that was merged " - "into and ['label2'] from the one that " - "was merged into the other") + "into and ['label2'] from the ones that " + "were merged into it.") assert_neo_object_is_compliant(result) @@ -1303,6 +1306,48 @@ def test_merge_typical(self): np.array([1, 101, 2, 102, 3, 103, 4, 104, 5, 105, 6, 106])) self.assertIsInstance(result.array_annotations, ArrayDict) + def test_merge_multiple(self): + self.train1.waveforms = None + + train3 = self.train1.duplicate_with_new_data(self.train1.times.magnitude * pq.microsecond) + train3.segment = self.train1.segment + train3.array_annotate(index=np.arange(301, 307)) + + train4 = self.train1.duplicate_with_new_data(self.train1.times / 2) + train4.segment = self.train1.segment + train4.array_annotate(index=np.arange(401, 407)) + + # Array annotations merge warning was already tested, can be ignored now + with warnings.catch_warnings(record=True) as w: + result = self.train1.merge(train3, train4) + self.assertEqual(len(w), 1) + self.assertTrue("array annotations" in str(w[0].message)) + + assert_neo_object_is_compliant(result) + + self.assertEqual(len(result.shape), 1) + self.assertEqual(result.shape[0], sum(len(st) + for st in (self.train1, train3, train4))) + + self.assertEqual(self.train1.sampling_rate, result.sampling_rate) + + time_unit = result.units + + expected = np.concatenate((self.train1.rescale(time_unit).times, + train3.rescale(time_unit).times, + train4.rescale(time_unit).times)) + expected *= time_unit + sorting = np.argsort(expected) + expected = expected[sorting] + np.testing.assert_array_equal(result.times, expected) + + # Make sure array annotations are merged correctly + self.assertTrue('label' not in result.array_annotations) + assert_arrays_equal(result.array_annotations['index'], + np.concatenate([st.array_annotations['index'] + for st in (self.train1, train3, train4)])[sorting]) + self.assertIsInstance(result.array_annotations, ArrayDict) + def test_merge_with_waveforms(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: @@ -1311,6 +1356,39 @@ def test_merge_with_waveforms(self): self.assertTrue("array annotations" in str(w[0].message)) assert_neo_object_is_compliant(result) + def test_merge_multiple_with_waveforms(self): + train3 = self.train1.duplicate_with_new_data(self.train1.times.magnitude * pq.microsecond) + train3.segment = self.train1.segment + train3.array_annotate(index=np.arange(301, 307)) + train3.waveforms = self.train1.waveforms / 10 + + train4 = self.train1.duplicate_with_new_data(self.train1.times / 2) + train4.segment = self.train1.segment + train4.array_annotate(index=np.arange(401, 407)) + train4.waveforms = self.train1.waveforms / 2 + + # Array annotations merge warning was already tested, can be ignored now + with warnings.catch_warnings(record=True) as w: + result = self.train1.merge(train3, train4) + self.assertEqual(len(w), 1) + self.assertTrue("array annotations" in str(w[0].message)) + + assert_neo_object_is_compliant(result) + self.assertEqual(len(result.shape), 1) + self.assertEqual(result.shape[0], sum(len(st) for st in (self.train1, train3, train4))) + + time_unit = result.units + + expected = np.concatenate((self.train1.rescale(time_unit).times, + train3.rescale(time_unit).times, + train4.rescale(time_unit).times)) + sorting = np.argsort(expected) + + assert_arrays_equal(result.waveforms, + np.vstack([st.waveforms.rescale(self.train1.waveforms.units) + for st in (self.train1, train3, train4)])[sorting] + * self.train1.waveforms.units) + def test_correct_shape(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: @@ -1358,6 +1436,63 @@ def test_rescaling_units(self): np.array([1, 2, 3, 4, 5, 6, 101, 102, 103, 104, 105, 106])) self.assertIsInstance(result.array_annotations, ArrayDict) + def test_name_file_origin_description(self): + self.train1.waveforms = None + self.train2.waveforms = None + self.train1.name = 'name1' + self.train1.description = 'desc1' + self.train1.file_origin = 'file1' + self.train2.name = 'name2' + self.train2.description = 'desc2' + self.train2.file_origin = 'file2' + + train3 = self.train1.duplicate_with_new_data(self.train1.times.magnitude * pq.microsecond) + train3.segment = self.train1.segment + train3.name = 'name3' + train3.description = 'desc3' + train3.file_origin = 'file3' + + train4 = self.train1.duplicate_with_new_data(self.train1.times / 2) + train4.segment = self.train1.segment + train4.name = 'name3' + train4.description = 'desc3' + train4.file_origin = 'file3' + + # merge two spiketrains with different attributes + merge1 = self.train1.merge(self.train2) + + self.assertEqual(merge1.name, 'merge(name1; name2)') + self.assertEqual(merge1.description, 'merge(desc1; desc2)') + self.assertEqual(merge1.file_origin, 'merge(file1; file2)') + + # merge a merged spiketrain with a regular one + merge2 = merge1.merge(train3) + + self.assertEqual(merge2.name, 'merge(name1; name2; name3)') + self.assertEqual(merge2.description, 'merge(desc1; desc2; desc3)') + self.assertEqual(merge2.file_origin, 'merge(file1; file2; file3)') + + # merge two merged spiketrains + merge3 = merge1.merge(merge2) + + self.assertEqual(merge3.name, 'merge(name1; name2; name3)') + self.assertEqual(merge3.description, 'merge(desc1; desc2; desc3)') + self.assertEqual(merge3.file_origin, 'merge(file1; file2; file3)') + + # merge two spiketrains with identical attributes + merge4 = train3.merge(train4) + + self.assertEqual(merge4.name, 'name3') + self.assertEqual(merge4.description, 'desc3') + self.assertEqual(merge4.file_origin, 'file3') + + # merge a reqular spiketrain with a merged spiketrain + merge5 = train3.merge(merge1) + + self.assertEqual(merge5.name, 'merge(name3; name1; name2)') + self.assertEqual(merge5.description, 'merge(desc3; desc1; desc2)') + self.assertEqual(merge5.file_origin, 'merge(file3; file1; file2)') + def test_sampling_rate(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: @@ -1390,6 +1525,56 @@ def test_incompatible_t_start(self): with self.assertRaises(MergeError): self.train2.merge(train3) + def test_merge_multiple_raise_merge_errors(self): + # different t_start + train3 = self.train1.duplicate_with_new_data(self.train1, t_start=-1 * pq.s) + train3.segment = self.train1.segment + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different t_stop + train3 = self.train1.duplicate_with_new_data(self.train1, t_stop=133 * pq.s) + train3.segment = self.train1.segment + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different segment + train3 = self.train1.duplicate_with_new_data(self.train1) + seg = Segment() + train3.segment = seg + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # missing waveforms + train3 = self.train1.duplicate_with_new_data(self.train1) + train3.waveforms = None + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different sampling rate + train3 = self.train1.duplicate_with_new_data(self.train1) + train3.sampling_rate = 1 * pq.s + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different left sweep + train3 = self.train1.duplicate_with_new_data(self.train1) + train3.left_sweep = 1 * pq.s + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + class TestDuplicateWithNewData(unittest.TestCase): def setUp(self):