TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别

TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别(论文复现)

  • 摘要
  • 模型结构
  • 代码实现
  • 写在最后

**这是一篇代码复现,原文通过Pytorch实现,本文中使用Keras对该结构进行复现。**该论文发表在IEEE Transactions on Affective Computing,第一作者Yi Ding

摘要

高时间分辨率和不对称的空间激活是脑内情绪过程的基本特征。为了学习EEG的时间动态性和空间不对称性,以实现准确和广义的情感识别,Yi Ding等人提出了一种多尺度卷积神经网络TSception,可以从EEG中对情感进行分类。Tsception由动态时间、非对称空间和高级融合层组成,同时学习时间和信道维度的区分表示。动态时域层由多尺度1D卷积核组成,其长度与EEG的采样率相关,其学习EEG的动态时间和频率表示。非对称空间层利用情绪的非对称EEG模式,学习有区别的全局和半球表示。原文代码可在以下网址获得:https://github.com/yi-ding-cs/TSception,本文复现完整代码可在下面获得:https://github.com/ruix6/tsception

模型结构

关于模型结构的相关公式推理可以参考原文,本文不详细展开,下图是模型的具体结构:
TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别_第1张图片
熟悉经典深度学习模型的同学应该能一眼看出来TSception的设计灵感来自Inception模型。结合脑电信号的特点,TSception分四步实现对脑电信号的计算。

  1. 多尺度时域卷积:第一步通过多个尺度的时域卷积核实现对EEG信号的分解与特征提取。多尺度的优势在于可以给模型提供多个不同的感受野,这对于脑电信号这种多源的复杂信号来说是十分合理的。作为对比,我们可以看一下EEGNet的结构,如下图,EEGNet的第一个部分也是一个一维时域卷积结构,但是由于单一的感受野,在很多任务中它很容易被低频部分的噪声所干扰,所以EEGNet在ERP或者SMR这种有效信息分布在低频段的任务比较友好,但是像情绪识别的话,该网络的原始结构似乎不能发挥其全部能量(改变其时域卷积核大小似乎能有效提升其能力)。TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别_第2张图片
  2. 不对称空间卷积层:模型的第二部分由两个尺度的卷积实现。大尺度的卷积层覆盖所有通道,小尺度的卷积层分可以分别卷积大脑左半球的通道和右半球的通道。前面提到过,不对称的空间激活(即受试者在不同的情绪状态下,大脑的左右半球的激活状态是不一样的,原文中分析了受试者的大脑激活状态,想进一步了解可以看看原文)是情绪的重要特征,所以通过小尺度的空间卷积可以进一步的抓住这些特征。
    TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别_第3张图片
  3. 高级融合层:该层为了进一步的融合输入的时空特征而设计,这个和EEGNet的可分离卷积层的效果是一样的,我更愿意称之为为了减小参数量而设计的,这个卷积层的出现,其实使得模型的可解释性进一步降低。
  4. 分类:把高级融合层的输出做个全局平均池化然后全连接,最后输出。

代码实现

代码如下:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Dense, Dropout, BatchNormalization, concatenate, LeakyReLU


# 定义时域卷积块
def conv_block(input, out_chan, kernel, step, pool):
    x = Conv2D(out_chan, kernel, strides=step)(input)# padding='same', use_bias=False
    x = LeakyReLU()(x)
    x = AveragePooling2D(pool_size=(1, pool), strides=(1, pool))(x)

    return x


def Tsception(num_classes, Chans, Samples, sampling_rate, num_T, num_S, hidden, dropout_rate, pool=8):

    '''
    input_size: 输入数据的维度,(chans, samples, 1)
    '''
    inception_window = [0.5, 0.25, 0.125]
    # 定义输入层
    input = Input(shape=(Chans, Samples, 1))
    # 定义时域卷积层
    x1 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[0])), 1, pool)
    x2 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[1])), 1, pool)
    x3 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[2])), 1, pool)
    # 在height维度上进行拼接
    x = concatenate([x1, x2, x3], axis=2)
    x = BatchNormalization()(x)
    # 定义空域卷积层
    y1 = conv_block(x, num_S, (Chans, 1), (Chans, 1), int(pool*0.25))
    y2 = conv_block(x, num_S, (int(Chans*0.5), 1), (int(Chans*0.5), 1), int(pool*0.25))
    # 在width维度上进行拼接
    y = concatenate([y1, y2], axis=1)
    y = BatchNormalization()(y)
    # 定义fusion_layer
    z = conv_block(y, num_S, (3, 1), (3, 1), 4)
    z = BatchNormalization()(z)
    # 定义全局平均池化层
    z = AveragePooling2D(pool_size=(1, z.shape[2]))(z)
    z = Flatten()(z)
    # 全连接层
    z = Dense(hidden, activation='relu')(z)# , use_bias=False
    z = Dropout(dropout_rate)(z)
    z = Dense(num_classes, activation='softmax')(z)# , use_bias=False

    return Model(inputs=input, outputs=z)

参照原文的各个超参数,引用方式为:

if __name__ == '__main__': 
    model = Tsception(num_classes=2, Chans=28, Samples=512, sampling_rate=128, num_T=15, num_S=15, hidden=32, dropout_rate=0.5)
    model.summary()

最后的输出为:

__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_1 (InputLayer)           [(None, 28, 512, 1)  0           []
                                ]

 conv2d (Conv2D)                (None, 28, 449, 15)  975         ['input_1[0][0]']

 conv2d_1 (Conv2D)              (None, 28, 481, 15)  495         ['input_1[0][0]']

 conv2d_2 (Conv2D)              (None, 28, 497, 15)  255         ['input_1[0][0]']

 leaky_re_lu (LeakyReLU)        (None, 28, 449, 15)  0           ['conv2d[0][0]']

 leaky_re_lu_1 (LeakyReLU)      (None, 28, 481, 15)  0           ['conv2d_1[0][0]']

 leaky_re_lu_2 (LeakyReLU)      (None, 28, 497, 15)  0           ['conv2d_2[0][0]']

 average_pooling2d (AveragePool  (None, 28, 56, 15)  0           ['leaky_re_lu[0][0]']
 ing2D)

 average_pooling2d_1 (AveragePo  (None, 28, 60, 15)  0           ['leaky_re_lu_1[0][0]']
 oling2D)

 average_pooling2d_2 (AveragePo  (None, 28, 62, 15)  0           ['leaky_re_lu_2[0][0]']
 oling2D)

 concatenate (Concatenate)      (None, 28, 178, 15)  0           ['average_pooling2d[0][0]',
                                                                  'average_pooling2d_1[0][0]',
                                                                  'average_pooling2d_2[0][0]']

 batch_normalization (BatchNorm  (None, 28, 178, 15)  60         ['concatenate[0][0]']
 alization)

 conv2d_3 (Conv2D)              (None, 1, 178, 15)   6315        ['batch_normalization[0][0]']

 conv2d_4 (Conv2D)              (None, 2, 178, 15)   3165        ['batch_normalization[0][0]']

 leaky_re_lu_3 (LeakyReLU)      (None, 1, 178, 15)   0           ['conv2d_3[0][0]']

 leaky_re_lu_4 (LeakyReLU)      (None, 2, 178, 15)   0           ['conv2d_4[0][0]']

 average_pooling2d_3 (AveragePo  (None, 1, 89, 15)   0           ['leaky_re_lu_3[0][0]']
 oling2D)

 average_pooling2d_4 (AveragePo  (None, 2, 89, 15)   0           ['leaky_re_lu_4[0][0]']
 oling2D)

 concatenate_1 (Concatenate)    (None, 3, 89, 15)    0           ['average_pooling2d_3[0][0]',
                                                                  'average_pooling2d_4[0][0]']

 batch_normalization_1 (BatchNo  (None, 3, 89, 15)   60          ['concatenate_1[0][0]']
 rmalization)

 conv2d_5 (Conv2D)              (None, 1, 89, 15)    690         ['batch_normalization_1[0][0]']

 leaky_re_lu_5 (LeakyReLU)      (None, 1, 89, 15)    0           ['conv2d_5[0][0]']

 average_pooling2d_5 (AveragePo  (None, 1, 22, 15)   0           ['leaky_re_lu_5[0][0]']
 oling2D)

 batch_normalization_2 (BatchNo  (None, 1, 22, 15)   60          ['average_pooling2d_5[0][0]']
 rmalization)

 average_pooling2d_6 (AveragePo  (None, 1, 1, 15)    0           ['batch_normalization_2[0][0]']
 oling2D)

 flatten (Flatten)              (None, 15)           0           ['average_pooling2d_6[0][0]']

 dense (Dense)                  (None, 32)           512         ['flatten[0][0]']

 dropout (Dropout)              (None, 32)           0           ['dense[0][0]']

 dense_1 (Dense)                (None, 2)            66          ['dropout[0][0]']

==================================================================================================
Total params: 12,653
Trainable params: 12,563
Non-trainable params: 90
__________________________________________________________________________________________________

原文结果,测试数据集是DEAP情绪数据集:
TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别_第4张图片

写在最后

原文的作者并没有对模型进行更加具体的调参,事实上,输出的全连接层的神经元设置为32应该是意义不大的,效果可能不如直接链接到输出层,并且如果要考虑进一步缩小参数量的话,各个卷积层的偏置权重其实可以去除。在原文中,模型的效果表现与其它的卷积神经网络并没有统计学意义上的差别,但是,如果能够进一步调参的话,效果实际上要好很多。

你可能感兴趣的:(深度学习,人工智能)