diff --git a/neo/core/container.py b/neo/core/container.py index 3c162a046..e916aaaea 100644 --- a/neo/core/container.py +++ b/neo/core/container.py @@ -7,6 +7,8 @@ from copy import deepcopy from neo.core.baseneo import BaseNeo, _reference_name, _container_name +from neo.core.spiketrain import SpikeTrain +from neo.core.spiketrainlist import SpikeTrainList def unique_objs(objs): @@ -83,7 +85,11 @@ def filterdata(data, targdict=None, objects=None, **kwargs): results = [result for result in results if result.__class__ in objects or result.__class__.__name__ in objects] - return results + + if results and all(isinstance(obj, SpikeTrain) for obj in results): + return SpikeTrainList(results) + else: + return results class Container(BaseNeo): @@ -411,7 +417,11 @@ def filter(self, targdict=None, data=True, container=False, recursive=True, data = True container = True - children = [] + if objects == SpikeTrain: + children = SpikeTrainList() + else: + children = [] + # get the objects we want if data: if recursive: diff --git a/neo/core/segment.py b/neo/core/segment.py index cf1712693..03969bec7 100644 --- a/neo/core/segment.py +++ b/neo/core/segment.py @@ -13,6 +13,7 @@ from copy import deepcopy from neo.core.container import Container +from neo.core.spiketrainlist import SpikeTrainList class Segment(Container): @@ -89,8 +90,8 @@ def __init__(self, name=None, description=None, file_origin=None, Initialize a new :class:`Segment` instance. ''' super().__init__(name=name, description=description, - file_origin=file_origin, **annotations) - + file_origin=file_origin, **annotations) + self.spiketrains = SpikeTrainList(segment=self) self.file_datetime = file_datetime self.rec_datetime = rec_datetime self.index = index diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index c0f228034..c6358817d 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -110,6 +110,56 @@ def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None, copy=True, return obj +def normalize_times_array(times, units=None, dtype=None, copy=True): + """ + Return a quantity array with the correct units. + There are four scenarios: + + A. times (NumPy array), units given as string or Quantities units + B. times (Quantity array), units=None + C. times (Quantity), units given as string or Quantities units + D. times (NumPy array), units=None + + In scenarios A-C we return a tuple (times as a Quantity array, dimensionality) + In scenario C, we rescale the original array to match `units` + In scenario D, we raise a ValueError + """ + if dtype is None: + if not hasattr(times, 'dtype'): + dtype = np.float + if units is None: + # No keyword units, so get from `times` + try: + dim = times.units.dimensionality + except AttributeError: + raise ValueError('you must specify units') + else: + if hasattr(units, 'dimensionality'): + dim = units.dimensionality + else: + dim = pq.quantity.validate_dimensionality(units) + + if hasattr(times, 'dimensionality'): + if times.dimensionality.items() == dim.items(): + units = None # units will be taken from times, avoids copying + else: + if not copy: + raise ValueError("cannot rescale and return view") + else: + # this is needed because of a bug in python-quantities + # see issue # 65 in python-quantities github + # remove this if it is fixed + times = times.rescale(dim) + + # check to make sure the units are time + # this approach is orders of magnitude faster than comparing the + # reference dimensionality + if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0], + pq.UnitTime)): + ValueError("Units have dimensions %s, not [time]" % dim.simplified) + return pq.Quantity(times, units=units, dtype=dtype, copy=copy), dim + + class SpikeTrain(DataObject): ''' :class:`SpikeTrain` is a :class:`Quantity` array of spike times. @@ -140,7 +190,7 @@ class SpikeTrain(DataObject): each spike. :units: (quantity units) Required if :attr:`times` is a list or :class:`~numpy.ndarray`, not if it is a - :class:`~quantites.Quantity`. + :class:`~quantities.Quantity`. :t_stop: (quantity scalar, numpy scalar, or float) Time at which :class:`SpikeTrain` ended. This will be converted to the same units as :attr:`times`. This argument is required because it @@ -220,37 +270,7 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate # len(times)!=0 has been used to workaround a bug occuring during neo import raise ValueError("the number of waveforms should be equal to the number of spikes") - # Make sure units are consistent - # also get the dimensionality now since it is much faster to feed - # that to Quantity rather than a unit - if units is None: - # No keyword units, so get from `times` - try: - dim = times.units.dimensionality - except AttributeError: - raise ValueError('you must specify units') - else: - if hasattr(units, 'dimensionality'): - dim = units.dimensionality - else: - dim = pq.quantity.validate_dimensionality(units) - - if hasattr(times, 'dimensionality'): - if times.dimensionality.items() == dim.items(): - units = None # units will be taken from times, avoids copying - else: - if not copy: - raise ValueError("cannot rescale and return view") - else: - # this is needed because of a bug in python-quantities - # see issue # 65 in python-quantities github - # remove this if it is fixed - times = times.rescale(dim) - - if dtype is None: - if not hasattr(times, 'dtype'): - dtype = np.float_ - elif hasattr(times, 'dtype') and times.dtype != dtype: + if dtype is not None and hasattr(times, 'dtype') and times.dtype != dtype: if not copy: raise ValueError("cannot change dtype and return view") @@ -264,15 +284,13 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate if hasattr(t_stop, 'dtype') and t_stop.dtype != times.dtype: t_stop = t_stop.astype(times.dtype) - # check to make sure the units are time - # this approach is orders of magnitude faster than comparing the - # reference dimensionality - if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0], - pq.UnitTime)): - ValueError("Unit has dimensions %s, not [time]" % dim.simplified) + # Make sure units are consistent + # also get the dimensionality now since it is much faster to feed + # that to Quantity rather than a unit + times, dim = normalize_times_array(times, units, dtype, copy) # Construct Quantity from data - obj = pq.Quantity(times, units=units, dtype=dtype, copy=copy).view(cls) + obj = times.view(cls) # spiketrain times always need to be 1-dimensional if len(obj.shape) > 1: diff --git a/neo/core/spiketrainlist.py b/neo/core/spiketrainlist.py new file mode 100644 index 000000000..f58040d18 --- /dev/null +++ b/neo/core/spiketrainlist.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +""" +This module implements :class:`SpikeTrainList`, a pseudo-list +which supports a multiplexed representation of spike trains +(all times in a single array, with a second array indicating which +neuron/channel the spike is from). +""" + +import numpy as np +import quantities as pq +from .spiketrain import SpikeTrain, normalize_times_array + + +def is_spiketrain_or_proxy(obj): + return isinstance(obj, SpikeTrain) or getattr(obj, "proxy_for", None) == SpikeTrain + + +class SpikeTrainList(object): + """ + This class contains multiple spike trains, and can represent them + either as a list of SpikeTrain objects or as a pair of arrays + (all spike times in a single array, with a second array indicating which + neuron/channel the spike is from). + + A SpikeTrainList object should behave like a list of SpikeTrains + for iteration and item access. It is not intended to be used directly + by users, but is available as the attribute `spiketrains` of Segments. + + Examples: + + # Create from list of SpikeTrain objects + + >>> stl = SpikeTrainList(items=( + ... SpikeTrain([0.5, 0.6, 23.6, 99.2], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + ... SpikeTrain([0.0007, 0.0112], units="s", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + ... SpikeTrain([1100, 88500], units="us", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + ... SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + ... )) + >>> stl.multiplexed + (array([0, 0, 0, 0, 1, 1, 2, 2]), + array([ 0.5, 0.6, 23.6, 99.2, 0.7, 11.2, 1.1, 88.5]) * ms) + + # Create from a pair of arrays + + >>> stl = SpikeTrainList.from_spike_time_array( + ... np.array([0.5, 0.6, 0.7, 1.1, 11.2, 23.6, 88.5, 99.2]), + ... np.array([0, 0, 1, 2, 1, 0, 2, 0]), + ... all_channel_ids=[0, 1, 2, 3], + ... units='ms', + ... t_start=0 * pq.ms, + ... t_stop=100.0 * pq.ms) + >>> list(stl) + [, + , + , + ] + + """ + + def __init__(self, items=None, segment=None): + """Initialize self""" + if items is None: + self._items = items + else: + for item in items: + if not is_spiketrain_or_proxy(item): + raise ValueError( + "`items` can only contain SpikeTrain objects or proxy pbjects") + self._items = list(items) + self._spike_time_array = None + self._channel_id_array = None + self._all_channel_ids = None + self._spiketrain_metadata = None + self.segment = segment + + def __iter__(self): + """Implement iter(self)""" + if self._items is None: + self._spiketrains_from_array() + for item in self._items: + yield item + + def __getitem__(self, i): + """x.__getitem__(y) <==> x[y]""" + if self._items is None: + self._spiketrains_from_array() + items = self._items[i] + if is_spiketrain_or_proxy(items): + return items + else: + return SpikeTrainList(items=items) + + def __str__(self): + """Return str(self)""" + if self._items is None: + if self._spike_time_array is None: + return str([]) + else: + return "SpikeTrainList containing {} spikes from {} neurons".format( + self._spike_time_array.size, + len(self._all_channel_ids)) + else: + return str(self._items) + + def __len__(self): + """Return len(self)""" + if self._items is None: + if self._all_channel_ids is not None: + return len(self._all_channel_ids) + elif self._channel_id_array is not None: + return np.unique(self._channel_id_array).size + else: + return 0 + else: + return len(self._items) + + def _add_spiketrainlists(self, other, in_place=False): + if self._spike_time_array is None or other._spike_time_array is None: + # if either self or other is not storing multiplexed spike trains + # we combine them using the list of SpikeTrains representation + if self._items is None: + self._spiketrains_from_array() + if other._items is None: + other._spiketrains_from_array() + if in_place: + self._items.extend(other._items) + return self + else: + return self.__class__(items=self._items[:] + other._items) + else: + # both self and other are storing multiplexed spike trains + # so we update the array representation + if self._spiketrain_metadata['t_start'] != other._spiketrain_metadata['t_start']: + raise ValueError("Incompatible t_start") + # todo: adjust times and t_start of other to be compatible with self + if self._spiketrain_metadata['t_stop'] != other._spiketrain_metadata['t_stop']: + raise ValueError("Incompatible t_stop") + # todo: adjust t_stop of self and other as necessary + combined_spike_time_array = np.hstack( + (self._spike_time_array, other._spike_time_array)) + combined_channel_id_array = np.hstack( + (self._channel_id_array, other._channel_id_array)) + combined_channel_ids = set(list(self._all_channel_ids) + other._all_channel_ids) + if len(combined_channel_ids) != ( + len(self._all_channel_ids) + len(other._all_channel_ids) + ): + raise ValueError("Duplicate channel ids, please rename channels before adding") + if in_place: + self._spike_time_array = combined_spike_time_array + self._channel_id_array = combined_channel_id_array + self._all_channel_ids = combined_channel_ids + self._items = None + return self + else: + return self.__class__.from_spike_time_array( + combined_spike_time_array, + combined_channel_id_array, + combined_channel_ids, + t_start=self._spiketrain_metadata['t_start'], + t_stop=self._spiketrain_metadata['t_stop']) + + def __add__(self, other): + """Return self + other""" + if isinstance(other, self.__class__): + return self._add_spiketrainlists(other) + elif other and is_spiketrain_or_proxy(other[0]): + return self._add_spiketrainlists( + self.__class__(items=other, segment=self.segment) + ) + else: + if self._items is None: + self._spiketrains_from_array() + return self._items + other + + def __iadd__(self, other): + """Return self""" + if isinstance(other, self.__class__): + return self._add_spiketrainlists(other, in_place=True) + elif other and is_spiketrain_or_proxy(other[0]): + for obj in other: + obj.segment = self.segment + if self._items is None: + self._spiketrains_from_array() + self._items.extend(other) + return self + else: + raise TypeError("Can only add a SpikeTrainList or a list of SpikeTrains in place") + + def __radd__(self, other): + """Return other + self""" + if isinstance(other, self.__class__): + return other._add_spiketrainlists(self) + elif other and is_spiketrain_or_proxy(other[0]): + for obj in other: + obj.segment = self.segment + if self._items is None: + self._spiketrains_from_array() + self._items.extend(other) + return self + elif len(other) == 0: + return self + else: + if self._items is None: + self._spiketrains_from_array() + return other + self._items + + def append(self, obj): + """L.append(object) -> None -- append object to end""" + if not is_spiketrain_or_proxy(obj): + raise ValueError("Can only append SpikeTrain objects") + if self._items is None: + self._spiketrains_from_array() + obj.segment = self.segment + self._items.append(obj) + + def extend(self, iterable): + """L.extend(iterable) -> None -- extend list by appending elements from the iterable""" + if self._items is None: + self._spiketrains_from_array() + for obj in iterable: + obj.segment = self.segment + self._items.extend(iterable) + + @classmethod + def from_spike_time_array(cls, spike_time_array, channel_id_array, + all_channel_ids, t_stop, units=None, + t_start=0.0 * pq.s, **annotations): + """Create a SpikeTrainList object from an array of spike times + and an array of channel ids. + + *Required attributes/properties*: + + :spike_time_array: (quantity array 1D, numpy array 1D, or list) The times of + all spikes. + :channel_id_array: (numpy array 1D of dtype int) The id of the channel (e.g. the + neuron) to which each spike belongs. This array should have the same length + as :attr:`spike_time_array` + :all_channel_ids: (list, tuple, or numpy array 1D containing integers) All + channel ids. This is needed to represent channels in which there are no + spikes. + :units: (quantity units) Required if :attr:`spike_time_array` is not a + :class:`~quantities.Quantity`. + :t_stop: (quantity scalar, numpy scalar, or float) Time at which + spike recording ended. This will be converted to the + same units as :attr:`spike_time_array` or :attr:`units`. + + *Recommended attributes/properties*: + :t_start: (quantity scalar, numpy scalar, or float) Time at which + spike recording began. This will be converted to the + same units as :attr:`spike_time_array` or :attr:`units`. + Default: 0.0 seconds. + + + *Optional attributes/properties*: + """ + spike_time_array, dim = normalize_times_array(spike_time_array, units) + obj = cls() + obj._spike_time_array = spike_time_array + obj._channel_id_array = channel_id_array + obj._all_channel_ids = all_channel_ids + obj._spiketrain_metadata = { + "t_start": t_start, + "t_stop": t_stop + } + for name, ann_value in annotations.items(): + if len(ann_value) != len(obj): + raise ValueError(f"incorrect length for annotation '{name}'") + obj._annotations = annotations + return obj + + def _spiketrains_from_array(self): + """Convert multiplexed spike time data into a list of SpikeTrain objects""" + if self._spike_time_array is None: + self._items = [] + else: + self._items = [] + for i, channel_id in enumerate(self._all_channel_ids): + mask = self._channel_id_array == channel_id + times = self._spike_time_array[mask] + spiketrain = SpikeTrain(times, **self._spiketrain_metadata) + spiketrain.annotations = { + name: value[i] + for name, value in self._annotations.items() + } + spiketrain.annotate(channel_id=channel_id) + spiketrain.segment = self.segment + self._items.append(spiketrain) + + @property + def multiplexed(self): + """Return spike trains as a pair of arrays. + + The first (plain NumPy) array contains the ids of the channels/neurons that produced + each spike, the second (Quantity) array contains the times of the spikes. + """ + if self._spike_time_array is None: + # need to convert list of SpikeTrains into multiplexed spike times array + if self._items is None: + return np.array([]), np.array([]) + else: + channel_ids = [] + spike_times = [] + dim = self._items[0].units.dimensionality + for i, spiketrain in enumerate(self._items): + if hasattr(spiketrain, "load"): # proxy object + spiketrain = spiketrain.load() + if spiketrain.times.dimensionality.items() == dim.items(): + # no need to rescale + spike_times.append(spiketrain.times) + else: + spike_times.append(spiketrain.times.rescale(dim)) + if ("channel_id" in spiketrain.annotations + and isinstance(spiketrain.annotations["channel_id"], int) + ): + ch_id = spiketrain.annotations["channel_id"] + else: + ch_id = i + channel_ids.append(ch_id * np.ones(spiketrain.shape, dtype=np.int64)) + self._spike_time_array = np.hstack(spike_times) * self._items[0].units + self._channel_id_array = np.hstack(channel_ids) + return self._channel_id_array, self._spike_time_array diff --git a/neo/test/coretest/test_segment.py b/neo/test/coretest/test_segment.py index 4a6fa703a..ecda68a31 100644 --- a/neo/test/coretest/test_segment.py +++ b/neo/test/coretest/test_segment.py @@ -21,6 +21,7 @@ from neo.core.segment import Segment from neo.core import (AnalogSignal, Block, Event, IrregularlySampledSignal, Epoch, SpikeTrain) +from neo.core.spiketrainlist import SpikeTrainList from neo.core.container import filterdata from neo.test.tools import (assert_neo_object_is_compliant, assert_same_sub_schema, assert_same_attributes) @@ -151,6 +152,12 @@ def test__filter_none(self): targ.extend(segment.spiketrains) targ.extend(segment.imagesequences) + # occasionally we randomly get only spike trains, + # and then we have to convert to a SpikeTrainList + # to match the output of segment.filter + if all(isinstance(obj, SpikeTrain) for obj in targ): + targ = SpikeTrainList(items=targ, segment=segment) + res0 = segment.filter() res1 = segment.filter({}) res2 = segment.filter([]) @@ -286,8 +293,12 @@ def test__filter_multi_partres(self): def test__filter_no_annotation_but_object(self): for segment in self.segments: targ = segment.spiketrains + assert isinstance(targ, SpikeTrainList) res = segment.filter(objects=SpikeTrain) - assert_same_sub_schema(res, targ) + if len(res) > 0: + # if res has length 0 it will be just a plain list + assert isinstance(res, SpikeTrainList) + assert_same_sub_schema(res, targ) targ = segment.analogsignals res = segment.filter(objects=AnalogSignal) @@ -295,8 +306,8 @@ def test__filter_no_annotation_but_object(self): targ = segment.analogsignals + segment.spiketrains res = segment.filter(objects=[AnalogSignal, SpikeTrain]) - assert_same_sub_schema(res, targ) - assert_same_sub_schema(res, targ) + if len(res) > 0: + assert_same_sub_schema(res, targ) def test__filter_single_annotation_obj_single(self): segment = simple_block().segments[0] diff --git a/neo/test/coretest/test_spiketrainlist.py b/neo/test/coretest/test_spiketrainlist.py new file mode 100644 index 000000000..4172bcbd5 --- /dev/null +++ b/neo/test/coretest/test_spiketrainlist.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +""" +Tests of the neo.core.spiketrainlist.SpikeTrainList class +""" + +import sys + +import unittest +import warnings +from copy import deepcopy + +import numpy as np +from numpy.testing import assert_array_equal +import quantities as pq + +from neo.core.spiketrain import SpikeTrain +from neo.core.spiketrainlist import SpikeTrainList +from neo.io.proxyobjects import SpikeTrainProxy + + +class MockRawIO(object): + raw_annotations = { + "blocks": [{ + "segments": [{ + "spikes": [{ + "__array_annotations__": {} + }] + }] + }] + } + header = { + "spike_channels": [{ + 'wf_sampling_rate': 5, + 'wf_left_sweep': 3, + 'wf_units': "mV" + }], + } + + def source_name(self): + return "name_of_source" + + def segment_t_start(self, block_index=0, seg_index=0): + return 0 + + def segment_t_stop(self, block_index=0, seg_index=0): + return 100.0 + + def spike_count(self, block_index=0, seg_index=0, spike_channel_index=0): + return 2 + + def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0, + t_start=None, t_stop=None): + return np.array([0.0011, 0.0885]) + + def rescale_spike_timestamp(self, spike_timestamps, dtype='float64'): + return spike_timestamps * pq.s + + +class TestSpikeTrainList(unittest.TestCase): + + def setUp(self): + spike_time_array = np.array([0.5, 0.6, 0.7, 1.1, 11.2, 23.6, 88.5, 99.2]) + channel_id_array = np.array([0, 0, 1, 2, 1, 0, 2, 0]) + all_channel_ids = (0, 1, 2, 3) + self.stl_from_array = SpikeTrainList.from_spike_time_array( + spike_time_array, + channel_id_array, + all_channel_ids=all_channel_ids, + units='ms', + t_start=0 * pq.ms, + t_stop=100.0 * pq.ms, + identifier=["A", "B", "C", "D"] # annotation + ) + + self.stl_from_obj_list = SpikeTrainList(items=( + SpikeTrain([0.5, 0.6, 23.6, 99.2], units="ms", + t_start=0 * pq.ms, t_stop=100.0 * pq.ms, channel_id=101), + SpikeTrain([0.0007, 0.0112], units="s", t_start=0 * pq.ms, t_stop=100.0 * pq.ms, + channel_id=102), + SpikeTrain([1100, 88500], units="us", t_start=0 * pq.ms, t_stop=100.0 * pq.ms, + channel_id=103), + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms, + channel_id=104), + )) + + self.stl_from_obj_list_incl_proxy = SpikeTrainList(items=( + SpikeTrain([0.5, 0.6, 23.6, 99.2], units="ms", + t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([0.0007, 0.0112], units="s", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrainProxy(rawio=MockRawIO(), spike_channel_index=0), + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + )) + + def test_create_from_spiketrain_array(self): + self.assertEqual(type(self.stl_from_array._spike_time_array), pq.Quantity) + as_list = list(self.stl_from_array) + assert_array_equal(as_list[0].times.magnitude, + np.array([0.5, 0.6, 23.6, 99.2])) + assert_array_equal(as_list[1].times.magnitude, + np.array([0.7, 11.2])) + assert_array_equal(as_list[2].times.magnitude, + np.array([1.1, 88.5])) + assert_array_equal(as_list[3].times.magnitude, + np.array([])) + self.assertEqual(as_list[0].annotations["identifier"], "A") + self.assertEqual(as_list[1].annotations["identifier"], "B") + self.assertEqual(as_list[2].annotations["identifier"], "C") + self.assertEqual(as_list[3].annotations["identifier"], "D") + + def test_create_from_spiketrain_list(self): + as_list = list(self.stl_from_obj_list) + assert_array_equal(as_list[0].times.rescale(pq.ms).magnitude, + np.array([0.5, 0.6, 23.6, 99.2])) + assert_array_equal(as_list[1].times.rescale(pq.ms).magnitude, + np.array([0.7, 11.2])) + assert_array_equal(as_list[2].times.rescale(pq.ms).magnitude, + np.array([1.1, 88.5])) + assert_array_equal(as_list[3].times.rescale(pq.ms).magnitude, + np.array([])) + + def test_create_from_spiketrain_list_incl_proxy(self): + as_list = list(self.stl_from_obj_list_incl_proxy) + assert_array_equal(as_list[0].times.rescale(pq.ms).magnitude, + np.array([0.5, 0.6, 23.6, 99.2])) + assert_array_equal(as_list[1].times.rescale(pq.ms).magnitude, + np.array([0.7, 11.2])) + assert isinstance(as_list[2], SpikeTrainProxy) + assert_array_equal(as_list[3].times.rescale(pq.ms).magnitude, + np.array([])) + + def test_str(self): + target = "SpikeTrainList containing 8 spikes from 4 neurons" + self.assertEqual(target, str(self.stl_from_array)) + target = ("[," + " ," + " ," + " ]" + ) + self.assertEqual(target, str(self.stl_from_obj_list)) + + def test_get_single_item(self): + """Indexing a SpikeTrainList with a single integer should return a SpikeTrain""" + for stl in (self.stl_from_obj_list, self.stl_from_array): + st = stl[1] + assert isinstance(st, SpikeTrain) + assert_array_equal(st.times.rescale(pq.ms).magnitude, np.array([0.7, 11.2])) + + def test_get_slice(self): + """Slicing a SpikeTrainList should return a SpikeTrainList""" + for stl in (self.stl_from_obj_list, self.stl_from_array): + new_stl = stl[1:3] + self.assertIsInstance(new_stl, SpikeTrainList) + self.assertEqual(len(new_stl), 2) + + def test_len(self): + for stl in (self.stl_from_obj_list, self.stl_from_array): + self.assertEqual(len(stl), 4) + + def test_add_spiketrainlists(self): + """Adding two SpikeTrainLists should return a new SpikeTrainList object, + whatever the internal representation being used by the two SpikeTrainLists.""" + a = self.stl_from_array + b = self.stl_from_obj_list_incl_proxy + c = a + b + self.assertEqual(len(c), 8) + self.assertEqual(len(a), 4) + self.assertNotEqual(id(c), id(a)) + + c = b + a + self.assertEqual(len(c), 8) + self.assertEqual(len(b), 4) + self.assertNotEqual(id(c), id(b)) + + b = deepcopy(a) + b._all_channel_ids = [5, 6, 7, 8] + c = a + b + self.assertEqual(len(c), 8) + self.assertEqual(len(a), 4) + self.assertNotEqual(id(c), id(a)) + + def test_iadd_spiketrainlists(self): + """Adding a SpikeTrainLists to another in place should + return the first SpikeTrainList object""" + a = deepcopy(self.stl_from_array) + b = self.stl_from_obj_list_incl_proxy + c = a + c += b + self.assertEqual(len(c), 8) + self.assertEqual(len(a), 8) + self.assertEqual(len(b), 4) + self.assertEqual(id(c), id(a)) + + a = self.stl_from_array + b = deepcopy(self.stl_from_obj_list_incl_proxy) + c = b + c += a + self.assertEqual(len(c), 8) + self.assertEqual(len(b), 8) + self.assertEqual(len(a), 4) + self.assertEqual(id(c), id(b)) + + a = deepcopy(self.stl_from_array) + b = deepcopy(a) + b._all_channel_ids = [5, 6, 7, 8] + c = a + c += b + self.assertEqual(len(c), 8) + self.assertEqual(len(a), 8) + self.assertEqual(id(c), id(a)) + + def test_add_list_of_spiketrains(self): + """Adding a list of SpikeTrains to a SpikeTrainList should return a new SpikeTrainList""" + extended_stl = self.stl_from_array + [ + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([22.2, 33.3], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), ] + self.assertIsInstance(extended_stl, SpikeTrainList) + self.assertEqual(len(extended_stl), 7) + self.assertNotEqual(id(extended_stl), id(self.stl_from_array)) + + extended_stl = self.stl_from_obj_list_incl_proxy + [ + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([22.2, 33.3], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms)] + self.assertIsInstance(extended_stl, SpikeTrainList) + self.assertEqual(len(extended_stl), 7) + + def test_iadd_list_of_spiketrains(self): + """Adding a list of SpikeTrains to a SpikeTrainList in place + should return the original SpikeTrainList""" + extended_stl = deepcopy(self.stl_from_array) + extended_stl += [ + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([22.2, 33.3], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), ] + self.assertIsInstance(extended_stl, SpikeTrainList) + self.assertEqual(len(extended_stl), 7) + + def test_add_list_of_something_else(self): + """Adding something that is not a list of SpikeTrains to a SpikeTrainList + should return a plain list""" + bag = self.stl_from_array + ["apples", "bananas"] + self.assertIsInstance(bag, list) + + def test_radd_list_of_spiketrains(self): + """ """ + extended_stl = [ + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([22.2, 33.3], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms) + ] + self.stl_from_array + self.assertIsInstance(extended_stl, SpikeTrainList) + self.assertEqual(len(extended_stl), 7) + + extended_stl = [ + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([22.2, 33.3], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms), + SpikeTrain([], units="ms", t_start=0 * pq.ms, t_stop=100.0 * pq.ms) + ] + self.stl_from_obj_list_incl_proxy + self.assertIsInstance(extended_stl, SpikeTrainList) + self.assertEqual(len(extended_stl), 7) + + def test_radd_list_of_something_else(self): + """Adding a SpikeTrainList to something that is not a list of SpikeTrains + should return a plain list""" + bag = ["apples", "bananas"] + self.stl_from_array + self.assertIsInstance(bag, list) + + def test_append(self): + """Appending a SpikeTrain to a SpikeTrainList should make the STL longer""" + for stl in (self.stl_from_obj_list, self.stl_from_array): + stl.append(SpikeTrain([22.2, 33.3], units="ms", + t_start=0 * pq.ms, t_stop=100.0 * pq.ms)) + self.assertEqual(len(stl), 5) + + def test_append_something_else(self): + """Trying to append something other than a SpikeTrain should raise an Exception""" + for stl in (self.stl_from_obj_list, self.stl_from_array): + self.assertRaises(ValueError, stl.append, None) + + def test_multiplexed(self): + """The multiplexed property should return a pair of arrays""" + channel_id_array, spike_time_array = self.stl_from_array.multiplexed + assert type(spike_time_array) == pq.Quantity + assert type(channel_id_array) == np.ndarray + assert_array_equal(channel_id_array, np.array([0, 0, 1, 2, 1, 0, 2, 0])) + assert_array_equal(spike_time_array, np.array( + [0.5, 0.6, 0.7, 1.1, 11.2, 23.6, 88.5, 99.2]) * pq.ms) + + channel_id_array, spike_time_array = self.stl_from_obj_list.multiplexed + assert type(spike_time_array) == pq.Quantity + assert type(channel_id_array) == np.ndarray + assert_array_equal(channel_id_array, np.array([101, 101, 101, 101, 102, 102, 103, 103])) + assert_array_equal(spike_time_array, np.array( + [0.5, 0.6, 23.6, 99.2, 0.7, 11.2, 1.1, 88.5]) * pq.ms) diff --git a/neo/test/generate_datasets.py b/neo/test/generate_datasets.py index fdda4a394..904118f08 100644 --- a/neo/test/generate_datasets.py +++ b/neo/test/generate_datasets.py @@ -10,8 +10,10 @@ import quantities as pq from neo.core import (AnalogSignal, Block, Epoch, Event, IrregularlySampledSignal, Group, - Segment, SpikeTrain, ImageSequence, CircularRegionOfInterest, ChannelView, - RectangularRegionOfInterest, PolygonRegionOfInterest, class_by_name) + Segment, SpikeTrain, ImageSequence, ChannelView, + CircularRegionOfInterest, RectangularRegionOfInterest, + PolygonRegionOfInterest, class_by_name) +from neo.core.spiketrainlist import SpikeTrainList from neo.core.baseneo import _container_name from neo.core.dataobject import DataObject diff --git a/neo/test/tools.py b/neo/test/tools.py index af3a10d2f..0fcbee5b1 100644 --- a/neo/test/tools.py +++ b/neo/test/tools.py @@ -13,6 +13,7 @@ from neo.core.baseneo import _reference_name, _container_name from neo.core.basesignal import BaseSignal from neo.core.container import Container +from neo.core.spiketrainlist import SpikeTrainList from neo.io.basefromrawio import proxyobjectlist, EventProxy, EpochProxy @@ -194,7 +195,7 @@ def assert_same_sub_schema(ob1, ob2, equal_almost=True, threshold=1e-10, exclude if exclude is None: exclude = [] - if isinstance(ob1, list): + if isinstance(ob1, (list, SpikeTrainList)): assert len(ob1) == len(ob2), 'lens %s and %s not equal for %s and %s' \ '' % (len(ob1), len(ob2), ob1, ob2) for i, (sub1, sub2) in enumerate(zip(ob1, ob2)):