Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion neo/core/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from neo.core.container import Container
from neo.core.spiketrainlist import SpikeTrainList


class Segment(Container):
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(self, name=None, description=None, file_origin=None,
'''
super(Segment, self).__init__(name=name, description=description,
file_origin=file_origin, **annotations)

self.spiketrains = SpikeTrainList(segment=self)
self.file_datetime = file_datetime
self.rec_datetime = rec_datetime
self.index = index
Expand Down
150 changes: 150 additions & 0 deletions neo/core/spiketrainlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
"""
This module implements :class:`SpikeTrainList`, a pseudo-list
which takes care of relationships between Neo parent-child objects.

In addition, it supports a multiplexed representation of spike trains
(all times in a single array, with a second array indicating which
neuron/channel the spike is from).
"""

import numpy as np
from .spiketrain import SpikeTrain


class SpikeTrainList(object):
"""
docstring needed
"""

def __init__(self, items=None, segment=None):
"""Initialize self"""
self._items = items
self._spike_time_array = None
self._channel_id_array = None
self._all_channel_ids = None
self._spiketrain_metadata = None
self.segment = segment

def __iter__(self):
"""Implement iter(self)"""
if self._items is None:
self._spiketrains_from_array()
for item in self._items:
yield item

def __getitem__(self, i):
"""x.__getitem__(y) <==> x[y]"""
if self._items is None:
self._spiketrains_from_array()
return self._items[i]

def __str__(self):
"""Return str(self)"""
if self._items is None:
if self._spike_time_array is None:
return str([])
else:
return "SpikeTrainList containing {} spikes from {} neurons".format(
self._spike_time_array.size,
self._channel_id_array.size)
else:
return str(self._items)

def __len__(self):
"""Return len(self)"""
if self._items is None:
if self._all_channel_ids is not None:
return len(self._all_channel_ids)
elif self._channel_id_array is not None:
return np.unique(self._channel_id_array).size
else:
return 0
else:
return len(self._items)

def __add__(self, other):
"""Return self + other"""
if isinstance(other, self.__class__):
if self._items is None or other._items is None:
# todo: update self._spike_time_array, etc.
raise NotImplementedError
else:
self._items.extend(other._items)
return self
elif other and isinstance(other[0], SpikeTrain):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle this would need to check if other is iterable and assert that all contained items are SpikeTrain instances, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this was a shortcut, but it would be more robust to do a full check

for obj in other:
obj.segment = self.segment
self._items.extend(other)
return self
else:
return self._items + other

def __radd__(self, other):
"""Return other + self"""
if self._items is None:
self._spiketrains_from_array()
other.extend(self._items)
return other

def append(self, obj):
"""L.append(object) -> None -- append object to end"""
if not isinstance(obj, SpikeTrain):
raise ValueError("Can only append SpikeTrain objects")
if self._items is None:
self._spiketrains_from_array()
obj.segment = self.segment
self._items.append(obj)

def extend(self, iterable):
"""L.extend(iterable) -> None -- extend list by appending elements from the iterable"""
if self._items is None:
self._spiketrains_from_array()
for obj in iterable:
obj.segment = self.segment
self._items.extend(iterable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's necessary to also verify that the items of iterable are instances of SpikeTrain? Otherwise arbitrary objects might be in the SpikeTrainList...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, more robust checks are needed


@classmethod
def from_spike_time_array(cls, spike_time_array, channel_id_array,
all_channel_ids=None, units='ms',
t_start=None, t_stop=None):
"""Create a SpikeTrainList object from an array of spike times
and an array of channel ids."""
obj = cls()
obj._spike_time_array = spike_time_array
obj._channel_id_array = channel_id_array
obj._all_channel_ids = all_channel_ids
obj._spiketrain_metadata = {
"units": units,
"t_start": t_start,
"t_stop": t_stop
}
return obj

def _spiketrains_from_array(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is quite expensive, as it generates all Spiketrain objects explicitly. As it is used in many of Spiketrainlist methods, this would cause quite a bit of overhead. Maybe it would be good to use it in only as few places as possible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only ever used once, because the generated SpikeTrain objects are cached in self._items

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. So in this PR the SpikeTrain representation is still the reference all data will be converted to eventually. Even when merging two SpikeTrainLists coming from a spike_time_array representation. Is there any way we can avoid that? Maybe having a default operation mode that can switch between SpikeTrain and spike_time_list? Or would it be worth to have the _spike_time_array and the corresponding masks for individual channels / units as base representation and only generate SpikeTrains when required (without) caching / duplicating the data? This would rely even stronger on numpy for inserting / reordering of the spike_times lists and masks. I am not sure which method would be more performant in the end. @apdavison Do you have a better intuition there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right that merging two array-representation SpikeTrainLists should avoid creating SpikeTrains. I think that's easy enough to do.

The general idea is to keep data in the representation they arrive in as long as possible, to avoid unecessary transformations, i.e. the "reference/base" representation depends on how the object was initialized.

Note that the multiplexed property (not yet implemented in this branch, but you can see it in the spiketrainlist branch which will replace this PR) allows users to access the array representation.

"""Convert multiplexed spike time data into a list of SpikeTrain objects"""
if self._spike_time_array is None:
self._items = []
else:
if self._all_channel_ids is None:
all_channel_ids = np.unique(self._channel_id_array)
else:
all_channel_ids = self._all_channel_ids
for channel_id in all_channel_ids:
mask = self._channel_id_array == channel_id
times = self._spike_time_array[mask]
spiketrain = SpikeTrain(times, **self._spiketrain_metadata)
spiketrain.segment = self.segment
self._items.append(spiketrain)

@property
def multiplexed(self):
"""Return spike trains as a pair of arrays.

The first array contains the ids of the channels/neurons that produced each spike,
the second array contains the times of the spikes.
"""
if self._spike_time_array is None:
# need to convert list of SpikeTrains into multiplexed spike times array
raise NotImplementedError
return self._channel_id_array, self._spike_time_array
7 changes: 4 additions & 3 deletions neo/test/coretest/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def test__merge(self):
seg1a.epochs.append(self.epcs2[0])
seg1a.annotate(seed=self.seed2)
seg1a.merge(self.seg2)
self.check_creation(self.seg2)
self.check_creation(self.seg2) # arguably we're checking the wrong thing here
#self.check_creation(seg1a) # should be checking seg1a

assert_same_sub_schema(self.sigarrs1a + self.sigarrs2,
seg1a.analogsignals)
Expand Down Expand Up @@ -481,15 +482,15 @@ def test__filter_multi_partres(self):
assert_same_sub_schema(res5, targ)

def test__filter_no_annotation_but_object(self):
targ = self.targobj.spiketrains
targ = list(self.targobj.spiketrains)
res = self.targobj.filter(objects=SpikeTrain)
assert_same_sub_schema(res, targ)

targ = self.targobj.analogsignals
res = self.targobj.filter(objects=AnalogSignal)
assert_same_sub_schema(res, targ)

targ = self.targobj.analogsignals + self.targobj.spiketrains
targ = self.targobj.analogsignals + list(self.targobj.spiketrains)
res = self.targobj.filter(objects=[AnalogSignal, SpikeTrain])
assert_same_sub_schema(res, targ)
assert_same_sub_schema(res, targ)
Expand Down
3 changes: 2 additions & 1 deletion neo/test/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import neo
from neo.core import objectlist
from neo.core.baseneo import _reference_name, _container_name
from neo.core.spiketrainlist import SpikeTrainList


def assert_arrays_equal(a, b, dtype=False):
Expand Down Expand Up @@ -197,7 +198,7 @@ def assert_same_sub_schema(ob1, ob2, equal_almost=True, threshold=1e-10,
if exclude is None:
exclude = []

if isinstance(ob1, list):
if isinstance(ob1, (list, SpikeTrainList)):
assert len(ob1) == len(ob2), \
'lens %s and %s not equal for %s and %s' % \
(len(ob1), len(ob2), ob1, ob2)
Expand Down