From be327abbfbe06e5ab82e4a447b44ccce358f3e4c Mon Sep 17 00:00:00 2001 From: Filip Wasilewski Date: Mon, 2 Mar 2009 11:34:36 +0000 Subject: [PATCH 1/2] Added CWT branch. --HG-- branch : pywt-cwt extra : convert_revision : svn%3A993dc4b6-72fc-0310-9b3e-91fa30ebc9a8/pywt/branches/pywt-cwt%40114 From 2f347bb8320680f6a41bde4e7e66d279356dc824 Mon Sep 17 00:00:00 2001 From: Filip Wasilewski Date: Mon, 2 Mar 2009 12:05:34 +0000 Subject: [PATCH 2/2] Added experimental Continuous Wavelet Transform code which can use both discrete and continuous wavelets for CWT. Note 1: CWT API (CWavelet class, continuous wavelet names with params - i.e. "cfbsp4-0.7-1") is designed to work with the current DWT solutions and follow the Matlab Wavelet Toolbox CWT API, but is not perfect and the most cleanest approach. Any suggestions regarding API design are greatly appreciated. Note 2: There are probably errors in the Gaussian Derivatives function declarations in the continuous_wavelets.py module, so watch out. --HG-- branch : pywt-cwt extra : convert_revision : svn%3A993dc4b6-72fc-0310-9b3e-91fa30ebc9a8/pywt/branches/pywt-cwt%40115 --- demo/coeffs.py | 79 ++++++++ demo/continuous.py | 7 + demo/freq.py | 23 +++ doc/cwt.rst | 51 +++++ pywt/__init__.py | 5 +- pywt/continuous_wavelets.py | 387 ++++++++++++++++++++++++++++++++++++ pywt/cwt.py | 119 +++++++++++ pywt/functions.py | 138 +++++++++++++ pywt/helpers.py | 26 +++ 9 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 demo/coeffs.py create mode 100644 demo/continuous.py create mode 100644 demo/freq.py create mode 100644 doc/cwt.rst create mode 100644 pywt/continuous_wavelets.py create mode 100644 pywt/cwt.py create mode 100644 pywt/functions.py create mode 100644 pywt/helpers.py diff --git a/demo/coeffs.py b/demo/coeffs.py new file mode 100644 index 00000000..07865517 --- /dev/null +++ b/demo/coeffs.py @@ -0,0 +1,79 @@ +import pywt, numpy, pylab, time +import sample_data + +interpolation = 'nearest' #'bilinear' +cmap = pylab.cm.jet + +absolute_values = 0 +normalize = 0 + +scales = range(4, 129, 2) +sample = 1 + +if sample == 1: + x = numpy.linspace(-300, 1300, 1024*1) + data = 2*numpy.sin(2*numpy.pi/4 * x) * numpy.exp(-(x-400)**2/(2*300**2)) + \ + numpy.sin(2*numpy.pi/32*x) * numpy.exp(-(x-700)**2/(2*100**2)) + \ + numpy.sin(2*numpy.pi/32 * (x/(1+x/1000)) ) +elif sample == 2: + data = sample_data.cuspamax + data = data + x = range(len(data)) +elif sample == 3: + data = sample_data.linchirp + x = range(len(data)) +elif sample == 4: + data = sample_data.ecg + x = range(len(data)) + +wavelets = ['db1', 'sym5', 'coif3', 'bior1.5', 'bior3.3', 'bior4.4'][:2] +#, pywt.cwt.morlet(len(data)), pywt.cwt.mexican_hat(len(data))]#[:4]#[-1:] +wavelets = [pywt.cwt.cmorlet(len(data)/4, 1.0, 1.0), + pywt.cwt.cmorlet(len(data)/4, 4.0, 1.0), + pywt.cwt.cmorlet(len(data)/4, 1.0, 4.0)]#, pywt.cwt.gauss1(len(data))] #, 'coif3', 'bior1.5', 'bior3.3', 'bior4.4', pywt.cwt.morlet(len(data)), pywt.cwt.mexican_hat(len(data))]#[:4]#[-1:] +#wavelets = [pywt.cwt.morlet(len(data)), pywt.cwt.mexican_hat(len(data)), pywt.cwt.gauss1(len(data)), pywt.cwt.gauss2(len(data)), pywt.cwt.gauss3(len(data))] +#wavelets = ['cmorlet2-2', 'cmorlet1-1'] +#wavelets = ['mexican_hat', 'cmorlet2-2', 'cmorlet1-1'] +#wavelets = ['cfbsp4-0.7-1', pywt.cwt.cfbsp(len(data), 4, 0.7, 1)]#, 'cmorlet1-1'] +#wavelets = [pywt.cwt.cmorlet(2**8, 2, 2)] + +for name in ['mexican_hat', 'morlet', 'gauss1', 'gauss2', 'gauss3', 'cfbsp1-1-1', 'cmorlet1-1', 'cshannon1-1']: + print name, + t = time.clock() + #pywt.cwt.CWavelet(name).wavefun(19) + print time.clock() - t + +#p,x = pywt.cwt.CWavelet('cfbsp1-1-1').wavefun() +#pylab.plot(x,p.real) +#pylab.plot(x,p.imag) +#pylab.show() + + +for wavelet in wavelets: + #print wavelet + pylab.figure() + pylab.subplot(3,1,1) + t = time.clock() + print len(data), len(scales) + c = pywt.cwt.cwt(data, wavelet=wavelet, scales=scales, data_step=x[1]-x[0], precision=16) + print "%.4f" % (time.clock()-t) + + if absolute_values: + c = numpy.abs(numpy.asarray(c).real) + if normalize: + for y in c: + y *= 1.0 / max(abs(y.max()), abs(y.min())) + + c = numpy.asarray(c).real + pylab.imshow(c, origin='image', interpolation=interpolation, aspect='auto', cmap=cmap) + pylab.subplot(3,1,2) + pylab.plot(x, data) + pylab.xlim(x[0], x[-1]) + + pylab.subplot(3,1,3) + for i in [j for j in (4, 16, 32, 64, 128) if j <= max(scales) and j in scales]: + pylab.plot(x, c[scales.index(i)], label= ("scale a = %d" % i)) + pylab.legend() + pylab.xlim(x[0], x[-1]) +pylab.show() + diff --git a/demo/continuous.py b/demo/continuous.py new file mode 100644 index 00000000..6c0a9f54 --- /dev/null +++ b/demo/continuous.py @@ -0,0 +1,7 @@ +import pywt +import pywt.cwt as c +import pylab + +psi, x = c.mexican_hat(1000) +pylab.plot(x, psi) +pylab.show() diff --git a/demo/freq.py b/demo/freq.py new file mode 100644 index 00000000..3901a15c --- /dev/null +++ b/demo/freq.py @@ -0,0 +1,23 @@ +import pywt, numpy +#, pylab, time +import sample_data + +#interpolation = 'nearest' #'bilinear' +#cmap = pylab.cm.jet + +from pywt.functions import centfrq, orthfilt +print centfrq('db1', 8) +print centfrq('db2', 8) +print centfrq(pywt.cwt.morlet(256)) +print centfrq(pywt.cwt.mexican_hat(256)) +#print pywt.cwt.mexican_hat(256)[0] + +import pylab +pylab.plot(*pywt.cwt.cgauss1(256)[::-1]) +pylab.show() + +#for i in pywt.wavelist(): +# print "%s = %s" % (i.replace('.', '_'), centfrq(i, 8)) + + +#print orthfilt([1,2,3,4,5,6]) diff --git a/doc/cwt.rst b/doc/cwt.rst new file mode 100644 index 00000000..908aa166 --- /dev/null +++ b/doc/cwt.rst @@ -0,0 +1,51 @@ +Continuous Wavelet Transform (CWT) +---------------------------------- + +cwavelist() +~~~~~~~~~~~ + + +.. _`CWavelet`: + +Continuous Wavelet - ``CWavelet`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:: + + CWavelet(name, psi=None, properties={}) + + +Continuous Wavelet Transform with ``cwt`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +1D Continuous Wavelet Transform + +:: + + coeffs = cwt(data, wavelet, scales, data_step=1.0, precision=10) + +data + 1D input data + +wavelet + Wavelet *name*, CWavelet_ object or Wavelet_ object. + + For convenience, a pair of `(function_approximation, x_grid)` arrays can aslo + be used here:: + + coeffs cwt(data, (psi, x_grid), scales, data_step=1.0, precision=10) + +scales + A list of scales at which to perform the CWT. + +data_step + The distance between two neighbour points on the x-axis. + +precision + Applicable only when wavelet *name*, *CWavelet* object or *Wavelet* object is + passed as the *wavelet* parameter and is used to calculate the wavelet function + approximation. + +The function returns a list of coefficients arrays, one for every scale value +in *scales*. + diff --git a/pywt/__init__.py b/pywt/__init__.py index efcec1a1..b0d47bb9 100644 --- a/pywt/__init__.py +++ b/pywt/__init__.py @@ -14,7 +14,9 @@ from multilevel import * from multidim import * from wavelet_packets import * +from functions import * import thresholding +import cwt from release_details import version as __version__, author as __author__, license as __license__ __all__ = [] @@ -22,6 +24,7 @@ __all__ += wavelet_packets.__all__ __all__ += multilevel.__all__ __all__ += multidim.__all__ -__all__ += ['thresholding'] +__all__ += functions.__all__ +__all__ += ['thresholding', 'cwt'] del multilevel, multidim, wavelet_packets diff --git a/pywt/continuous_wavelets.py b/pywt/continuous_wavelets.py new file mode 100644 index 00000000..7235c037 --- /dev/null +++ b/pywt/continuous_wavelets.py @@ -0,0 +1,387 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +Continuous Wavelets definitions. +""" + +__all__ = [ + 'mexican_hat', 'morlet', + 'gauss1', 'gauss2', 'gauss3', + 'cgauss1', 'cgauss2', + 'cmorlet', 'cshannon', 'cfbsp', +] +__all__ += ['cwavelist'] + +from math import sqrt, pi +from numerix import cos, exp, sinc +from numerix import linspace + +""" +class WaveletFunction(object): + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + def get_xgrid(self, points): + return linspace(self.lower_bound, self.upper_bound, points) + + @classmethod + def for_name(cls, name): + pass + +class MexicanHat(WaveletFunction): + def __init__(self, lower_bound=-8.0, upper_bound=8.0): + super(MexicanHat, self).__init__(lower_bound, upper_bound) + + self.properties = { + "family": "Mexican Hat", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False,"support": "infinite", + "effective support": [-5, 5], "symmetry": "symmetric" + } + + def __call__(self, points): + x = self.get_xgrid(points) + x2 = x*x + psi = -x2 + psi += 1 + psi *= (2.0/sqrt(3)*(pi**-0.25)) + psi *= exp(-0.5 * x2) + return psi, x + + +class Morlet(WaveletFunction): + def __init__(self, lower_bound=-8.0, upper_bound=8.0): + super(Morlet, self).__init__(lower_bound, upper_bound) + + def __call__(self, points): + x = self.get_xgrid(points) + minus_half_x2 = x*x + minus_half_x2 *= -0.5 + psi = exp(minus_half_x2) + psi *= cos(5*x) + return psi, x + +class GaussDerivative(WaveletFunction): + def __init__(self, lower_bound=-5.0, upper_bound=5.0): + super(GaussDerivative, self).__init__(lower_bound, upper_bound) + assert order in [1, 2, 3] + + self.order = order + self.psi, self.properties = { + 1: (self.gauss1, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": [-5, 5], "symmetry": "symmetric" + } + ), + 2: (self.gauss2, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": [-5, 5], "symmetry": "anti-symmetric" + } + ), + 3: (self.gauss3, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": [-5, 5], "symmetry": "symmetric" + } + ), + }[self.order] + + def __call__(self, points): + return self.psi(points, self.lower_bound, self.upper_bound) + + def gauss1(self, points, lower_bound, upper_bound): + x = self.get_xgrid(points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= -2*(2.0/pi)**0.25 + return psi, x + + def gauss2(self, points, lower_bound, upper_bound): + x = self.get_xgrid(points) + x2 = x*x + psi = exp(-x2) + psi *= -1+2*x2 + psi *= -2.0/(3**0.25) * (2.0/pi)**0.25 + return psi, x + + def gauss3(self, points, lower_bound, upper_bound): + x = self.get_xgrid(points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= (3-2*x2) + psi *= (2.0/pi)**0.25 * -4.0/(15**0.5) + return psi, x + +class ComplexMorlet(WaveletFunction): + def __init__(self, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + super(ComplexMorlet, self).__init__(lower_bound, upper_bound) + assert bandwidth > 0 + assert center_frequency > 0 + self.bandwidth = bandwidth + self.center_frequency = center_frequency + + def __call__(self, points): + x = self.get_xgrid(points) + a = x*x + a *= -1.0/self.bandwidth + psi = 1.0 / sqrt(pi * self.bandwidth) * exp(2j*pi*self.center_frequency*x) + psi *= exp(a) + return psi, x + +class ComplexFrequencyBSpline(WaveletFunction): + def __init__(self, order, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + super(ComplexFrequencyBSpline, self).__init__(lower_bound, upper_bound) + assert order > 0 + assert bandwidth > 0 + assert center_frequency > 0 + self.order = order + self.bandwidth = bandwidth + self.center_frequency = center_frequency + + def __call__(self, points): + x = self.get_xgrid(points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth/order*x)**order + psi *= sqrt(bandwidth) + return psi, x + +def cfbsp(points, order, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + order, bandwidth, center_frequency = int(order), float(bandwidth), float(center_frequency) + assert bandwidth > 0 + assert center_frequency > 0 +cfbsp.params = [int, float, float] + +def cshannon(points, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + bandwidth, center_frequency = float(bandwidth), float(center_frequency) + assert bandwidth > 0 + assert center_frequency > 0 + x = self.get_xgrid(points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth*x) + psi *= sqrt(bandwidth) + return psi, x +cshannon.params = [float, float] + +""" + + +def mexican_hat(points, lower_bound=-8.0, upper_bound=8.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = -x2 + psi += 1 + psi *= (2.0/sqrt(3)*(pi**-0.25)) + psi *= exp(-0.5 * x2) + return psi, x + +def morlet(points, lower_bound=-8.0, upper_bound=8.0): + x = linspace(lower_bound, upper_bound, points) + minus_half_x2 = x*x + minus_half_x2 *= -0.5 + psi = exp(minus_half_x2) + psi *= cos(5*x) + return psi, x + +def gauss1(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= -2*(2.0/pi)**0.25 + return psi, x + +def gauss2(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = exp(-x2) + psi *= -1+2*x2 + psi *= -2.0/(3**0.25) * (2.0/pi)**0.25 + return psi, x + +def gauss3(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = exp(-x2) + psi *= x + psi *= (3-2*x2) + psi *= (2.0/pi)**0.25 * -4.0/(15**0.5) + return psi, x + +def cgauss1(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = (-1j-2*x) + psi *= exp(-x2) + psi *= sqrt(2) / sqrt(exp(-0.5) * sqrt(2) * sqrt(pi)) + psi *= exp(-1j*x) + return psi, x + +def cgauss2(points, lower_bound=-5.0, upper_bound=5.0): + x = linspace(lower_bound, upper_bound, points) + x2 = x*x + psi = 1j*(exp(-x2) * exp(-1j*x)) + tmp = 4j*x + x2 *= 4 + tmp += x2 + tmp -= 3 + psi *= tmp + psi *= sqrt(6) * (1.0/sqrt(exp(-0.5) * sqrt(2) * sqrt(pi))) / 3 + return psi, x + +def cmorlet(points, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + assert bandwidth > 0 + assert center_frequency > 0 + x = linspace(lower_bound, upper_bound, points) + x2b = x*x + x2b *= -1.0/bandwidth + psi = 1.0 / sqrt(pi * bandwidth) * exp(2j*pi*center_frequency*x) + psi *= exp(x2b) + return psi, x + +def cshannon(points, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + assert bandwidth > 0 + assert center_frequency > 0 + x = linspace(lower_bound, upper_bound, points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth*x) + psi *= sqrt(bandwidth) + return psi, x + +def cfbsp(points, order, bandwidth, center_frequency, lower_bound=-8.0, upper_bound=8.0): + assert order > 0 + assert bandwidth > 0 + assert center_frequency > 0 + x = linspace(lower_bound, upper_bound, points) + psi = exp(2j*pi*center_frequency*x) + psi *= sinc(bandwidth/order*x)**order + psi *= sqrt(bandwidth) + return psi, x + +wavelet_functions = { + #real continuous wavelet functions + "mexican_hat": (mexican_hat, { + "family": "Mexican Hat", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False,"support": "infinite", + "effective support": (-5, 5), "symmetry": "symmetric" + }, None + ), + "morlet": (morlet, { + "family": "Morlet", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-4, 4), "symmetry": "symmetric" + }, None + ), + "gauss1": (gauss1, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-5, 5), "symmetry": "symmetric" + }, None + ), + "gauss2": (gauss2, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-5, 5), "symmetry": "anti-symmetric" + }, None + ), + "gauss3": (gauss3, { + "family": "Gaussian", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": False, "support": "infinite", + "effective support": (-5, 5), "symmetry": "symmetric" + }, None + ), + #complex continuous wavelets + #"cgauss1": (cgauss1, { + # "family": "Complex Gaussian", "orthogonal": False, "biorthogonal": False, + # "compact support": False, "complex": True, "support": "infinite", + # "symmetry": "symmetric" + # } + # ), + #"cgauss2": (cgauss2, { + # "family": "Complex Gaussian", "orthogonal": False, "biorthogonal": False, + # "compact support": False, "complex": True, "support": "infinite", + # "symmetry": "anti-symmetric" + # } + # ), + "cmorlet": (cmorlet, { + "family": "Complex Morlet", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": True, "support": "infinite" + }, (float, float) + ), + "cshannon": (cshannon, { + "family": "Shannon", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": True, "support": "infinite" + }, (float, float) + ), + "cfbsp": (cfbsp, { + "family": "Frequency B-Spline", "orthogonal": False, "biorthogonal": False, + "compact support": False, "complex": True, "support": "infinite" + }, (int, float, float) + ), + } + +""" +def name_to_wavelet_and_params(name): + if name in wavelet_functions: + psi, properties = wavelet_functions[self.name] + return psi, properties, () + elif '-' in name: + for basename in wavelet_functions: + if wavelet_name.startswith(basename): + psi, properties = wavelet_functions[basename] + params = name[len(basename):].split('-') + if not hasattr(psi, 'params_count') or psi.params_count != len(params): + raise ValueError("Invalid parameter string in name: %s. Expected %d hyphen-delimited params." % \ + (name[len(basename):], psi.params_count)) + return psi, properties, params + raise ValueError("Invalid continuous wavelet name: %s." % name) +""" + +def format_wavelet_name(basename, params_spec=[]): + if params_spec: + (basename + '-'.join(["(%s)" % param.__name__ for param in params_spec])) + return basename + +def function_for_name(name): + if name in wavelet_functions: + psi, properties, params_spec = wavelet_functions[name] + if params_spec: + raise ValueError("Invalid wavelet name - '%s'. Missing params part for wavelet '%s'." % \ + format_wavelet_name(name, params_spec)) + else: + psi = None + for basename in wavelet_functions: + if name.startswith(basename): + _psi, properties, params_spec = wavelet_functions[basename] + if not params_spec: + raise ValueError("Invalid wavelet name - '%s'. No params expected for wavelet '%s'." % \ + (name, basename)) + params = name[len(basename):].split('-') + if len(params_spec) != len(params): + raise ValueError("Invalid wavelet name - '%s'. Expected %d parameters for '%s' wavelet, got %d instead." % \ + (name, len(params_spec), format_wavelet_name(basename, params_spec), len(params))) + try: + params = [type(value) for type, value in zip(params_spec, params)] + except ValueError: + raise ValueError("Invalid wavelet name - '%s'. Cannot convert parameter '%s' to type '%s'." % \ + (value, type.__name__)) + params = tuple(params) + + def psi(points, **kwds): + return _psi(points, *params, **kwds) + if psi is None: + raise ValueError("Invalid wavelet name - '%s'." % name) + + return psi, properties.copy() + + +def cwavelist(): + return [format_wavelet_name(basename, spec[2]) for basename, spec in wavelet_functions.items()] diff --git a/pywt/cwt.py b/pywt/cwt.py new file mode 100644 index 00000000..20cbafd6 --- /dev/null +++ b/pywt/cwt.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +Continuous Wavelet Transform module. +""" + +__all__ = ['cwt', 'CWavelet', 'cwavelist'] + +from math import sqrt, floor + +from numerix import asarray, linspace +from numerix import intp +from numerix import concatenate, keep +from numerix import convolve, diff + + +from continuous_wavelets import cwavelist, function_for_name +from continuous_wavelets import * + + +class CWavelet(object): + """ + + """ + def __init__(self, name, psi=None, properties={}): + self.name = name.lower() + if psi is not None: + self.psi, self.properties = psi, properties + else: + self.psi, self.properties = function_for_name(name) + + def wavefun(self, iter=10, points=None, lower_bound=None, upper_bound=None): + if points is None: + assert iter > 0 + points = 2**iter + assert points > 0 + + kwds = {} + if lower_bound is not None: + kwds["lower_bound"] = lower_bound + if upper_bound is not None: + kwds["upper_bound"] = upper_bound + + return self.psi(points, **kwds) + + +def cwt(data, wavelet, scales, data_step=1.0, precision=10): + """ + cwt(data, wavelet, scales, data_step=1, precision=10) + + 1D Continuous Wavelet Transform + + data - 1D input data + wavelet - Wavelet name, CWavelet object or Wavelet object. + For convenience, a pair of (function_approximation, x_grid) + arrays can aslo be passed. + scales - List of scales at which the CWT will be computed. + Each scale must me in range (0 < scale < len(data)/2). + data_step - The distance between two neighbour points on the x-axis. + precision - Applicable only when wavelet *name*, *CWavelet* object + or *Wavelet* object is passed as the *wavelet* parameter + and is used to calculate the wavelet function approximation. + """ + from functions import intwave + + data = asarray(data) + if len(data.shape) != 1: + raise ValueError("Expected 1D array, got %dD." % len(data.shape)) + if isinstance(scales, (int, float)): + scales = [scales] + elif not isinstance(scales, (list, tuple)): + raise ValueError("Scales parameter must be a list or tuple of ints or floats, not %s." % type(scales)) + if not len(scales): + raise ValueError("Scales parameter must be non-empty list of ints or floats.") + if not (max(scales) < len(data)/2 and min(scales) > 0): + raise ValueError("Scales values must be in range (0 < scale < len(data)/2).") + + # integrate wavelet function + if isinstance(wavelet, tuple): + psi, psi_x = wavelet + assert len(psi) == len(psi_x) + intwavefun, intwavefun_x = intwave((psi, psi_x)) + else: + _intwavefun = intwave(wavelet, precision) + intwavefun, intwavefun_x = _intwavefun[0], _intwavefun[-1] + del _intwavefun + + intwavefun_step = intwavefun_x[1]-intwavefun_x[0] # xgrid step + intwavefun_x -= intwavefun_x[0] # shift xgrid to start in 0 point + assert intwavefun_x[-1] > 0 + + coeffs = [] + + for scale in scales: + scale = float(scale) + resampled_intwavefun = resample(intwavefun, intwavefun_x, scale, data_step) # scale integrated wavelet function + conv = convolve(data, resampled_intwavefun[::-1]) # match data against scaled function + d = diff(conv) # compute 1st derivative from coefficients + c = keep(d, len(data)) # keep only coefficients in range + c *= -sqrt(scale) # normalize coefficients according to scale + coeffs.append(c) + + return coeffs + + +def resample(function, xgrid, scale, data_step): + """Resample `function` defined on `xgrix` [0, x] using `scale`. + """ + step = int(floor(scale / data_step * xgrid[-1])) + 1 + resampled = function[linspace(0, (len(xgrid)-1), step).astype(intp)] + assert len(resampled) > 0 + if len(resampled) == 1: + resampled = concatenate([resampled, resampled]) + return resampled diff --git a/pywt/functions.py b/pywt/functions.py new file mode 100644 index 00000000..92980778 --- /dev/null +++ b/pywt/functions.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +Other wavelet related functions. +""" + +__all__ = ["intwave", "centfrq", "scal2frq", "qmf", "orthfilt"] + +from math import sqrt + +from _pywt import Wavelet +from cwt import CWavelet +from helpers import wavelet_for_name + +from numerix import asarray, array, float64 +from numerix import integrate +from numerix import argmax, mean +from numerix import fft + + +def intwave(wavelet, precision=8): + """ + intwave(wavelet, precision=8) -> [int_psi, x] - for orthogonal wavelets + intwave(wavelet, precision=8) -> [int_psi_d, int_psi_r, x] - for other wavelets + intwave((function_approx, x), precision=8) -> [int_function, x] - for (function approx., x grid) pair + + Integrate *psi* wavelet function from -Inf to x using the rectangle + integration method. + + wavelet - Wavelet to integrate (Wavelet object, wavelet name string + or (wavelet function approx., x grid) pair) + + precision = 8 - Precision that will be used for wavelet function + approximation computed with the wavefun(level=precision) + Wavelet's method. + + (function_approx, x) - Function to integrate on the x grid. Used instead + of Wavelet object to allow custom wavelet functions. + """ + + if isinstance(wavelet, tuple): + psi, x = asarray(wavelet[0]), asarray(wavelet[1]) + step = x[1] - x[0] + return integrate(psi, step), x + + else: + if not isinstance(wavelet, (Wavelet, CWavelet)): + wavelet = wavelet_for_name(wavelet) + + functions_approximations = wavelet.wavefun(precision) + if len(functions_approximations) == 2: # continuous wavelet + psi, x = functions_approximations + step = x[1] - x[0] + return integrate(psi, step), x + elif len(functions_approximations) == 3: # orthogonal wavelet + phi, psi, x = functions_approximations + step = x[1] - x[0] + return integrate(psi, step), x + else: # biorthogonal wavelet + phi_d, psi_d, phi_r, psi_r, x = functions_approximations + step = x[1] - x[0] + return integrate(psi_d, step), integrate(psi_r, step), x + + +def centfrq(wavelet, precision=8): + """ + centfrq(wavelet, precision=8) -> float - for orthogonal wavelets + centfrq((function_aprox, x), precision=8) -> float - for (function approx., x grid) pair + + Computes the central frequency of the *psi* wavelet function. + + wavelet - Wavelet (Wavelet object, wavelet name string + or (wavelet function approx., x grid) pair) + precision = 8 - Precision that will be used for wavelet function + approximation computed with the wavefun(level=precision) + Wavelet's method. + + (function_approx, xgrid) - Function defined on xgrid. Used instead + of Wavelet object to allow custom wavelet functions. + """ + + if isinstance(wavelet, tuple): + psi, x = asarray(wavelet[0]), asarray(wavelet[1]) + else: + if not isinstance(wavelet, (Wavelet, CWavelet)): + wavelet = wavelet_for_name(wavelet) + functions_approximations = wavelet.wavefun(precision) + + if len(functions_approximations) == 2: + psi, x = functions_approximations + else: + psi, x = functions_approximations[1], functions_approximations[-1] # (psi, x) for (phi, psi, x) and (psi_d, x) for (phi_d, psi_d, phi_r, psi_r, x) + + domain = float(x[-1] - x[0]) + assert domain > 0 + + index = argmax(abs(fft(psi)[1:]))+2 + if index > len(psi)/2: + index = len(psi)-index+2 + + return 1.0/(domain/(index-1)) + + +def scal2frq(wavelet, scale, delta, precision=8): + """ + scal2frq(wavelet, scale, delta, precision=8) -> float - for orthogonal wavelets + scal2frq(wavelet, scale, delta, precision=8) -> float - for (function approx., x grid) pair + + wavelet + scale + delta - sampling + """ + return centfrq(wavelet, precision=precision)/(scale*delta) + + +def qmf(filter): + filter = array(filter)[::-1] + filter[1::2] = -filter[1::2] + return filter + + +def orthfilt(scaling_filter): + assert len(scaling_filter) % 2 == 0 + + scaling_filter = asarray(scaling_filter, dtype=float64) + + rec_lo = sqrt(2) * scaling_filter / sum(scaling_filter) + dec_lo = rec_lo[::-1] + + rec_hi = qmf(rec_lo) + dec_hi = rec_hi[::-1] + + return (dec_lo, dec_hi, rec_lo, rec_hi) diff --git a/pywt/helpers.py b/pywt/helpers.py new file mode 100644 index 00000000..2349d3ca --- /dev/null +++ b/pywt/helpers.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2006-2008 Filip Wasilewski +# See COPYING for license details. + +# $Id: $ + +""" +""" + + +from _pywt import Wavelet +from cwt import CWavelet + +def wavelet_for_name(name): + if not isinstance(name, basestring): + raise TypeError("Wavelet name must be a string, not %s" % type(name)) + try: + wavelet = Wavelet(name) + except ValueError: + try: + wavelet = CWavelet(name) + except: + raise + #raise ValueError("Invalid wavelet name - %s." % name) + return wavelet