Skip to content

Commit 9d64537

Browse files
authored
Merge pull request #1016 from JuliaSprenger/fix/anno_copies
[utils] do not copy nix_name annotation to different object types
2 parents cb45900 + d9ba05d commit 9d64537

File tree

2 files changed

+59
-19
lines changed

2 files changed

+59
-19
lines changed

neo/test/utils/test_misc.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import unittest
6+
import warnings
67

78
import numpy as np
89
import quantities as pq
@@ -281,11 +282,11 @@ def test__get_epochs(self):
281282

282283
def test__add_epoch(self):
283284
starts = Event(times=[0.5, 10.0, 25.2] * pq.s)
284-
starts.annotate(event_type='trial start')
285+
starts.annotate(event_type='trial start', nix_name='neo.event.0')
285286
starts.array_annotate(trial_id=[1, 2, 3])
286287

287288
stops = Event(times=[5.5, 14.9, 30.1] * pq.s)
288-
stops.annotate(event_type='trial stop')
289+
stops.annotate(event_type='trial stop', nix_name='neo.event.1')
289290
stops.array_annotate(trial_id=[1, 2, 3])
290291

291292
seg = Segment()
@@ -295,7 +296,7 @@ def test__add_epoch(self):
295296
ep_starts = add_epoch(seg, starts, pre=-300 * pq.ms, post=250 * pq.ms)
296297

297298
assert_neo_object_is_compliant(ep_starts)
298-
assert_same_annotations(ep_starts, starts)
299+
self.assertDictEqual(ep_starts.annotations, {'event_type': 'trial start'})
299300
assert_arrays_almost_equal(ep_starts.times, starts.times - 300 * pq.ms, 1e-12)
300301
assert_arrays_almost_equal(ep_starts.durations,
301302
(550 * pq.ms).rescale(ep_starts.durations.units)
@@ -305,7 +306,7 @@ def test__add_epoch(self):
305306
ep_trials = add_epoch(seg, starts, stops)
306307

307308
assert_neo_object_is_compliant(ep_trials)
308-
assert_same_annotations(ep_trials, starts)
309+
self.assertDictEqual(ep_trials.annotations, {'event_type': 'trial start'})
309310
assert_arrays_almost_equal(ep_trials.times, starts.times, 1e-12)
310311
assert_arrays_almost_equal(ep_trials.durations, stops - starts, 1e-12)
311312

@@ -337,16 +338,16 @@ def test__match_events(self):
337338
def test__cut_block_by_epochs(self):
338339
epoch = Epoch([0.5, 10.0, 25.2] * pq.s, durations=[5.1, 4.8, 5.0] * pq.s,
339340
t_start=.1 * pq.s)
340-
epoch.annotate(epoch_type='a', pick='me')
341+
epoch.annotate(epoch_type='a', pick='me', nix_name='neo.epoch.0')
341342
epoch.array_annotate(trial_id=[1, 2, 3])
342343

343344
epoch2 = Epoch([0.6, 9.5, 16.8, 34.1] * pq.s, durations=[4.5, 4.8, 5.0, 5.0] * pq.s,
344345
t_start=.1 * pq.s)
345-
epoch2.annotate(epoch_type='b')
346+
epoch2.annotate(epoch_type='b', nix_name='neo.epoch.1')
346347
epoch2.array_annotate(trial_id=[1, 2, 3, 4])
347348

348349
event = Event(times=[0.5, 10.0, 25.2] * pq.s, t_start=.1 * pq.s)
349-
event.annotate(event_type='trial start')
350+
event.annotate(event_type='trial start', nix_name='neo.event.0')
350351
event.array_annotate(trial_id=[1, 2, 3])
351352

352353
anasig = AnalogSignal(np.arange(50.0) * pq.mV, t_start=.1 * pq.s,
@@ -362,8 +363,8 @@ def test__cut_block_by_epochs(self):
362363
array_annotations={'spikenum': np.arange(1, 9)})
363364

364365
# test without resetting the time
365-
seg = Segment()
366-
seg2 = Segment(name='NoCut')
366+
seg = Segment(nix_name='neo.segment.0')
367+
seg2 = Segment(name='NoCut', nix_name='neo.segment.1')
367368
seg.epochs = [epoch, epoch2]
368369
seg.events = [event]
369370
seg.analogsignals = [anasig]
@@ -374,7 +375,10 @@ def test__cut_block_by_epochs(self):
374375
original_block.segments = [seg, seg2]
375376
original_block.create_many_to_one_relationship()
376377

377-
block = cut_block_by_epochs(original_block, properties={'pick': 'me'})
378+
with warnings.catch_warnings(record=True) as w:
379+
# This should raise a warning as one segment does not contain epochs
380+
block = cut_block_by_epochs(original_block, properties={'pick': 'me'})
381+
self.assertEqual(len(w), 1)
378382

379383
assert_neo_object_is_compliant(block)
380384
self.assertEqual(len(block.segments), 3)
@@ -385,6 +389,10 @@ def test__cut_block_by_epochs(self):
385389
self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
386390
self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1)
387391

392+
annos = block.segments[epoch_idx].annotations
393+
# new segment objects have different identity
394+
self.assertNotIn('nix_name', annos)
395+
388396
if epoch_idx != 0:
389397
self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
390398
else:
@@ -414,8 +422,8 @@ def test__cut_block_by_epochs(self):
414422
t_stop=epoch.times[0] + epoch.durations[0]))
415423

416424
# test with resetting the time
417-
seg = Segment()
418-
seg2 = Segment(name='NoCut')
425+
seg = Segment(nix_name='neo.segment.0')
426+
seg2 = Segment(name='NoCut', nix_name='neo.segment.1')
419427
seg.epochs = [epoch, epoch2]
420428
seg.events = [event]
421429
seg.analogsignals = [anasig]
@@ -426,7 +434,10 @@ def test__cut_block_by_epochs(self):
426434
original_block.segments = [seg, seg2]
427435
original_block.create_many_to_one_relationship()
428436

429-
block = cut_block_by_epochs(original_block, properties={'pick': 'me'}, reset_time=True)
437+
with warnings.catch_warnings(record=True) as w:
438+
# This should raise a warning as one segment does not contain epochs
439+
block = cut_block_by_epochs(original_block, properties={'pick': 'me'}, reset_time=True)
440+
self.assertEqual(len(w), 1)
430441

431442
assert_neo_object_is_compliant(block)
432443
self.assertEqual(len(block.segments), 3)
@@ -436,6 +447,10 @@ def test__cut_block_by_epochs(self):
436447
self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1)
437448
self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
438449
self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1)
450+
451+
annos = block.segments[epoch_idx].annotations
452+
self.assertNotIn('nix_name', annos)
453+
439454
if epoch_idx != 0:
440455
self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
441456
else:
@@ -527,14 +542,18 @@ def test__add_epoch(self):
527542

528543
regular_event = Event(times=loaded_event.times - 1 * loaded_event.units)
529544

545+
loaded_event.annotate(nix_name='neo.event.0')
546+
regular_event.annotate(nix_name='neo.event.1')
547+
530548
seg = Segment()
531549
seg.events = [regular_event, proxy_event]
532550

533551
# test cutting with two events one of which is a proxy
534552
epoch = add_epoch(seg, regular_event, proxy_event)
535553

536554
assert_neo_object_is_compliant(epoch)
537-
assert_same_annotations(epoch, regular_event)
555+
exp_annos = {k: v for k, v in regular_event.annotations.items() if k != 'nix_name'}
556+
self.assertDictEqual(epoch.annotations, exp_annos)
538557
assert_arrays_almost_equal(epoch.times, regular_event.times, 1e-12)
539558
assert_arrays_almost_equal(epoch.durations,
540559
np.ones(regular_event.shape) * loaded_event.units, 1e-12)

neo/utils/misc.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
etc. of neo.core objects.
44
'''
55

6-
import neo
76
import copy
87
import warnings
8+
99
import numpy as np
1010
import quantities as pq
1111

12+
import neo
13+
14+
reserved_annotations = ['nix_name']
1215

1316
def get_events(container, **properties):
1417
"""
@@ -345,8 +348,8 @@ def add_epoch(
345348

346349
ep = neo.Epoch(times=times, durations=durations, **kwargs)
347350

348-
ep.annotate(**event1.annotations)
349-
ep.array_annotate(**event1.array_annotations)
351+
ep.annotate(**clean_annotations(event1.annotations))
352+
ep.array_annotate(**clean_annotations(event1.array_annotations))
350353

351354
if attach_result:
352355
segment.epochs.append(ep)
@@ -543,10 +546,11 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False):
543546
epoch.times[ep_id] + epoch.durations[ep_id],
544547
reset_time=reset_time)
545548

546-
subseg.annotate(**copy.copy(epoch.annotations))
549+
subseg.annotations = clean_annotations(subseg.annotations)
550+
subseg.annotate(**clean_annotations(epoch.annotations))
547551

548552
# Add array-annotations of Epoch
549-
for key, val in epoch.array_annotations.items():
553+
for key, val in clean_annotations(epoch.array_annotations).items():
550554
if len(val):
551555
subseg.annotations[key] = copy.copy(val[ep_id])
552556

@@ -555,6 +559,23 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False):
555559
return segments
556560

557561

562+
def clean_annotations(dictionary):
563+
"""
564+
Remove reserved keys from an annotation dictionary.
565+
566+
Parameters
567+
----------
568+
dictionary: dict
569+
annotation dictionary to be cleaned
570+
571+
Returns:
572+
--------
573+
dict
574+
A cleaned version of the annotations
575+
"""
576+
return {k: v for k, v in dictionary.items() if k not in reserved_annotations}
577+
578+
558579
def is_block_rawio_compatible(block, return_problems=False):
559580
"""
560581
The neo.rawio layer have some restriction compared to neo.io layer:

0 commit comments

Comments
 (0)