diff --git a/demo/coeffs.py b/demo/coeffs.py new file mode 100644 index 00000000..07865517 --- /dev/null +++ b/demo/coeffs.py @@ -0,0 +1,79 @@ +import pywt, numpy, pylab, time +import sample_data + +interpolation = 'nearest' #'bilinear' +cmap = pylab.cm.jet + +absolute_values = 0 +normalize = 0 + +scales = range(4, 129, 2) +sample = 1 + +if sample == 1: + x = numpy.linspace(-300, 1300, 1024*1) + data = 2*numpy.sin(2*numpy.pi/4 * x) * numpy.exp(-(x-400)**2/(2*300**2)) + \ + numpy.sin(2*numpy.pi/32*x) * numpy.exp(-(x-700)**2/(2*100**2)) + \ + numpy.sin(2*numpy.pi/32 * (x/(1+x/1000)) ) +elif sample == 2: + data = sample_data.cuspamax + data = data + x = range(len(data)) +elif sample == 3: + data = sample_data.linchirp + x = range(len(data)) +elif sample == 4: + data = sample_data.ecg + x = range(len(data)) + +wavelets = ['db1', 'sym5', 'coif3', 'bior1.5', 'bior3.3', 'bior4.4'][:2] +#, pywt.cwt.morlet(len(data)), pywt.cwt.mexican_hat(len(data))]#[:4]#[-1:] +wavelets = [pywt.cwt.cmorlet(len(data)/4, 1.0, 1.0), + pywt.cwt.cmorlet(len(data)/4, 4.0, 1.0), + pywt.cwt.cmorlet(len(data)/4, 1.0, 4.0)]#, pywt.cwt.gauss1(len(data))] #, 'coif3', 'bior1.5', 'bior3.3', 'bior4.4', pywt.cwt.morlet(len(data)), pywt.cwt.mexican_hat(len(data))]#[:4]#[-1:] +#wavelets = [pywt.cwt.morlet(len(data)), pywt.cwt.mexican_hat(len(data)), pywt.cwt.gauss1(len(data)), pywt.cwt.gauss2(len(data)), pywt.cwt.gauss3(len(data))] +#wavelets = ['cmorlet2-2', 'cmorlet1-1'] +#wavelets = ['mexican_hat', 'cmorlet2-2', 'cmorlet1-1'] +#wavelets = ['cfbsp4-0.7-1', pywt.cwt.cfbsp(len(data), 4, 0.7, 1)]#, 'cmorlet1-1'] +#wavelets = [pywt.cwt.cmorlet(2**8, 2, 2)] + +for name in ['mexican_hat', 'morlet', 'gauss1', 'gauss2', 'gauss3', 'cfbsp1-1-1', 'cmorlet1-1', 'cshannon1-1']: + print name, + t = time.clock() + #pywt.cwt.CWavelet(name).wavefun(19) + print time.clock() - t + +#p,x = pywt.cwt.CWavelet('cfbsp1-1-1').wavefun() +#pylab.plot(x,p.real) +#pylab.plot(x,p.imag) +#pylab.show() + + +for wavelet in wavelets: + #print wavelet + pylab.figure() + pylab.subplot(3,1,1) + t = time.clock() + print len(data), len(scales) + c = pywt.cwt.cwt(data, wavelet=wavelet, scales=scales, data_step=x[1]-x[0], precision=16) + print "%.4f" % (time.clock()-t) + + if absolute_values: + c = numpy.abs(numpy.asarray(c).real) + if normalize: + for y in c: + y *= 1.0 / max(abs(y.max()), abs(y.min())) + + c = numpy.asarray(c).real + pylab.imshow(c, origin='image', interpolation=interpolation, aspect='auto', cmap=cmap) + pylab.subplot(3,1,2) + pylab.plot(x, data) + pylab.xlim(x[0], x[-1]) + + pylab.subplot(3,1,3) + for i in [j for j in (4, 16, 32, 64, 128) if j <= max(scales) and j in scales]: + pylab.plot(x, c[scales.index(i)], label= ("scale a = %d" % i)) + pylab.legend() + pylab.xlim(x[0], x[-1]) +pylab.show() + diff --git a/demo/continuous.py b/demo/continuous.py new file mode 100644 index 00000000..6c0a9f54 --- /dev/null +++ b/demo/continuous.py @@ -0,0 +1,7 @@ +import pywt +import pywt.cwt as c +import pylab + +psi, x = c.mexican_hat(1000) +pylab.plot(x, psi) +pylab.show() diff --git a/demo/freq.py b/demo/freq.py new file mode 100644 index 00000000..3901a15c --- /dev/null +++ b/demo/freq.py @@ -0,0 +1,23 @@ +import pywt, numpy +#, pylab, time +import sample_data + +#interpolation = 'nearest' #'bilinear' +#cmap = pylab.cm.jet + +from pywt.functions import centfrq, orthfilt +print centfrq('db1', 8) +print centfrq('db2', 8) +print centfrq(pywt.cwt.morlet(256)) +print centfrq(pywt.cwt.mexican_hat(256)) +#print pywt.cwt.mexican_hat(256)[0] + +import pylab +pylab.plot(*pywt.cwt.cgauss1(256)[::-1]) +pylab.show() + +#for i in pywt.wavelist(): +# print "%s = %s" % (i.replace('.', '_'), centfrq(i, 8)) + + +#print orthfilt([1,2,3,4,5,6]) diff --git a/doc/cwt.rst b/doc/cwt.rst new file mode 100644 index 00000000..908aa166 --- /dev/null +++ b/doc/cwt.rst @@ -0,0 +1,51 @@ +Continuous Wavelet Transform (CWT) +---------------------------------- + +cwavelist() +~~~~~~~~~~~ + + +.. _`CWavelet`: + +Continuous Wavelet - ``CWavelet`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:: + + CWavelet(name, psi=None, properties={}) + + +Continuous Wavelet Transform with ``cwt`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +1D Continuous Wavelet Transform + +:: + + coeffs = cwt(data, wavelet, scales, data_step=1.0, precision=10) + +data + 1D input data + +wavelet + Wavelet *name*, CWavelet_ object or Wavelet_ object. + + For convenience, a pair of `(function_approximation, x_grid)` arrays can aslo + be used here:: + + coeffs cwt(data, (psi, x_grid), scales, data_step=1.0, precision=10) + +scales + A list of scales at which to perform the CWT. + +data_step + The distance between two neighbour points on the x-axis. + +precision + Applicable only when wavelet *name*, *CWavelet* object or *Wavelet* object is + passed as the *wavelet* parameter and is used to calculate the wavelet function + approximation. + +The function returns a list of coefficients arrays, one for every scale value +in *scales*. + diff --git a/pywt/__init__.py b/pywt/__init__.py index efcec1a1..b0d47bb9 100644 --- a/pywt/__init__.py +++ b/pywt/__init__.py @@ -14,7 +14,9 @@ from multilevel import * from multidim import * from wavelet_packets import * +from functions import * import thresholding +import cwt from release_details import version as __version__, author as __author__, license as __license__ __all__ = [] @@ -22,6 +24,7 @@ __all__ += wavelet_packets.__all__ __all__ += multilevel.__all__ __all__ += multidim.__all__ -__all__ += ['thresholding'] +__all__ += functions.__all__ +__all__ += ['thresholding', 'cwt'] del multilevel, multidim, wavelet_packets diff --git a/pywt/continuous_wavelets.py b/pywt/continuous_wavelets.py new file mode 100644 index 00000000..7235c037 --- /dev/null +++ b/pywt/continuous_wavelets.py @@ -0,0 +1,387 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +Continuous Wavelets definitions. +""" + +__all__ = [ + 'mexican_hat', 'morlet', + 'gauss1', 'gauss2', 'gauss3', + 'cgauss1', 'cgauss2', + 'cmorlet', 'cshannon', 'cfbsp', +] +__all__ += ['cwavelist'] + +from math import sqrt, pi +from numerix import cos, exp, sinc +from numerix import linspace + +""" +class WaveletFunction(object): + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + def get_xgrid(self, points): + return linspace(self.lower_bound, self.upper_bound, points) + + @classmethod + def for_name(cls, name): + pass + +class MexicanHat(WaveletFunction): + def __init__(self, lower_bound=-8.0, upper_bound=8.0): + super(MexicanHat, self).__init__(lower_bound, upper_bound) + + self.properties = { + "family": "Mexican Hat", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False,"support": "infinite", + "effective support": [-5, 5], "symmetry": "symmetric" + } + + def __call__(self, points): + x = self.get_xgrid(points) + x2 = x*x + psi = -x2 + psi += 1 + psi *= (2.0/sqrt(3)*(pi**-0.25)) + psi *= exp(-0.5 * x2) + return psi, x + + +class Morlet(WaveletFunction): + def __init__(self, lower_bound=-8.0, upper_bound=8.0): + super(Morlet, self).__init__(lower_bound, upper_bound) + + def __call__(self, points): + x = self.get_xgrid(points) + minus_half_x2 = x*x + minus_half_x2 *= -0.5 + psi = exp(minus_half_x2) + psi *= cos(5*x) + return psi, x + +class GaussDerivative(WaveletFunction): + def __init__(self, lower_bound=-5.0, upper_bound=5.0): + super(GaussDerivative, self).__init__(lower_bound, upper_bound) + assert order in [1, 2, 3] + + self.order = order + self.psi, self.properties = { + 1: (self.gauss1, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": [-5, 5], "symmetry": "symmetric" + } + ), + 2: (self.gauss2, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": [-5, 5], "symmetry": "anti-symmetric" + } + ), + 3: (self.gauss3, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": [-5, 5], "symmetry": "symmetric" + } + ), + }[self.order] + + def __call__(self, points): + return self.psi(points, self.lower_bound, self.upper_bound) + + def gauss1(self, points, lower_bound, upper_bound): + x = self.get_xgrid(points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= -2*(2.0/pi)**0.25 + return psi, x + + def gauss2(self, points, lower_bound, upper_bound): + x = self.get_xgrid(points) + x2 = x*x + psi = exp(-x2) + psi *= -1+2*x2 + psi *= -2.0/(3**0.25) * (2.0/pi)**0.25 + return psi, x + + def gauss3(self, points, lower_bound, upper_bound): + x = self.get_xgrid(points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= (3-2*x2) + psi *= (2.0/pi)**0.25 * -4.0/(15**0.5) + return psi, x + +class ComplexMorlet(WaveletFunction): + def __init__(self, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + super(ComplexMorlet, self).__init__(lower_bound, upper_bound) + assert bandwidth > 0 + assert center_frequency > 0 + self.bandwidth = bandwidth + self.center_frequency = center_frequency + + def __call__(self, points): + x = self.get_xgrid(points) + a = x*x + a *= -1.0/self.bandwidth + psi = 1.0 / sqrt(pi * self.bandwidth) * exp(2j*pi*self.center_frequency*x) + psi *= exp(a) + return psi, x + +class ComplexFrequencyBSpline(WaveletFunction): + def __init__(self, order, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + super(ComplexFrequencyBSpline, self).__init__(lower_bound, upper_bound) + assert order > 0 + assert bandwidth > 0 + assert center_frequency > 0 + self.order = order + self.bandwidth = bandwidth + self.center_frequency = center_frequency + + def __call__(self, points): + x = self.get_xgrid(points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth/order*x)**order + psi *= sqrt(bandwidth) + return psi, x + +def cfbsp(points, order, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + order, bandwidth, center_frequency = int(order), float(bandwidth), float(center_frequency) + assert bandwidth > 0 + assert center_frequency > 0 +cfbsp.params = [int, float, float] + +def cshannon(points, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + bandwidth, center_frequency = float(bandwidth), float(center_frequency) + assert bandwidth > 0 + assert center_frequency > 0 + x = self.get_xgrid(points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth*x) + psi *= sqrt(bandwidth) + return psi, x +cshannon.params = [float, float] + +""" + + +def mexican_hat(points, lower_bound=-8.0, upper_bound=8.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = -x2 + psi += 1 + psi *= (2.0/sqrt(3)*(pi**-0.25)) + psi *= exp(-0.5 * x2) + return psi, x + +def morlet(points, lower_bound=-8.0, upper_bound=8.0): + x = linspace(lower_bound, upper_bound, points) + minus_half_x2 = x*x + minus_half_x2 *= -0.5 + psi = exp(minus_half_x2) + psi *= cos(5*x) + return psi, x + +def gauss1(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= -2*(2.0/pi)**0.25 + return psi, x + +def gauss2(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = exp(-x2) + psi *= -1+2*x2 + psi *= -2.0/(3**0.25) * (2.0/pi)**0.25 + return psi, x + +def gauss3(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= (3-2*x2) + psi *= (2.0/pi)**0.25 * -4.0/(15**0.5) + return psi, x + +def cgauss1(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = (-1j-2*x) + psi *= exp(-x2) + psi *= sqrt(2) / sqrt(exp(-0.5) * sqrt(2) * sqrt(pi)) + psi *= exp(-1j*x) + return psi, x + +def cgauss2(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = 1j*(exp(-x2) * exp(-1j*x)) + tmp = 4j*x + x2 *= 4 + tmp += x2 + tmp -= 3 + psi *= tmp + psi *= sqrt(6) * (1.0/sqrt(exp(-0.5) * sqrt(2) * sqrt(pi))) / 3 + return psi, x + +def cmorlet(points, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + assert bandwidth > 0 + assert center_frequency > 0 + x = linspace(lower_bound, upper_bound, points) + x2b = x*x + x2b *= -1.0/bandwidth + psi = 1.0 / sqrt(pi * bandwidth) * exp(2j*pi*center_frequency*x) + psi *= exp(x2b) + return psi, x + +def cshannon(points, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + assert bandwidth > 0 + assert center_frequency > 0 + x = linspace(lower_bound, upper_bound, points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth*x) + psi *= sqrt(bandwidth) + return psi, x + +def cfbsp(points, order, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + assert order > 0 + assert bandwidth > 0 + assert center_frequency > 0 + x = linspace(lower_bound, upper_bound, points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth/order*x)**order + psi *= sqrt(bandwidth) + return psi, x + +wavelet_functions = { + #real continuous wavelet functions + "mexican_hat": (mexican_hat, { + "family": "Mexican Hat", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False,"support": "infinite", + "effective support": (-5, 5), "symmetry": "symmetric" + }, None + ), + "morlet": (morlet, { + "family": "Morlet", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-4, 4), "symmetry": "symmetric" + }, None + ), + "gauss1": (gauss1, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-5, 5), "symmetry": "symmetric" + }, None + ), + "gauss2": (gauss2, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-5, 5), "symmetry": "anti-symmetric" + }, None + ), + "gauss3": (gauss3, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-5, 5), "symmetry": "symmetric" + }, None + ), + #complex continuous wavelets + #"cgauss1": (cgauss1, { + # "family": "Complex Gaussian", "orthogonal": False, "biorthogonal": False, + # "compact support": False, "complex": True, "support": "infinite", + # "symmetry": "symmetric" + # } + # ), + #"cgauss2": (cgauss2, { + # "family": "Complex Gaussian", "orthogonal": False, "biorthogonal": False, + # "compact support": False, "complex": True, "support": "infinite", + # "symmetry": "anti-symmetric" + # } + # ), + "cmorlet": (cmorlet, { + "family": "Complex Morlet", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": True, "support": "infinite" + }, (float, float) + ), + "cshannon": (cshannon, { + "family": "Shannon", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": True, "support": "infinite" + }, (float, float) + ), + "cfbsp": (cfbsp, { + "family": "Frequency B-Spline", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": True, "support": "infinite" + }, (int, float, float) + ), + } + +""" +def name_to_wavelet_and_params(name): + if name in wavelet_functions: + psi, properties = wavelet_functions[self.name] + return psi, properties, () + elif '-' in name: + for basename in wavelet_functions: + if wavelet_name.startswith(basename): + psi, properties = wavelet_functions[basename] + params = name[len(basename):].split('-') + if not hasattr(psi, 'params_count') or psi.params_count != len(params): + raise ValueError("Invalid parameter string in name: %s. Expected %d hyphen-delimited params." % \ + (name[len(basename):], psi.params_count)) + return psi, properties, params + raise ValueError("Invalid continuous wavelet name: %s." % name) +""" + +def format_wavelet_name(basename, params_spec=[]): + if params_spec: + (basename + '-'.join(["(%s)" % param.__name__ for param in params_spec])) + return basename + +def function_for_name(name): + if name in wavelet_functions: + psi, properties, params_spec = wavelet_functions[name] + if params_spec: + raise ValueError("Invalid wavelet name - '%s'. Missing params part for wavelet '%s'." % \ + format_wavelet_name(name, params_spec)) + else: + psi = None + for basename in wavelet_functions: + if name.startswith(basename): + _psi, properties, params_spec = wavelet_functions[basename] + if not params_spec: + raise ValueError("Invalid wavelet name - '%s'. No params expected for wavelet '%s'." % \ + (name, basename)) + params = name[len(basename):].split('-') + if len(params_spec) != len(params): + raise ValueError("Invalid wavelet name - '%s'. Expected %d parameters for '%s' wavelet, got %d instead." % \ + (name, len(params_spec), format_wavelet_name(basename, params_spec), len(params))) + try: + params = [type(value) for type, value in zip(params_spec, params)] + except ValueError: + raise ValueError("Invalid wavelet name - '%s'. Cannot convert parameter '%s' to type '%s'." % \ + (value, type.__name__)) + params = tuple(params) + + def psi(points, **kwds): + return _psi(points, *params, **kwds) + if psi is None: + raise ValueError("Invalid wavelet name - '%s'." % name) + + return psi, properties.copy() + + +def cwavelist(): + return [format_wavelet_name(basename, spec[2]) for basename, spec in wavelet_functions.items()] diff --git a/pywt/cwt.py b/pywt/cwt.py new file mode 100644 index 00000000..20cbafd6 --- /dev/null +++ b/pywt/cwt.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +Continuous Wavelet Transform module. +""" + +__all__ = ['cwt', 'CWavelet', 'cwavelist'] + +from math import sqrt, floor + +from numerix import asarray, linspace +from numerix import intp +from numerix import concatenate, keep +from numerix import convolve, diff + + +from continuous_wavelets import cwavelist, function_for_name +from continuous_wavelets import * + + +class CWavelet(object): + """ + + """ + def __init__(self, name, psi=None, properties={}): + self.name = name.lower() + if psi is not None: + self.psi, self.properties = psi, properties + else: + self.psi, self.properties = function_for_name(name) + + def wavefun(self, iter=10, points=None, lower_bound=None, upper_bound=None): + if points is None: + assert iter > 0 + points = 2**iter + assert points > 0 + + kwds = {} + if lower_bound is not None: + kwds["lower_bound"] = lower_bound + if upper_bound is not None: + kwds["upper_bound"] = upper_bound + + return self.psi(points, **kwds) + + +def cwt(data, wavelet, scales, data_step=1.0, precision=10): + """ + cwt(data, wavelet, scales, data_step=1, precision=10) + + 1D Continuous Wavelet Transform + + data - 1D input data + wavelet - Wavelet name, CWavelet object or Wavelet object. + For convenience, a pair of (function_approximation, x_grid) + arrays can aslo be passed. + scales - List of scales at which the CWT will be computed. + Each scale must me in range (0 < scale < len(data)/2). + data_step - The distance between two neighbour points on the x-axis. + precision - Applicable only when wavelet *name*, *CWavelet* object + or *Wavelet* object is passed as the *wavelet* parameter + and is used to calculate the wavelet function approximation. + """ + from functions import intwave + + data = asarray(data) + if len(data.shape) != 1: + raise ValueError("Expected 1D array, got %dD." % len(data.shape)) + if isinstance(scales, (int, float)): + scales = [scales] + elif not isinstance(scales, (list, tuple)): + raise ValueError("Scales parameter must be a list or tuple of ints or floats, not %s." % type(scales)) + if not len(scales): + raise ValueError("Scales parameter must be non-empty list of ints or floats.") + if not (max(scales) < len(data)/2 and min(scales) > 0): + raise ValueError("Scales values must be in range (0 < scale < len(data)/2).") + + # integrate wavelet function + if isinstance(wavelet, tuple): + psi, psi_x = wavelet + assert len(psi) == len(psi_x) + intwavefun, intwavefun_x = intwave((psi, psi_x)) + else: + _intwavefun = intwave(wavelet, precision) + intwavefun, intwavefun_x = _intwavefun[0], _intwavefun[-1] + del _intwavefun + + intwavefun_step = intwavefun_x[1]-intwavefun_x[0] # xgrid step + intwavefun_x -= intwavefun_x[0] # shift xgrid to start in 0 point + assert intwavefun_x[-1] > 0 + + coeffs = [] + + for scale in scales: + scale = float(scale) + resampled_intwavefun = resample(intwavefun, intwavefun_x, scale, data_step) # scale integrated wavelet function + conv = convolve(data, resampled_intwavefun[::-1]) # match data against scaled function + d = diff(conv) # compute 1st derivative from coefficients + c = keep(d, len(data)) # keep only coefficients in range + c *= -sqrt(scale) # normalize coefficients according to scale + coeffs.append(c) + + return coeffs + + +def resample(function, xgrid, scale, data_step): + """Resample `function` defined on `xgrix` [0, x] using `scale`. + """ + step = int(floor(scale / data_step * xgrid[-1])) + 1 + resampled = function[linspace(0, (len(xgrid)-1), step).astype(intp)] + assert len(resampled) > 0 + if len(resampled) == 1: + resampled = concatenate([resampled, resampled]) + return resampled diff --git a/pywt/functions.py b/pywt/functions.py new file mode 100644 index 00000000..92980778 --- /dev/null +++ b/pywt/functions.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +Other wavelet related functions. +""" + +__all__ = ["intwave", "centfrq", "scal2frq", "qmf", "orthfilt"] + +from math import sqrt + +from _pywt import Wavelet +from cwt import CWavelet +from helpers import wavelet_for_name + +from numerix import asarray, array, float64 +from numerix import integrate +from numerix import argmax, mean +from numerix import fft + + +def intwave(wavelet, precision=8): + """ + intwave(wavelet, precision=8) -> [int_psi, x] - for orthogonal wavelets + intwave(wavelet, precision=8) -> [int_psi_d, int_psi_r, x] - for other wavelets + intwave((function_approx, x), precision=8) -> [int_function, x] - for (function approx., x grid) pair + + Integrate *psi* wavelet function from -Inf to x using the rectangle + integration method. + + wavelet - Wavelet to integrate (Wavelet object, wavelet name string + or (wavelet function approx., x grid) pair) + + precision = 8 - Precision that will be used for wavelet function + approximation computed with the wavefun(level=precision) + Wavelet's method. + + (function_approx, x) - Function to integrate on the x grid. Used instead + of Wavelet object to allow custom wavelet functions. + """ + + if isinstance(wavelet, tuple): + psi, x = asarray(wavelet[0]), asarray(wavelet[1]) + step = x[1] - x[0] + return integrate(psi, step), x + + else: + if not isinstance(wavelet, (Wavelet, CWavelet)): + wavelet = wavelet_for_name(wavelet) + + functions_approximations = wavelet.wavefun(precision) + if len(functions_approximations) == 2: # continuous wavelet + psi, x = functions_approximations + step = x[1] - x[0] + return integrate(psi, step), x + elif len(functions_approximations) == 3: # orthogonal wavelet + phi, psi, x = functions_approximations + step = x[1] - x[0] + return integrate(psi, step), x + else: # biorthogonal wavelet + phi_d, psi_d, phi_r, psi_r, x = functions_approximations + step = x[1] - x[0] + return integrate(psi_d, step), integrate(psi_r, step), x + + +def centfrq(wavelet, precision=8): + """ + centfrq(wavelet, precision=8) -> float - for orthogonal wavelets + centfrq((function_aprox, x), precision=8) -> float - for (function approx., x grid) pair + + Computes the central frequency of the *psi* wavelet function. + + wavelet - Wavelet (Wavelet object, wavelet name string + or (wavelet function approx., x grid) pair) + precision = 8 - Precision that will be used for wavelet function + approximation computed with the wavefun(level=precision) + Wavelet's method. + + (function_approx, xgrid) - Function defined on xgrid. Used instead + of Wavelet object to allow custom wavelet functions. + """ + + if isinstance(wavelet, tuple): + psi, x = asarray(wavelet[0]), asarray(wavelet[1]) + else: + if not isinstance(wavelet, (Wavelet, CWavelet)): + wavelet = wavelet_for_name(wavelet) + functions_approximations = wavelet.wavefun(precision) + + if len(functions_approximations) == 2: + psi, x = functions_approximations + else: + psi, x = functions_approximations[1], functions_approximations[-1] # (psi, x) for (phi, psi, x) and (psi_d, x) for (phi_d, psi_d, phi_r, psi_r, x) + + domain = float(x[-1] - x[0]) + assert domain > 0 + + index = argmax(abs(fft(psi)[1:]))+2 + if index > len(psi)/2: + index = len(psi)-index+2 + + return 1.0/(domain/(index-1)) + + +def scal2frq(wavelet, scale, delta, precision=8): + """ + scal2frq(wavelet, scale, delta, precision=8) -> float - for orthogonal wavelets + scal2frq(wavelet, scale, delta, precision=8) -> float - for (function approx., x grid) pair + + wavelet + scale + delta - sampling + """ + return centfrq(wavelet, precision=precision)/(scale*delta) + + +def qmf(filter): + filter = array(filter)[::-1] + filter[1::2] = -filter[1::2] + return filter + + +def orthfilt(scaling_filter): + assert len(scaling_filter) % 2 == 0 + + scaling_filter = asarray(scaling_filter, dtype=float64) + + rec_lo = sqrt(2) * scaling_filter / sum(scaling_filter) + dec_lo = rec_lo[::-1] + + rec_hi = qmf(rec_lo) + dec_hi = rec_hi[::-1] + + return (dec_lo, dec_hi, rec_lo, rec_hi) diff --git a/pywt/helpers.py b/pywt/helpers.py new file mode 100644 index 00000000..2349d3ca --- /dev/null +++ b/pywt/helpers.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +""" + + +from _pywt import Wavelet +from cwt import CWavelet + +def wavelet_for_name(name): + if not isinstance(name, basestring): + raise TypeError("Wavelet name must be a string, not %s" % type(name)) + try: + wavelet = Wavelet(name) + except ValueError: + try: + wavelet = CWavelet(name) + except: + raise + #raise ValueError("Invalid wavelet name - %s." % name) + return wavelet