Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions neo/test/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import unittest
import warnings

import numpy as np
import quantities as pq
Expand Down Expand Up @@ -281,11 +282,11 @@ def test__get_epochs(self):

def test__add_epoch(self):
starts = Event(times=[0.5, 10.0, 25.2] * pq.s)
starts.annotate(event_type='trial start')
starts.annotate(event_type='trial start', nix_name='neo.event.0')
starts.array_annotate(trial_id=[1, 2, 3])

stops = Event(times=[5.5, 14.9, 30.1] * pq.s)
stops.annotate(event_type='trial stop')
stops.annotate(event_type='trial stop', nix_name='neo.event.1')
stops.array_annotate(trial_id=[1, 2, 3])

seg = Segment()
Expand All @@ -295,7 +296,7 @@ def test__add_epoch(self):
ep_starts = add_epoch(seg, starts, pre=-300 * pq.ms, post=250 * pq.ms)

assert_neo_object_is_compliant(ep_starts)
assert_same_annotations(ep_starts, starts)
self.assertDictEqual(ep_starts.annotations, {'event_type': 'trial start'})
assert_arrays_almost_equal(ep_starts.times, starts.times - 300 * pq.ms, 1e-12)
assert_arrays_almost_equal(ep_starts.durations,
(550 * pq.ms).rescale(ep_starts.durations.units)
Expand All @@ -305,7 +306,7 @@ def test__add_epoch(self):
ep_trials = add_epoch(seg, starts, stops)

assert_neo_object_is_compliant(ep_trials)
assert_same_annotations(ep_trials, starts)
self.assertDictEqual(ep_trials.annotations, {'event_type': 'trial start'})
assert_arrays_almost_equal(ep_trials.times, starts.times, 1e-12)
assert_arrays_almost_equal(ep_trials.durations, stops - starts, 1e-12)

Expand Down Expand Up @@ -337,16 +338,16 @@ def test__match_events(self):
def test__cut_block_by_epochs(self):
epoch = Epoch([0.5, 10.0, 25.2] * pq.s, durations=[5.1, 4.8, 5.0] * pq.s,
t_start=.1 * pq.s)
epoch.annotate(epoch_type='a', pick='me')
epoch.annotate(epoch_type='a', pick='me', nix_name='neo.epoch.0')
epoch.array_annotate(trial_id=[1, 2, 3])

epoch2 = Epoch([0.6, 9.5, 16.8, 34.1] * pq.s, durations=[4.5, 4.8, 5.0, 5.0] * pq.s,
t_start=.1 * pq.s)
epoch2.annotate(epoch_type='b')
epoch2.annotate(epoch_type='b', nix_name='neo.epoch.1')
epoch2.array_annotate(trial_id=[1, 2, 3, 4])

event = Event(times=[0.5, 10.0, 25.2] * pq.s, t_start=.1 * pq.s)
event.annotate(event_type='trial start')
event.annotate(event_type='trial start', nix_name='neo.event.0')
event.array_annotate(trial_id=[1, 2, 3])

anasig = AnalogSignal(np.arange(50.0) * pq.mV, t_start=.1 * pq.s,
Expand All @@ -362,8 +363,8 @@ def test__cut_block_by_epochs(self):
array_annotations={'spikenum': np.arange(1, 9)})

# test without resetting the time
seg = Segment()
seg2 = Segment(name='NoCut')
seg = Segment(nix_name='neo.segment.0')
seg2 = Segment(name='NoCut', nix_name='neo.segment.1')
seg.epochs = [epoch, epoch2]
seg.events = [event]
seg.analogsignals = [anasig]
Expand All @@ -374,7 +375,10 @@ def test__cut_block_by_epochs(self):
original_block.segments = [seg, seg2]
original_block.create_many_to_one_relationship()

block = cut_block_by_epochs(original_block, properties={'pick': 'me'})
with warnings.catch_warnings(record=True) as w:
# This should raise a warning as one segment does not contain epochs
block = cut_block_by_epochs(original_block, properties={'pick': 'me'})
self.assertEqual(len(w), 1)

assert_neo_object_is_compliant(block)
self.assertEqual(len(block.segments), 3)
Expand All @@ -385,6 +389,10 @@ def test__cut_block_by_epochs(self):
self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1)

annos = block.segments[epoch_idx].annotations
# new segment objects have different identity
self.assertNotIn('nix_name', annos)

if epoch_idx != 0:
self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
else:
Expand Down Expand Up @@ -414,8 +422,8 @@ def test__cut_block_by_epochs(self):
t_stop=epoch.times[0] + epoch.durations[0]))

# test with resetting the time
seg = Segment()
seg2 = Segment(name='NoCut')
seg = Segment(nix_name='neo.segment.0')
seg2 = Segment(name='NoCut', nix_name='neo.segment.1')
seg.epochs = [epoch, epoch2]
seg.events = [event]
seg.analogsignals = [anasig]
Expand All @@ -426,7 +434,10 @@ def test__cut_block_by_epochs(self):
original_block.segments = [seg, seg2]
original_block.create_many_to_one_relationship()

block = cut_block_by_epochs(original_block, properties={'pick': 'me'}, reset_time=True)
with warnings.catch_warnings(record=True) as w:
# This should raise a warning as one segment does not contain epochs
block = cut_block_by_epochs(original_block, properties={'pick': 'me'}, reset_time=True)
self.assertEqual(len(w), 1)

assert_neo_object_is_compliant(block)
self.assertEqual(len(block.segments), 3)
Expand All @@ -436,6 +447,10 @@ def test__cut_block_by_epochs(self):
self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1)
self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1)

annos = block.segments[epoch_idx].annotations
self.assertNotIn('nix_name', annos)

if epoch_idx != 0:
self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
else:
Expand Down Expand Up @@ -527,14 +542,18 @@ def test__add_epoch(self):

regular_event = Event(times=loaded_event.times - 1 * loaded_event.units)

loaded_event.annotate(nix_name='neo.event.0')
regular_event.annotate(nix_name='neo.event.1')

seg = Segment()
seg.events = [regular_event, proxy_event]

# test cutting with two events one of which is a proxy
epoch = add_epoch(seg, regular_event, proxy_event)

assert_neo_object_is_compliant(epoch)
assert_same_annotations(epoch, regular_event)
exp_annos = {k: v for k, v in regular_event.annotations.items() if k != 'nix_name'}
self.assertDictEqual(epoch.annotations, exp_annos)
assert_arrays_almost_equal(epoch.times, regular_event.times, 1e-12)
assert_arrays_almost_equal(epoch.durations,
np.ones(regular_event.shape) * loaded_event.units, 1e-12)
Expand Down
31 changes: 26 additions & 5 deletions neo/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
etc. of neo.core objects.
'''

import neo
import copy
import warnings

import numpy as np
import quantities as pq

import neo

reserved_annotations = ['nix_name']

def get_events(container, **properties):
"""
Expand Down Expand Up @@ -345,8 +348,8 @@ def add_epoch(

ep = neo.Epoch(times=times, durations=durations, **kwargs)

ep.annotate(**event1.annotations)
ep.array_annotate(**event1.array_annotations)
ep.annotate(**clean_annotations(event1.annotations))
ep.array_annotate(**clean_annotations(event1.array_annotations))

if attach_result:
segment.epochs.append(ep)
Expand Down Expand Up @@ -543,10 +546,11 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False):
epoch.times[ep_id] + epoch.durations[ep_id],
reset_time=reset_time)

subseg.annotate(**copy.copy(epoch.annotations))
subseg.annotations = clean_annotations(subseg.annotations)
subseg.annotate(**clean_annotations(epoch.annotations))

# Add array-annotations of Epoch
for key, val in epoch.array_annotations.items():
for key, val in clean_annotations(epoch.array_annotations).items():
if len(val):
subseg.annotations[key] = copy.copy(val[ep_id])

Expand All @@ -555,6 +559,23 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False):
return segments


def clean_annotations(dictionary):
"""
Remove reserved keys from an annotation dictionary.

Parameters
----------
dictionary: dict
annotation dictionary to be cleaned

Returns:
--------
dict
A cleaned version of the annotations
"""
return {k: v for k, v in dictionary.items() if k not in reserved_annotations}


def is_block_rawio_compatible(block, return_problems=False):
"""
The neo.rawio layer have some restriction compared to neo.io layer:
Expand Down