@@ -110,6 +110,56 @@ def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None, copy=True,
110110 return obj
111111
112112
113+ def normalize_times_array (times , units = None , dtype = None , copy = True ):
114+ """
115+ Return a quantity array with the correct units.
116+ There are four scenarios:
117+
118+ A. times (NumPy array), units given as string or Quantities units
119+ B. times (Quantity array), units=None
120+ C. times (Quantity), units given as string or Quantities units
121+ D. times (NumPy array), units=None
122+
123+ In scenarios A-C we return a tuple (times as a Quantity array, dimensionality)
124+ In scenario C, we rescale the original array to match `units`
125+ In scenario D, we raise a ValueError
126+ """
127+ if dtype is None :
128+ if not hasattr (times , 'dtype' ):
129+ dtype = np .float
130+ if units is None :
131+ # No keyword units, so get from `times`
132+ try :
133+ dim = times .units .dimensionality
134+ except AttributeError :
135+ raise ValueError ('you must specify units' )
136+ else :
137+ if hasattr (units , 'dimensionality' ):
138+ dim = units .dimensionality
139+ else :
140+ dim = pq .quantity .validate_dimensionality (units )
141+
142+ if hasattr (times , 'dimensionality' ):
143+ if times .dimensionality .items () == dim .items ():
144+ units = None # units will be taken from times, avoids copying
145+ else :
146+ if not copy :
147+ raise ValueError ("cannot rescale and return view" )
148+ else :
149+ # this is needed because of a bug in python-quantities
150+ # see issue # 65 in python-quantities github
151+ # remove this if it is fixed
152+ times = times .rescale (dim )
153+
154+ # check to make sure the units are time
155+ # this approach is orders of magnitude faster than comparing the
156+ # reference dimensionality
157+ if (len (dim ) != 1 or list (dim .values ())[0 ] != 1 or not isinstance (list (dim .keys ())[0 ],
158+ pq .UnitTime )):
159+ ValueError ("Units have dimensions %s, not [time]" % dim .simplified )
160+ return pq .Quantity (times , units = units , dtype = dtype , copy = copy ), dim
161+
162+
113163class SpikeTrain (DataObject ):
114164 '''
115165 :class:`SpikeTrain` is a :class:`Quantity` array of spike times.
@@ -140,7 +190,7 @@ class SpikeTrain(DataObject):
140190 each spike.
141191 :units: (quantity units) Required if :attr:`times` is a list or
142192 :class:`~numpy.ndarray`, not if it is a
143- :class:`~quantites .Quantity`.
193+ :class:`~quantities .Quantity`.
144194 :t_stop: (quantity scalar, numpy scalar, or float) Time at which
145195 :class:`SpikeTrain` ended. This will be converted to the
146196 same units as :attr:`times`. This argument is required because it
@@ -220,37 +270,7 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate
220270 # len(times)!=0 has been used to workaround a bug occuring during neo import
221271 raise ValueError ("the number of waveforms should be equal to the number of spikes" )
222272
223- # Make sure units are consistent
224- # also get the dimensionality now since it is much faster to feed
225- # that to Quantity rather than a unit
226- if units is None :
227- # No keyword units, so get from `times`
228- try :
229- dim = times .units .dimensionality
230- except AttributeError :
231- raise ValueError ('you must specify units' )
232- else :
233- if hasattr (units , 'dimensionality' ):
234- dim = units .dimensionality
235- else :
236- dim = pq .quantity .validate_dimensionality (units )
237-
238- if hasattr (times , 'dimensionality' ):
239- if times .dimensionality .items () == dim .items ():
240- units = None # units will be taken from times, avoids copying
241- else :
242- if not copy :
243- raise ValueError ("cannot rescale and return view" )
244- else :
245- # this is needed because of a bug in python-quantities
246- # see issue # 65 in python-quantities github
247- # remove this if it is fixed
248- times = times .rescale (dim )
249-
250- if dtype is None :
251- if not hasattr (times , 'dtype' ):
252- dtype = np .float_
253- elif hasattr (times , 'dtype' ) and times .dtype != dtype :
273+ if dtype is not None and hasattr (times , 'dtype' ) and times .dtype != dtype :
254274 if not copy :
255275 raise ValueError ("cannot change dtype and return view" )
256276
@@ -264,15 +284,13 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate
264284 if hasattr (t_stop , 'dtype' ) and t_stop .dtype != times .dtype :
265285 t_stop = t_stop .astype (times .dtype )
266286
267- # check to make sure the units are time
268- # this approach is orders of magnitude faster than comparing the
269- # reference dimensionality
270- if (len (dim ) != 1 or list (dim .values ())[0 ] != 1 or not isinstance (list (dim .keys ())[0 ],
271- pq .UnitTime )):
272- ValueError ("Unit has dimensions %s, not [time]" % dim .simplified )
287+ # Make sure units are consistent
288+ # also get the dimensionality now since it is much faster to feed
289+ # that to Quantity rather than a unit
290+ times , dim = normalize_times_array (times , units , dtype , copy )
273291
274292 # Construct Quantity from data
275- obj = pq . Quantity ( times , units = units , dtype = dtype , copy = copy ) .view (cls )
293+ obj = times .view (cls )
276294
277295 # spiketrain times always need to be 1-dimensional
278296 if len (obj .shape ) > 1 :
0 commit comments