连续小波变换算法的C++实现

头文件:

/*
 * Copyright (c) 2008-2011 Zhang Ming (M. Zhang), [email protected]
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 2 or any later version.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
 * more details. A copy of the GNU General Public License is available at:
 * http://www.fsf.org/licensing/licenses
 */


/*****************************************************************************
 *                                   cwt.h
 *
 * Continuous Wavelet Transform.
 *
 * Class for continuous wavelet transform, which is designed for computing
 * the continuous wavelet transform and it's inverse transform of 1D signals.
 * For now this class only supports "Mexican hat"(real) and "Morlet"(complex)
 * wavlet.
 *
 * The mother wavelet doubles and scale parameters are specified by users. The
 * inverse transform can not achive perfect reconstruction, but with a
 * sufficient precision in practice. Of course you can improve the accurate
 * by extend the range of scale parameter.
 *
 * Zhang Ming, 2010-03, Xi'an Jiaotong University.
 *****************************************************************************/


#ifndef CWT_H
#define CWT_H


#include <cstdlib>
#include <string>
#include <complex>
#include <fft.h>
#include <matrix.h>


namespace splab
{

    template <typename Type>
	class CWT
	{

	public:

		CWT( const string &name );
		~CWT();

        void setScales( Type fs, Type fmin, Type fmax, Type dj=0.25 );
        Matrix<Type> cwtR( const Vector<Type> &signal );
		Vector<Type> icwtR( const Matrix<Type> &coefs );
		Matrix< complex<Type> > cwtC( const Vector<Type> &signal );
		Vector<Type> icwtC( const Matrix< complex<Type> > &coefs );

	private:

        string waveType;
        Type delta;
        Vector<Type> scales;
        Matrix<Type> table;

        void setTable( int N );
		Type constDelta();

	};
	// class CWT

    #include <cwt-impl.h>

}
// namespace splab


#endif
// CWT_H

实现文件:

/*
 * Copyright (c) 2008-2011 Zhang Ming (M. Zhang), [email protected]
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 2 or any later version.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
 * more details. A copy of the GNU General Public License is available at:
 * http://www.fsf.org/licensing/licenses
 */


/*****************************************************************************
 *                               cwt-impl.h
 *
 * Implementation for CWT class.
 *
 * Zhang Ming, 2010-03, Xi'an Jiaotong University.
 *****************************************************************************/


/**
 * constructors and destructor
 */
template<typename Type>
CWT<Type>::CWT( const string &name ) : waveType(name)
{
    if( (waveType != "mexiHat") && (waveType != "morlet") )
    {
        cerr << "No such wavelet type!" << endl;
        exit(1);
    }
}

template<typename Type>
CWT<Type>::~CWT()
{
}


/**
 * Set the scales on which the CWT is computed. The frequency of wavelet
 * family generated from the mather wavelet should cover the frequency
 * band of signal for better reconstruction.
 */
template<typename Type>
void CWT<Type>::setScales( Type fs, Type fmin, Type fmax, Type dj )
{
    double  flc = 0,
            fuc = 0,
            a = pow( Type(2), dj );

    if( waveType == "mexiHat" )
    {
        flc = 0.01;
        fuc = 0.7;
    }
    else if( waveType == "morlet" )
    {
        flc = 0.4;
        fuc = 1.2;
    }

	int jMin = int( ceil( log(fs*flc/fmax)/log(2.0) / dj ) ),
        jMax = int( ceil( log(fs*fuc/fmin)/log(2.0) / dj ) ),
        J = jMax-jMin+1;

    scales.resize(J);
	for( int j=0; j<J; ++j )
		scales[j] = Type(pow(a,double(j+jMin)));

}


/**
 * Generate the table use in forward and backward transform. Each row of
 * the table consist of the N-points frequency sampling values' conjugation
 * of the mather wavelet scaled by "scales[j]". In order to normalization,
 * these values are multiplied by a constant.
 */
template <typename Type>
void CWT<Type>::setTable( int N )
{
    int J = scales.size();
    Vector<Type> omega(N), tmp(N);
    table.resize( J, N );

    for( int j=0; j<J; ++j )
    {
        // fundamental frequency
        Type c = Type( 2*PI*scales[j]/N );
        for( int i=0; i<N; ++i )
            if( i <= N/2 )
                omega[i] = c*i;
            else
                omega[i] = c*(i-N);

        // independent variables of exponential function
        if( waveType == "mexiHat" )
        {
            omega *= omega;
            tmp = Type( sqrt(c) ) * ( omega * exp(Type(-0.5)*omega) );
        }
        else if( waveType == "morlet" )
        {
            Type sigma = Type(6.2);
            omega = Type(-0.5) * ( (omega-sigma)*(omega-sigma) );
            tmp = Type( sqrt(c) ) * exp(omega);
        }
        else
        {
            cerr << "No such wavelet type!" << endl;
            exit(1);
        }
        table.setRow( tmp, j );
    }

    delta = constDelta();
}


/**
 * Compute the delta constant, which comes from the substitution
 * of the reconstruction function by the delta  function.
 */
template <typename Type>
Type CWT<Type>::constDelta()
{
    int J = scales.size(),
        N = table.cols();
    Type sum;
    Type C = 0;

    for( int j=0; j<J; ++j )
    {
        sum = 0;
        for( int k=0; k<N; ++k )
            sum += table[j][k];
        C += Type( sum / sqrt(scales[j]) );
    }
    return C;
}


/**
 * Compute the continuous wavelet transform of complex mather wavelet.
 * This is a fast algorithm throuth using FFT by the convolution theorem.
 */
template <typename Type>
Matrix< complex<Type> > CWT<Type>::cwtC( const Vector<Type> &signal )
{
    int N = signal.size(),
        J = scales.size();

    // initialize the coefficients
    Matrix< complex<Type> > coefs(J, N);
    if( (table.cols() != N) || (table.rows() != J) )
        setTable(N);

    // compute the DFT of input signal
    Vector< complex<Type> > sigDFT = fft( signal );

    Vector< complex<Type> > tmp(N), tmpDFT(N);

    for( int j=0; j<J; ++j )
    {
        // compute the DFT of CWT coefficients at scales[j]
        for( int k=0; k<N; ++k )
            tmpDFT[k] = sigDFT[k]*table[j][k];

        tmp = ifft( tmpDFT );
        coefs.setRow( tmp, j );
    }

    return coefs;
}


/**
 * Compute the inverse CWT of complex mather wavelet. The redundancy of
 * CWT makes it possible to reconstruct the signal using a diffient wavelet,
 * the easiest of which is the delta function. In this case, the
 * reconstructed signal is just the sum of the real part of the wavelet
 * transform over all scales.
 */
template <typename Type>
Vector<Type> CWT<Type>::icwtC( const Matrix< complex<Type> > &coefs )
{
    int J = coefs.rows(),
        N = coefs.cols();
    Vector<Type> signal(N);

    // recover "signal"
    for( int i=0; i<N; ++i )
        for( int j=0; j<J; ++j )
            signal[i] += real(coefs[j][i]) / Type(sqrt(scales[j]));

    signal = signal*Type(N) / delta;
    return signal;
}


/**
 * Calculate the continuous wavelet transform of real mather wavelet.
 */
template <typename Type>
Matrix<Type> CWT<Type>::cwtR( const Vector<Type> &signal )
{
    int N = signal.size(),
        J = scales.size();

    // initialize the coefficients
    Matrix<Type> coefs(J, N);
    if( (table.cols() != N) || (table.rows() != J) )
        setTable(N);

    // compute the DFT of input signal
    Vector< complex<Type> > sigDFT = fft( signal );

    Vector<Type> tmp(N);
    Vector< complex<Type> > tmpDFT(N);

    for( int j=0; j<J; ++j )
    {
        // compute the DFT of CWT coefficients at scales[j]
        for( int k=0; k<N; ++k )
            tmpDFT[k] = sigDFT[k]*table[j][k];

        tmp = ifftc2r( tmpDFT );
        coefs.setRow( tmp, j );
    }

    return coefs;
}


/**
 * Calculate the ICWT of real mather wavelet.
 */
template <typename Type>
Vector<Type> CWT<Type>::icwtR( const Matrix<Type> &coefs )
{
    int J = coefs.rows(),
        N = coefs.cols();
    Vector<Type> signal(N);

    // recover "signal"
    for( int i=0; i<N; ++i )
        for( int j=0; j<J; ++j )
            signal[i] += coefs[j][i] / Type(sqrt(scales[j]));

    signal = signal*Type(N) / delta;
    return signal;
}

测试代码:

/*****************************************************************************
 *                               cwt_test.cpp
 *
 * Continuous wavelet transform testing.
 *
 * Zhang Ming, 2010-03, Xi'an Jiaotong University.
 *****************************************************************************/


#define BOUNDS_CHECK

#include <iostream>
#include <vectormath.h>
#include <statistics.h>
#include <timing.h>
#include <cwt.h>


using namespace std;
using namespace splab;


const   int     Ls = 1000;
const   double  fs = 1000.0;


int main()
{

	/******************************* [ signal ] ******************************/
	Vector<double> t = linspace( 0.0, (Ls-1)/fs, Ls );
	Vector<double> st = sin( 200*PI*pow(t,2.0) );
	st = st-mean(st);

	/******************************** [ CWT ] ********************************/
	Matrix< complex<double> > coefs;
	CWT<double> wavelet("morlet");
	wavelet.setScales( fs, fs/Ls, fs/2 );
	Timing cnt;
	double runtime = 0.0;
	cout << "Taking continuous wavelet transform(Morlet)." << endl;
	cnt.start();
	coefs = wavelet.cwtC(st);
	cnt.stop();
	runtime = cnt.read();
	cout << "The running time = " << runtime << " (ms)" << endl << endl;

	/******************************** [ ICWT ] *******************************/
	cout << "Taking inverse continuous wavelet transform." << endl;
	cnt.start();
	Vector<double> xt = wavelet.icwtC(coefs);
	cnt.stop();
	runtime = cnt.read();
	cout << "The running time = " << runtime << " (ms)" << endl << endl;

	cout << "The relative error is : " << endl;
	cout << "norm(st-xt) / norm(st) = " << norm(st-xt)/norm(st) << endl;
	cout << endl << endl;


    /******************************* [ signal ] ******************************/
    Vector<float> tf = linspace( float(0.0), (Ls-1)/float(fs), Ls );
	Vector<float> stf = sin( float(200*PI) * pow(tf,float(2.0) ) );
	stf = stf-mean(stf);

	/******************************** [ CWT ] ********************************/
	CWT<float> waveletf("mexiHat");
	waveletf.setScales( float(fs), float(fs/Ls), float(fs/2), float(0.25) );
	runtime = 0.0;
	cout << "Taking continuous wavelet transform(Mexican Hat)." << endl;
	cnt.start();
	Matrix<float> coefsf = waveletf.cwtR(stf);
	cnt.stop();
	runtime = cnt.read();
	cout << "The running time = " << runtime << " (ms)" << endl << endl;

	/******************************** [ ICWT ] *******************************/
	cout << "Taking inverse continuous wavelet transform." << endl;
	cnt.start();
	Vector<float> xtf = waveletf.icwtR(coefsf);
	cnt.stop();
	runtime = cnt.read();
	cout << "The running time = " << runtime << " (ms)" << endl << endl;

	cout << "The relative error is : " << endl;
	cout << "norm(st-xt) / norm(st) = " << norm(stf-xtf)/norm(stf) << endl << endl;

	return 0;
}

运行结果:

Taking continuous wavelet transform(Morlet).
The running time = 0.046 (ms)

Taking inverse continuous wavelet transform.
The running time = 1.73472e-018 (ms)

The relative error is :
norm(st-xt) / norm(st) = 0.000604041


Taking continuous wavelet transform(Mexican Hat).
The running time = 0.032 (ms)

Taking inverse continuous wavelet transform.
The running time = 0.015 (ms)

The relative error is :
norm(st-xt) / norm(st) = 0.00157322


Process returned 0 (0x0)   execution time : 0.156 s
Press any key to continue.

你可能感兴趣的:(连续小波变换算法的C++实现)