<阿瑶机器学习之路>使用SNN对DEAP数据集进行情绪四分类

目录

 SNN基础知识讲解

​​​​​

​ DEAP数据集介绍

​​​​​

​ 使用SNN搭建一维Resnet网络进行情绪分类

​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​

​ 尾言

​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​

​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​


 

 

 

ef46fda6c8184735a588a7044cd05630.png SNN基础知识讲解

       dc60211d0ad9406998eb75a6d19f1666.pngSpiking Neural Network(脉冲神经网络,SNN)简介

                第一代神经网络(感知器),第二代神经网络(ANN)它们都是基于神经脉冲的发放频率进行编码,但是神经元的脉冲发放频率并不能完全捕获脉冲序列种包含的信息,因此第三代神经网络(SNN)登场了。第三代神经网络具有更强的生物可解释性的,神经网络内部的信息传递是由脉冲序列完成的。

       dc60211d0ad9406998eb75a6d19f1666.pngSNN和传统ANN的区别

                在传统的前馈人工神经网络中,两个神经元之间仅有一个突触连接,即输出只有一个权值去决定,这些权值受到的是神经脉冲的发放频率的影响。

                在SNN中的前馈型脉冲神经网络种,两个神经元之间是由多个突触连接的方式,每个突触具有不同的延时和权值,因此可以使得突触前神经元输入的脉冲能影响一段时间范围内突触后神经元发放的脉冲

       dc60211d0ad9406998eb75a6d19f1666.pngANN转SNN的理论基础

                SNN相比于ANN,产生的脉冲是离散的,这有利于高效的通信,但是SNN直接训练需要比较多的资源,且代码实现也比较复杂,因此我们自然会想到使用现在非常成熟的ANN转换到SNN。

                现在SNN主流的方式是采用频率编码,因此对于输出层,我们会用神经元输出脉冲数来判断类别。发放率和ANN有没有关系呢?答案是肯定的,ANN中的ReLU神经元非线性激活和SNN中IF神经元(采用减去阈值 V_threshold 方式重置)的发放率有着极强的相关性,我们可以借助这个特性来进行转换

       dc60211d0ad9406998eb75a6d19f1666.pngSNN的实现——spikingjelly

                 spikingjelly 是一个基于 PyTorch ,使用脉冲神经网络(Spiking Neural Network, SNN)进行深度学习的框架。

                中文文档连接:SpikingJelly(惊蛰)

         

 dc60211d0ad9406998eb75a6d19f1666.pngdc60211d0ad9406998eb75a6d19f1666.png​​ef46fda6c8184735a588a7044cd05630.png

 

 

ef46fda6c8184735a588a7044cd05630.png​ DEAP数据集介绍

       dc60211d0ad9406998eb75a6d19f1666.png​DEAP数据集组成(Python预处理版128Hz采样)

                DEAP数据集包含32个.mat文件,一部分字典键为data的数据,内含每名被实验者的脑电实验数据,数据采样频率为128Hz,另一部分为字典键为labels的数据,形状为40*4的矩阵,四列分别代表效价、唤醒度、支配和喜欢的值,其值由被试者在观看40段视频后进行打分得到。

       dc60211d0ad9406998eb75a6d19f1666.png​基于valence 和 arousal的值将情绪分为四类

                

<阿瑶机器学习之路>使用SNN对DEAP数据集进行情绪四分类_第1张图片

 

       dc60211d0ad9406998eb75a6d19f1666.png​DEAP中关于40个通道的解释

                DEAP数据集中有32个脑电通道,另外8个通道是其它传感器采集的信号,在32个EEG通道/电极中,14个通道用于收集情绪数据,它们被称为情绪通道。这些情绪通道列表,分别是:AF3、F3、F7、FC5、T7、P7、O1、AF4、F4、F8、FC6、T8、P8、O2。

                分别对应索引为:1,2,3,4,7,11,13,17,19,20,21,25,29,31

       dc60211d0ad9406998eb75a6d19f1666.png​本文使用的数据集预处理介绍

                原始读取到的字典键为data的数据是为40*40*8064的矩阵,第一个40是40次实验,第二个40是40个通道,8064是(63秒是采集时间*128Hz的采样频率)。然后我们预处理后的数据是40*40*14*15*512,即输入的形状为(268800,512)

                原始读取到的字典键为labels的数据,形状为40*4的矩阵,第一个40为40次实验,4代表效价、唤醒度、支配和喜欢的值,然后我们预处理后标签的形状为(268800,4)


def data_preprocessed():
    def label_mapping(valence, arousal):
        # 根据valence 和arousal的值将数据划分为四类
        if (valence > 4.5 and arousal > 4.5):
            label = 0  
        elif (valence > 4.5 and arousal <= 4.5):
            label = 1  
        elif (valence <= 5 and arousal > 4.5):
            label = 2  
        elif (valence <= 4.5 and arousal <= 4.5):
            label = 3  

        return label

    def map_label(label):
        new_label = label_mapping(label[0], label[1])
        return new_label


    channel = [1, 2, 3, 4, 7, 11, 13, 17, 19, 20, 21, 25, 29, 31]  # 要读取那个通道的数据进行预测,这里只取了14个通道进行预测
    window_size = 512 #滑动窗口长
    step_size = 512 #滑动窗口步长

# 将path替换为自己的路径
    def epoching(sub, channel, window_size, step_size):
        signal = []
        # 这个路径是DEAPpython预处理数据的文件夹路径,
        with open("path" + sub + '.dat', 'rb') as file:

            subject = pickle.load(file, encoding='latin1')  # resolve the python 2 data problem by encoding : latin1

            for i in range(0, 40):
                # loop over 0-39 trials
                data = subject["data"][i]
                labels = subject["labels"][i]
                labels = np.array(map_label(labels))
                start = 0;

                while start + window_size <= data.shape[1]:
                    array = []
                    for j in channel:
                        X = data[j][start: start + window_size]
                        array.append(np.array(X))
                        array.append(np.array(labels))
                        signal.append(np.array(array))

                    start = start + step_size

            signal = np.array(signal)
            np.save('path' + sub, signal, allow_pickle=True, fix_imports=True)

    for subjects in subject_list:
        epoching(subjects, channel, window_size, step_size)
        print(subjects)
    data = []
    label = []
    for subjects in subject_list:
        with open('path' + subjects + '.npy', 'rb') as file:
            sub = np.load(file, allow_pickle=True)
            for i in range(0, sub.shape[0]):
                data.append(sub[i][0])
                label.append(sub[i][1])

                

 

  dc60211d0ad9406998eb75a6d19f1666.pngdc60211d0ad9406998eb75a6d19f1666.png​​ef46fda6c8184735a588a7044cd05630.png

 

 

ef46fda6c8184735a588a7044cd05630.png​ 使用SNN搭建一维Resnet网络进行情绪分类

       dc60211d0ad9406998eb75a6d19f1666.png​SNN_Resnet代码实现

                

class Bottlrneck(torch.nn.Module):
    def __init__(self,In_channel,Med_channel,Out_channel,downsample=False):
        super(Bottlrneck, self).__init__()
        self.stride = 1
        if downsample == True:
            self.stride = 2

        self.layer = torch.nn.Sequential(
            torch.nn.Conv1d(In_channel, Med_channel, 1, self.stride,bias=False),
            torch.nn.BatchNorm1d(Med_channel),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            torch.nn.Conv1d(Med_channel, Med_channel, 3, padding=1,bias=False),
            torch.nn.Dropout(0.5),
            torch.nn.BatchNorm1d(Med_channel),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
            torch.nn.Dropout(0.5),
            torch.nn.Conv1d(Med_channel, Out_channel, 1,bias=False),
            torch.nn.BatchNorm1d(Out_channel),
            neuron.IFNode(surrogate_function=surrogate.ATan()),
        )

        if In_channel != Out_channel:
            self.res_layer = torch.nn.Conv1d(In_channel, Out_channel,1,self.stride)
        else:
            self.res_layer = None

    def forward(self,x):
        if self.res_layer is not None:
            residual = self.res_layer(x)
        else:
            residual = x
        return self.layer(x)+residual


class SNN_ResNet(torch.nn.Module):
    def __init__(self,in_channels=1):
        super(SNN_ResNet, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels, 16, kernel_size=7, stride=2, padding=3, bias=False),
            torch.nn.MaxPool1d(3, 2, 1),
        )

        self.features = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels, 16, kernel_size=7, stride=2, padding=3, bias=False),
            torch.nn.MaxPool1d(3, 2, 1),
            Bottlrneck(16, 64, 256, False),
            
            Bottlrneck(256, 128, 512, True),
            
            Bottlrneck(512, 256, 1024, True),
            
            Bottlrneck(1024, 512, 2048, True),

            torch.nn.AdaptiveAvgPool1d(1)
        )
        self.classifer = torch.nn.Sequential(
            torch.nn.Linear(2048,4,bias=False),
        )

    def forward(self,x):
        x=x.permute(0,2,1)
        out=0
        for i in range(50):
            k = self.features(x)
            k = k.view(-1,2048)
            out+= self.classifer(k)
        out=out/50.0

        return out

       dc60211d0ad9406998eb75a6d19f1666.png​重点代码解读——IF神经元

                5ffe574c2fcf483eac7043c762f936a7.png

                 上文提到,只要将ANN中的激活函数换成脉冲神经元,那么就可以将ANN转化为SNN,因此,这是将ANN转化为SNN最关键的代码。

                作用:脉冲神经元接受脉冲输入,输入脉冲输出

                IF脉冲神经元和LIF脉冲神经元的区别在于,在T次重复刺激中,LIF每次膜电位会先衰减在充电。而IF脉冲神经元不会衰减。

                脉冲神经元是有记忆的,脉冲神经元在接受脉冲输入后便会进行充电,当膜电位不超过阈值电压v_threshold 时,其膜电位会一直保持,并输出0,当持续充电后膜电位超过阈值电压,便会放电输出1,然后该脉冲神经元便会重置回V_reset

                一层IFNode里面所有脉冲神经元放电便形成脉冲,但是单次输入刺激获得的脉冲输出是不能作为分类依据的,因为神经元要持续刺激才会产生变化,所以我们要将同一个x 输入进网络 循环T个仿真步长 目的就是对输出层所有神经元的输出脉冲进行累加,得到输出层脉冲释放次数 使用脉冲释放次数除以仿真时长T,得到输出层脉冲发放频率,然后以这个频率进行分类,如代码所示

        out=0
        for i in range(50):
            k = self.features(x)
            k = k.view(-1,2048)
            out+= self.classifer(k)
        out=out/50.0

 dc60211d0ad9406998eb75a6d19f1666.pngdc60211d0ad9406998eb75a6d19f1666.png​​ef46fda6c8184735a588a7044cd05630.png​​

 

 

ef46fda6c8184735a588a7044cd05630.png​ 尾言

       dc60211d0ad9406998eb75a6d19f1666.png​完整代码获取方式

                这是帮某位同学做的,因此不便直接公开,对此感兴趣或者想深入研究的,可以联系我获取完整代码以及数据集(免费)

       dc60211d0ad9406998eb75a6d19f1666.png​关于此文

                本人才疏学浅,文章内容难免有错误之处,欢迎指出错误,本人将虚心学习,一起进步,谢谢。

   dc60211d0ad9406998eb75a6d19f1666.pngdc60211d0ad9406998eb75a6d19f1666.png​​ef46fda6c8184735a588a7044cd05630.png

 

 欢迎私信我一起讨论机器学习、深度学习方面的作业,互相学习进步。

----------------------------------------------------------------天空黑暗到一定程度,星辰就会熠熠生辉

 

  dc60211d0ad9406998eb75a6d19f1666.pngdc60211d0ad9406998eb75a6d19f1666.png​​ef46fda6c8184735a588a7044cd05630.png

 

 

 

你可能感兴趣的:(SNN,阿瑶机器学习之路,分类,人工智能)