Skip to content

Commit 47a5aad

Browse files
Merge pull request #1000 from apdavison/spiketrainlist
Implement SpikeTrainList class, and make Segment.spiketrains an instance of this class
2 parents 3d8b139 + 96abd9f commit 47a5aad

File tree

8 files changed

+708
-49
lines changed

8 files changed

+708
-49
lines changed

neo/core/container.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from copy import deepcopy
99
from neo.core.baseneo import BaseNeo, _reference_name, _container_name
10+
from neo.core.spiketrain import SpikeTrain
11+
from neo.core.spiketrainlist import SpikeTrainList
1012

1113

1214
def unique_objs(objs):
@@ -83,7 +85,11 @@ def filterdata(data, targdict=None, objects=None, **kwargs):
8385
results = [result for result in results if
8486
result.__class__ in objects or
8587
result.__class__.__name__ in objects]
86-
return results
88+
89+
if results and all(isinstance(obj, SpikeTrain) for obj in results):
90+
return SpikeTrainList(results)
91+
else:
92+
return results
8793

8894

8995
class Container(BaseNeo):
@@ -411,7 +417,11 @@ def filter(self, targdict=None, data=True, container=False, recursive=True,
411417
data = True
412418
container = True
413419

414-
children = []
420+
if objects == SpikeTrain:
421+
children = SpikeTrainList()
422+
else:
423+
children = []
424+
415425
# get the objects we want
416426
if data:
417427
if recursive:

neo/core/segment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from copy import deepcopy
1414

1515
from neo.core.container import Container
16+
from neo.core.spiketrainlist import SpikeTrainList
1617

1718

1819
class Segment(Container):
@@ -89,8 +90,8 @@ def __init__(self, name=None, description=None, file_origin=None,
8990
Initialize a new :class:`Segment` instance.
9091
'''
9192
super().__init__(name=name, description=description,
92-
file_origin=file_origin, **annotations)
93-
93+
file_origin=file_origin, **annotations)
94+
self.spiketrains = SpikeTrainList(segment=self)
9495
self.file_datetime = file_datetime
9596
self.rec_datetime = rec_datetime
9697
self.index = index

neo/core/spiketrain.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,56 @@ def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None, copy=True,
110110
return obj
111111

112112

113+
def normalize_times_array(times, units=None, dtype=None, copy=True):
114+
"""
115+
Return a quantity array with the correct units.
116+
There are four scenarios:
117+
118+
A. times (NumPy array), units given as string or Quantities units
119+
B. times (Quantity array), units=None
120+
C. times (Quantity), units given as string or Quantities units
121+
D. times (NumPy array), units=None
122+
123+
In scenarios A-C we return a tuple (times as a Quantity array, dimensionality)
124+
In scenario C, we rescale the original array to match `units`
125+
In scenario D, we raise a ValueError
126+
"""
127+
if dtype is None:
128+
if not hasattr(times, 'dtype'):
129+
dtype = np.float
130+
if units is None:
131+
# No keyword units, so get from `times`
132+
try:
133+
dim = times.units.dimensionality
134+
except AttributeError:
135+
raise ValueError('you must specify units')
136+
else:
137+
if hasattr(units, 'dimensionality'):
138+
dim = units.dimensionality
139+
else:
140+
dim = pq.quantity.validate_dimensionality(units)
141+
142+
if hasattr(times, 'dimensionality'):
143+
if times.dimensionality.items() == dim.items():
144+
units = None # units will be taken from times, avoids copying
145+
else:
146+
if not copy:
147+
raise ValueError("cannot rescale and return view")
148+
else:
149+
# this is needed because of a bug in python-quantities
150+
# see issue # 65 in python-quantities github
151+
# remove this if it is fixed
152+
times = times.rescale(dim)
153+
154+
# check to make sure the units are time
155+
# this approach is orders of magnitude faster than comparing the
156+
# reference dimensionality
157+
if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0],
158+
pq.UnitTime)):
159+
ValueError("Units have dimensions %s, not [time]" % dim.simplified)
160+
return pq.Quantity(times, units=units, dtype=dtype, copy=copy), dim
161+
162+
113163
class SpikeTrain(DataObject):
114164
'''
115165
:class:`SpikeTrain` is a :class:`Quantity` array of spike times.
@@ -140,7 +190,7 @@ class SpikeTrain(DataObject):
140190
each spike.
141191
:units: (quantity units) Required if :attr:`times` is a list or
142192
:class:`~numpy.ndarray`, not if it is a
143-
:class:`~quantites.Quantity`.
193+
:class:`~quantities.Quantity`.
144194
:t_stop: (quantity scalar, numpy scalar, or float) Time at which
145195
:class:`SpikeTrain` ended. This will be converted to the
146196
same units as :attr:`times`. This argument is required because it
@@ -220,37 +270,7 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate
220270
# len(times)!=0 has been used to workaround a bug occuring during neo import
221271
raise ValueError("the number of waveforms should be equal to the number of spikes")
222272

223-
# Make sure units are consistent
224-
# also get the dimensionality now since it is much faster to feed
225-
# that to Quantity rather than a unit
226-
if units is None:
227-
# No keyword units, so get from `times`
228-
try:
229-
dim = times.units.dimensionality
230-
except AttributeError:
231-
raise ValueError('you must specify units')
232-
else:
233-
if hasattr(units, 'dimensionality'):
234-
dim = units.dimensionality
235-
else:
236-
dim = pq.quantity.validate_dimensionality(units)
237-
238-
if hasattr(times, 'dimensionality'):
239-
if times.dimensionality.items() == dim.items():
240-
units = None # units will be taken from times, avoids copying
241-
else:
242-
if not copy:
243-
raise ValueError("cannot rescale and return view")
244-
else:
245-
# this is needed because of a bug in python-quantities
246-
# see issue # 65 in python-quantities github
247-
# remove this if it is fixed
248-
times = times.rescale(dim)
249-
250-
if dtype is None:
251-
if not hasattr(times, 'dtype'):
252-
dtype = np.float_
253-
elif hasattr(times, 'dtype') and times.dtype != dtype:
273+
if dtype is not None and hasattr(times, 'dtype') and times.dtype != dtype:
254274
if not copy:
255275
raise ValueError("cannot change dtype and return view")
256276

@@ -264,15 +284,13 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate
264284
if hasattr(t_stop, 'dtype') and t_stop.dtype != times.dtype:
265285
t_stop = t_stop.astype(times.dtype)
266286

267-
# check to make sure the units are time
268-
# this approach is orders of magnitude faster than comparing the
269-
# reference dimensionality
270-
if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0],
271-
pq.UnitTime)):
272-
ValueError("Unit has dimensions %s, not [time]" % dim.simplified)
287+
# Make sure units are consistent
288+
# also get the dimensionality now since it is much faster to feed
289+
# that to Quantity rather than a unit
290+
times, dim = normalize_times_array(times, units, dtype, copy)
273291

274292
# Construct Quantity from data
275-
obj = pq.Quantity(times, units=units, dtype=dtype, copy=copy).view(cls)
293+
obj = times.view(cls)
276294

277295
# spiketrain times always need to be 1-dimensional
278296
if len(obj.shape) > 1:

0 commit comments

Comments
 (0)