用python实现经验模态分解+小波软阈值去噪

PyEmd模块安装

试过很多博主说的pip insyall PyEmd都失败了,偶然间运气好发现正确的安装方式是pip install PyEmd-signal。如果找不到相关的库或者模块,直接去github上去搜索,上面有很详细的安装教程,不要被误导

pywt模块安装

pywt可以实现小波分解与重构,小波阈值降噪,小波包分解等功能,同样安装也是用相应的pip instal pywt来进行安装,如果找不到还是去github上寻找。

特别说明

关于EMD类方法和小波阈值降噪的相关理论知识可直接百度或者在知网找几篇硕博论文来看,里面有详细的推导过程。

不要祈求能推导明白相关公式,这些公式很复杂一边人理解不了,个人建议了解相关算法流程就可以了,相关的优化也是不同方法之间的排列组合,还是要多多尝试。

以上只是个人看法,不喜欢可以划走,别在这体现自己的优越感,我写这文章的目的就是单纯的记录!!!!

代码实现

由于写代码的时候,个人的理论了解程度仅仅停留在入门阶段,IMF分量(EMD类方法分解得到的分量)的选择是凭借个人感觉来选择的,正确的做法是计算多尺度排列熵(github上可以找到相关的模块,个人正在研究相关代码)、相关系数等来进行选择。

个人的玩具代码,有很多不严谨的地方。

EMD类方法实现

import numpy as np
from PyEMD import EMD,EEMD,CEEMDAN,Visualisation
from threshold import Threshold
from matplotlib import pyplot as plt
import pywt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']

def read_txt_file(input_file_path):
    """该函数主要用来从txt文件中读取所需数据,并转换数据类型
    输入为待处理文件的路径
    输出为一个存放txt文件数据的列表"""
    file_list = []
    file = open(input_file_path)
    file_lines = file.readline()
    file_lines = list(file_lines)
    file_lines.pop(0)
    file_lines.pop(-1)
    file_lines = ''.join(file_lines)
    cur = file_lines.strip().split(",")
    for i in range(0,len(cur)):
        file_list.append(float(cur[i]))
    #print(file_list)
    return file_list

class EmdFunction:
    """
    emd方法的调用
    """
    def __init__(self,data,function_name,sym,level,imfs_start_step,thr_select,thr_way):

        X = np.array(data)
        self.signal = (X - np.mean(X)) / np.std(X)
        self.function_name = function_name
        self.soft_threshold = Threshold(data,thr_select,sym,level,thr_way)
        self.imfs_start_step = imfs_start_step


    def emd_completed(self):

        if self.function_name == 'EMD':
            emd = EMD()
            emd.emd(self.signal)
            ims,res = emd.get_imfs_and_residue()
            return ims,res
        elif self.function_name == 'EEMD':
            eemd = EEMD()
            eemd.eemd(self.signal)
            ims,res = eemd.get_imfs_and_residue()
            return ims,res
        elif self.function_name == 'CEEMDAN':
            ceemdan = CEEMDAN()
            ceemdan.ceemdan(self.signal)
            ims,res = ceemdan.get_imfs_and_residue()
            return ims,res

    def plot_imfs_and_res(self,imfs,res):
        t = np.arange(0,len(self.signal),1)
        vis = Visualisation()
        vis.plot_imfs(imfs=imfs,residue=res,t=t,include_residue=True)
        vis.show()

    def wavelet_and_emd(self):

        useful_imfs_add = np.zeros(len(self.signal)).tolist()
        imfs,res = self.emd_completed()
        for i in range(self.imfs_start_step,len(imfs)):
            useful_imfs_add += imfs[i]

        for j in range(0,self.imfs_start_step):
            data = imfs[j]
            #mid_param = self.soft_threshold.wavelet_dec_rec(data)
            mid_param = self.soft_threshold.wavelet_dec_rec(data)
            useful_imfs_add  += mid_param
        return useful_imfs_add

    def plot_org_sotfthreshold(self):

        end_signal = self.wavelet_and_emd()
        snr = self.soft_threshold.compute_snr(self.signal,end_signal)
        rmse = self.soft_threshold.compute_mse(self.signal,end_signal)
        print('snr:{} , rmse:{}'.format(snr, rmse))
        figure,(ax1,ax2) = plt.subplots(nrows=2,ncols=1)
        ax1.plot(self.signal, label='org signal')
        ax1.set_title('降噪前的信号')
        ax1.legend()

        ax2.plot(end_signal, 'g',label='after wavele')
        ax2.set_title('降噪后的信号')
        ax2.legend()
        plt.show()



if __name__ == "__main__":
    path = 'D:\桌面文件夹\数据文件/0001001228.txt'
    #path = 'D:\桌面文件夹\新建文件夹 (3)/1101000158.txt'

    data = read_txt_file(path)

    function_name = 'CEEMDAN'

    sym = 'sym8'
    level = 3
    imfs_start_step = 4
    thr_select = 'sqtwolog'
    thr_way = 'soft'

    emdfunction = EmdFunction(data,function_name,sym,level,imfs_start_step,thr_select,thr_way)
    # ims,res=emdfunction.emd_completed()
    # emdfunction.plot_imfs_and_res(ims,res)
    emdfunction.plot_org_sotfthreshold()


小波软阈值代码实现

import numpy as np
import os
from matplotlib import pyplot as plt
import pywt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']


def read_txt_file(input_file_path):
    """该函数主要用来从txt文件中读取所需数据,并转换数据类型
    输入为待处理文件的路径
    输出为一个存放txt文件数据的列表"""
    file_list = []
    file = open(input_file_path)
    file_lines = file.readline()
    file_lines = list(file_lines)
    file_lines.pop(0)
    file_lines.pop(-1)
    file_lines = ''.join(file_lines)
    cur = file_lines.strip().split(",")
    for i in range(0,len(cur)):
        file_list.append(float(cur[i]))
    #print(file_list)
    return file_list

class Threshold:
    """小波阈值降噪"""
    def __init__(self,data,thr_select,wave_bais='sym8',level=3,thr_way='soft'):

        if type(data) == list:
            X = np.array(data)
            self.data = (X - np.mean(X)) / np.std(X)
        else:
            self.data = (data - np.mean(data)) / np.std(data)

        self.data = data

        self.wave_bais = wave_bais
        self.level = level
        self.thr_way = thr_way
        if thr_select in ['rigrsure','heursure','sqtwolog','minimaxi']:
            self.thr_select = thr_select
        else:
            raise print('取值计算函数名称错误,请重新输入')

    def thrselect(self,data):
        """
        阈值lambda的计算方式选择
        :return: 返回阈值
        """
        N = len(data)
        if self.thr_select == 'sqtwolog':
            #固定阈值
            thr = round(np.sqrt(2.0 * np.log(N)),4)
            return thr

        elif self.thr_select == 'minimaxi':
            #极大极小阈值
            if N<32:
                thr =0
            else:
                thr = 0.3936 + 0.1829*(np.log(N)/np.log(2))
            return thr

        elif self.thr_select == 'rigrsure':
            # #风险阈值
            # sx = np.sort(abs(self.data))
            # sx2 = np.square(sx)
            # N1 = np.repeat(N-2*[i for i in range(0,N)],1)
            pass
            return -1

        elif self.thr_select == 'heursure':
            pass
            return -1

    def wavelet_dec_rec(self,data):
        """小波分解"""
        coffe = pywt.wavedec(data,self.wave_bais,level=self.level)
        #低频分量分量
        ca = coffe[0]
        #高频分量
        cd_out_list = []
        cd_out_list.append(ca)
        #阈值
        thr = self.thrselect(data)
        for i in range(1,len(coffe)):
            cd = coffe[i]
            ysotf = pywt.threshold(cd,thr,self.thr_way)
            cd_out_list.append(ysotf)

        Y = pywt.waverec(cd_out_list,self.wave_bais)
        return Y

    def plot_signal(self,data):
        #获得降噪后的信号
        Y = self.wavelet_dec_rec(data)

        #绘制原始图像
        figure, axes = plt.subplots(2, 1)
        ax1 = axes[0]
        ax1.set_title('降噪前的信号')
        ax1.plot(self.data)
        #绘制降噪的图像
        ax2 = axes[1]
        ax2.set_title('降噪后的信号')
        ax2.plot(Y,color='g')
        plt.show()

    @staticmethod
    def compute_snr(org_signal, final_signal):
        """
        信噪比:信噪比越大越好
        均方根误差:均方根误差越小越好,越小去噪效果越好
        :param org_signal:原始信号
        :param final_signal:降噪后的信号
        :return: 信噪比,均方根误差
        """

        clean = np.array(final_signal)
        org_signal = np.array(org_signal)
        #est_noise = org_signal - clean
        # power_data = np.mean(np.square(data))
        # power_noise = np.mean(np.square(data - final_signal))
        #snr = 10 * np.log10((np.sum(clean ** 2)) / (np.sum(est_noise ** 2)))
        # snr = (math.log((power_data/power_noise),10) )* 10
        sigPower = sum(abs(clean) ** 2) / len(clean)  # 求出信号功率
        noisePower = sum(abs(org_signal - clean) ** 2) / len(org_signal - clean)  # 求出噪声功率
        SNR_10 = 10 * np.log10(sigPower / noisePower)
        #SNR_10 = (sigPower / noisePower)
        return SNR_10

    @staticmethod
    def compute_mse(org_signal, final_signal):
        """
        计算均方根误差:均方根误差越小越好,越小去噪效果越好
        :param org_signal:原始信号
        :param final_signal:降噪后的信号
        :return: 均方根误差
        """
        data = np.array(org_signal)
        final_signal = np.array(final_signal)
        rmse = np.sqrt(np.mean(np.square(data - final_signal)))
        return rmse




if __name__ == "__main__":
    #path = 'D:\桌面文件夹\新建文件夹 (3)/1101000158.txt'
    path = 'D:\桌面文件夹\数据文件/0001001228.txt'

    data = read_txt_file(path)
    data = (data - np.mean(data)) / np.std(data)

    thr_select = 'sqtwolog'
    wave_bais = 'sym8'
    level = 3
    thr_way = 'soft'

    wave = Threshold(data,thr_select,wave_bais,level,thr_way)
    Y = wave.wavelet_dec_rec(data)

    snr = wave.compute_snr(wave.data,Y)
    rmse = wave.compute_mse(wave.data,Y)
    print('snr:{},rmse:{}'.format(snr,rmse))

    wave.plot_signal(data)






你可能感兴趣的:(信号处理,python)