diff --git a/neo/test/utils/test_misc.py b/neo/test/utils/test_misc.py index 7291307f2..f547cd159 100644 --- a/neo/test/utils/test_misc.py +++ b/neo/test/utils/test_misc.py @@ -3,6 +3,7 @@ """ import unittest +import warnings import numpy as np import quantities as pq @@ -281,11 +282,11 @@ def test__get_epochs(self): def test__add_epoch(self): starts = Event(times=[0.5, 10.0, 25.2] * pq.s) - starts.annotate(event_type='trial start') + starts.annotate(event_type='trial start', nix_name='neo.event.0') starts.array_annotate(trial_id=[1, 2, 3]) stops = Event(times=[5.5, 14.9, 30.1] * pq.s) - stops.annotate(event_type='trial stop') + stops.annotate(event_type='trial stop', nix_name='neo.event.1') stops.array_annotate(trial_id=[1, 2, 3]) seg = Segment() @@ -295,7 +296,7 @@ def test__add_epoch(self): ep_starts = add_epoch(seg, starts, pre=-300 * pq.ms, post=250 * pq.ms) assert_neo_object_is_compliant(ep_starts) - assert_same_annotations(ep_starts, starts) + self.assertDictEqual(ep_starts.annotations, {'event_type': 'trial start'}) assert_arrays_almost_equal(ep_starts.times, starts.times - 300 * pq.ms, 1e-12) assert_arrays_almost_equal(ep_starts.durations, (550 * pq.ms).rescale(ep_starts.durations.units) @@ -305,7 +306,7 @@ def test__add_epoch(self): ep_trials = add_epoch(seg, starts, stops) assert_neo_object_is_compliant(ep_trials) - assert_same_annotations(ep_trials, starts) + self.assertDictEqual(ep_trials.annotations, {'event_type': 'trial start'}) assert_arrays_almost_equal(ep_trials.times, starts.times, 1e-12) assert_arrays_almost_equal(ep_trials.durations, stops - starts, 1e-12) @@ -337,16 +338,16 @@ def test__match_events(self): def test__cut_block_by_epochs(self): epoch = Epoch([0.5, 10.0, 25.2] * pq.s, durations=[5.1, 4.8, 5.0] * pq.s, t_start=.1 * pq.s) - epoch.annotate(epoch_type='a', pick='me') + epoch.annotate(epoch_type='a', pick='me', nix_name='neo.epoch.0') epoch.array_annotate(trial_id=[1, 2, 3]) epoch2 = Epoch([0.6, 9.5, 16.8, 34.1] * pq.s, durations=[4.5, 4.8, 5.0, 5.0] * pq.s, t_start=.1 * pq.s) - epoch2.annotate(epoch_type='b') + epoch2.annotate(epoch_type='b', nix_name='neo.epoch.1') epoch2.array_annotate(trial_id=[1, 2, 3, 4]) event = Event(times=[0.5, 10.0, 25.2] * pq.s, t_start=.1 * pq.s) - event.annotate(event_type='trial start') + event.annotate(event_type='trial start', nix_name='neo.event.0') event.array_annotate(trial_id=[1, 2, 3]) anasig = AnalogSignal(np.arange(50.0) * pq.mV, t_start=.1 * pq.s, @@ -362,8 +363,8 @@ def test__cut_block_by_epochs(self): array_annotations={'spikenum': np.arange(1, 9)}) # test without resetting the time - seg = Segment() - seg2 = Segment(name='NoCut') + seg = Segment(nix_name='neo.segment.0') + seg2 = Segment(name='NoCut', nix_name='neo.segment.1') seg.epochs = [epoch, epoch2] seg.events = [event] seg.analogsignals = [anasig] @@ -374,7 +375,10 @@ def test__cut_block_by_epochs(self): original_block.segments = [seg, seg2] original_block.create_many_to_one_relationship() - block = cut_block_by_epochs(original_block, properties={'pick': 'me'}) + with warnings.catch_warnings(record=True) as w: + # This should raise a warning as one segment does not contain epochs + block = cut_block_by_epochs(original_block, properties={'pick': 'me'}) + self.assertEqual(len(w), 1) assert_neo_object_is_compliant(block) self.assertEqual(len(block.segments), 3) @@ -385,6 +389,10 @@ def test__cut_block_by_epochs(self): self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1) self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1) + annos = block.segments[epoch_idx].annotations + # new segment objects have different identity + self.assertNotIn('nix_name', annos) + if epoch_idx != 0: self.assertEqual(len(block.segments[epoch_idx].epochs), 1) else: @@ -414,8 +422,8 @@ def test__cut_block_by_epochs(self): t_stop=epoch.times[0] + epoch.durations[0])) # test with resetting the time - seg = Segment() - seg2 = Segment(name='NoCut') + seg = Segment(nix_name='neo.segment.0') + seg2 = Segment(name='NoCut', nix_name='neo.segment.1') seg.epochs = [epoch, epoch2] seg.events = [event] seg.analogsignals = [anasig] @@ -426,7 +434,10 @@ def test__cut_block_by_epochs(self): original_block.segments = [seg, seg2] original_block.create_many_to_one_relationship() - block = cut_block_by_epochs(original_block, properties={'pick': 'me'}, reset_time=True) + with warnings.catch_warnings(record=True) as w: + # This should raise a warning as one segment does not contain epochs + block = cut_block_by_epochs(original_block, properties={'pick': 'me'}, reset_time=True) + self.assertEqual(len(w), 1) assert_neo_object_is_compliant(block) self.assertEqual(len(block.segments), 3) @@ -436,6 +447,10 @@ def test__cut_block_by_epochs(self): self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1) self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1) self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1) + + annos = block.segments[epoch_idx].annotations + self.assertNotIn('nix_name', annos) + if epoch_idx != 0: self.assertEqual(len(block.segments[epoch_idx].epochs), 1) else: @@ -527,6 +542,9 @@ def test__add_epoch(self): regular_event = Event(times=loaded_event.times - 1 * loaded_event.units) + loaded_event.annotate(nix_name='neo.event.0') + regular_event.annotate(nix_name='neo.event.1') + seg = Segment() seg.events = [regular_event, proxy_event] @@ -534,7 +552,8 @@ def test__add_epoch(self): epoch = add_epoch(seg, regular_event, proxy_event) assert_neo_object_is_compliant(epoch) - assert_same_annotations(epoch, regular_event) + exp_annos = {k: v for k, v in regular_event.annotations.items() if k != 'nix_name'} + self.assertDictEqual(epoch.annotations, exp_annos) assert_arrays_almost_equal(epoch.times, regular_event.times, 1e-12) assert_arrays_almost_equal(epoch.durations, np.ones(regular_event.shape) * loaded_event.units, 1e-12) diff --git a/neo/utils/misc.py b/neo/utils/misc.py index 3cd6f9003..00980eb04 100644 --- a/neo/utils/misc.py +++ b/neo/utils/misc.py @@ -3,12 +3,15 @@ etc. of neo.core objects. ''' -import neo import copy import warnings + import numpy as np import quantities as pq +import neo + +reserved_annotations = ['nix_name'] def get_events(container, **properties): """ @@ -345,8 +348,8 @@ def add_epoch( ep = neo.Epoch(times=times, durations=durations, **kwargs) - ep.annotate(**event1.annotations) - ep.array_annotate(**event1.array_annotations) + ep.annotate(**clean_annotations(event1.annotations)) + ep.array_annotate(**clean_annotations(event1.array_annotations)) if attach_result: segment.epochs.append(ep) @@ -543,10 +546,11 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False): epoch.times[ep_id] + epoch.durations[ep_id], reset_time=reset_time) - subseg.annotate(**copy.copy(epoch.annotations)) + subseg.annotations = clean_annotations(subseg.annotations) + subseg.annotate(**clean_annotations(epoch.annotations)) # Add array-annotations of Epoch - for key, val in epoch.array_annotations.items(): + for key, val in clean_annotations(epoch.array_annotations).items(): if len(val): subseg.annotations[key] = copy.copy(val[ep_id]) @@ -555,6 +559,23 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False): return segments +def clean_annotations(dictionary): + """ + Remove reserved keys from an annotation dictionary. + + Parameters + ---------- + dictionary: dict + annotation dictionary to be cleaned + + Returns: + -------- + dict + A cleaned version of the annotations + """ + return {k: v for k, v in dictionary.items() if k not in reserved_annotations} + + def is_block_rawio_compatible(block, return_problems=False): """ The neo.rawio layer have some restriction compared to neo.io layer: