用于图神经网络的脑电数据处理实现

基于论文的数据集代码实现:
原理可以参考:

1、图神经网络EEG论文阅读和分析:《EEG-Based Emotion Recognition Using Regularized Graph Neural Networks》_KPer_Yang的博客-CSDN博客

2、EEG-GNN论文阅读和分析:《EEG Emotion Recognition Using Dynamical Graph Convolutional Neural Networks》_KPer_Yang的博客-CSDN博客

构建特征

# -*- coding: utf-8 -*-
#
# Copyright (C) 2022 Emperor_Yang, Inc. All Rights Reserved 
#
# @CreateTime    : 2023/1/15 13:57
# @Author        : Emperor_Yang 
# @File          : feature_x.py
# @Software      : PyCharm

from utils.load_m_data import load_m_data
import numpy as np


def build_ML_feature_data(feature_path: str, label_path: str):
    """
    build traditional Machine learning data format
    :param feature_path:
    :param label_path:
    :return:
    """
    # 基本信息
    feature_dict = load_m_data(feature_path)
    label_dict = load_m_data(label_path)
    subject_num = 15
    channel_num = 62
    frequency_band_num = 5

    # 求样本数
    sample_num_s = []
    sample_num_sum = 0
    for i in range(subject_num):
        feature_3d = feature_dict['de_LDS' + str(i+1)]
        sample_slice_num = feature_3d.shape[1]
        sample_num_s.append(sample_slice_num)
        sample_num_sum += sample_slice_num

    # 构造特征向量和标签,按照(samples, [channels_1, ..., channels_5])
    all_index = 0
    feature_all = np.zeros((sample_num_sum, channel_num * frequency_band_num))
    label_all = np.zeros((sample_num_sum, 1))
    for subject_index in range(subject_num):
        feature_3d = feature_dict['de_LDS' + str(subject_index+1)]
        sample_slice_num = feature_3d.shape[1]
        for sample_index in range(sample_slice_num):
            for band_index in range(frequency_band_num):
                for channel_index in range(channel_num):
                    feature_all[all_index + sample_index, frequency_band_num * band_index + channel_index] \
                        = feature_3d[channel_index, sample_index, band_index]
            label_all[all_index + sample_index] = label_dict['label'][0, subject_index]
        all_index += sample_num_s[subject_index]

    return feature_all, label_all


def build_graph_feature_data(feature_path: str, label_path: str):
    """
    build graph NN data format
    :param feature_path:
    :param label_path:
    :return:
    """
    # 基本信息
    feature_dict = load_m_data(feature_path)
    label_dict = load_m_data(label_path)
    subject_num = 15
    channel_num = 62
    frequency_band_num = 5

    # 求样本数
    sample_num_s = []
    sample_num_sum = 0
    for i in range(subject_num):
        feature_3d = feature_dict['de_LDS' + str(i+1)]
        sample_slice_num = feature_3d.shape[1]
        sample_num_s.append(sample_slice_num)
        sample_num_sum += sample_slice_num

    # 构造特征向量和标签,按照(samples, [node_1], ..., [node_62])  shape:(sample, 62, 5)
    all_index = 0
    feature_all = np.zeros((sample_num_sum, channel_num, frequency_band_num))
    label_all = np.zeros((sample_num_sum, 1))
    for subject_index in range(subject_num):
        feature_3d = feature_dict['de_LDS' + str(subject_index+1)]
        current_sample_num = feature_3d.shape[1]
        for sample_index in range(current_sample_num):
            for band_index in range(frequency_band_num):
                for channel_index in range(channel_num):
                    feature_all[all_index + sample_index, channel_index, band_index] \
                        = feature_3d[channel_index, sample_index, band_index]
            label_all[all_index + sample_index] = label_dict['label'][0, subject_index]
        all_index += sample_num_s[subject_index]
    # print(label_all)
    # print(feature_all)
    return feature_all, label_all


if __name__ == '__main__':
    feature_path_g = '../data/SEED/ExtractedFeatures/1_20131027.mat'
    label_path_g = '../data/SEED/ExtractedFeatures/label.mat'
    build_graph_feature_data(feature_path_g, label_path_g)

构建邻接矩阵

# -*- coding: utf-8 -*-
#
# Copyright (C) 2022 Emperor_Yang, Inc. All Rights Reserved 
#
# @CreateTime    : 2023/1/17 21:39
# @Author        : Emperor_Yang 
# @File          : edge_weight.py
# @Software      : PyCharm
import numpy as np
from utils.load_channel_index import get_channel_index
from utils.local_connect_matrix import get_local_connect_matrix


def build_edge_weight_DGCNN(dist_ij_2D: np.array) -> np.array:
    """
     paper:《EEG Emotion Recognition Using Dynamical Graph Convolutional Neural Networks》
    :param dist_ij_2D: SEED-EED distance matrix
    :return:
    """
    node_num = 62
    tau_value = 2  # What value does the paper not say
    theta_value = 2  # What value does the paper not say

    edge_weight = np.zeros((node_num, node_num), dtype=np.float)
    for i in range(node_num):
        for j in range(node_num):
            dist_ij = dist_ij_2D[i, j]
            edge_weight[i, j] = 0 if dist_ij > tau_value else np.exp(- dist_ij ** 2 / 2 * theta_value ** 2)
    return edge_weight


def build_edge_weight_RGNN(dist_ij_2D: np.array) -> np.array:
    """
     paper:《EEG-Based Emotion Recognition Using Regularized Graph Neural Networks》
    :param dist_ij_2D: SEED-EED distance matrix
    :return:
    """
    node_num = 62
    delta_value = 2
    global_connect_pair = [['FP1', 'FP2'],
                           ['AF3', 'AF4'],
                           ['F5', 'F6'],
                           ['FC5', 'FC6'],
                           ['C5', 'C6'],
                           ['CP5', 'CP6'],
                           ['P5', 'P6'],
                           ['PO5', 'PO6'],
                           ['O1', 'O2']]

    edge_weight = np.zeros((node_num, node_num), dtype=np.float)
    for i in range(node_num):
        for j in range(node_num):
            dist_ij = dist_ij_2D[i, j]
            edge_weight[i, j] = np.min(1, delta_value / dist_ij ** 2)
    for pair in global_connect_pair:
        i = get_channel_index(pair[0])
        j = get_channel_index(pair[1])
        edge_weight[i, j] = edge_weight[i, j] - 1
    return edge_weight


def build_edge_weight_equal(dist_ij_2D: np.array) -> np.array:
    """
    As long as they're connected, the weights are equal 1
    :param dist_ij_2D:
    :return:
    """
    edge_weight = np.array(get_local_connect_matrix())
    return edge_weight

构建边索引

# -*- coding: utf-8 -*-
#
# Copyright (C) 2022 Emperor_Yang, Inc. All Rights Reserved 
#
# @CreateTime    : 2023/1/15 11:50
# @Author        : Emperor_Yang 
# @File          : edg_index.py
# @Software      : PyCharm
import os.path

import torch
import pandas as pd
import numpy as np
from utils.local_connect_matrix import get_local_connect_matrix


def build_local_edge_index_pt(path: str):
    """
    :param path:     'local_edge_index.pt'
    :return:
    """
    assert (os.path.exists(path))
    edge_index = torch.load(path)
    return edge_index


def build_local_edge_index_xlsx(path: str):
    """
    build edge from .xlsx file,for example '../data/SEED/local_connect__matrix.xlsx'
    :param path:
    :return:
    """
    data_df = pd.read_excel(path)
    data_df.fillna(0, inplace=True)
    data_np = data_df.values.astype(np.compat.long)
    edge_index_s = []
    for row in range(data_np.shape[0]):
        for col in range(data_np.shape[1]):
            if data_np[row][col] == 1:
                edge_index_s.append([row, col])
    edge_index = torch.tensor(edge_index_s, dtype=torch.long)
    edge_index = edge_index.t().contiguous()
    return edge_index


def build_local_edge_index_code():
    data_np = np.array(get_local_connect_matrix(), dtype=np.compat.long)
    edge_index_s = []
    for row in range(data_np.shape[0]):
        for col in range(data_np.shape[1]):
            if data_np[row][col] == 1:
                edge_index_s.append([row, col])
    edge_index = torch.tensor(edge_index_s, dtype=torch.long)
    edge_index = edge_index.t().contiguous()
    # torch.save(edge_index, 'local_edge_index.pt')
    return edge_index


构建数据集

# -*- coding: utf-8 -*-
#
# Copyright (C) 2022 Emperor_Yang, Inc. All Rights Reserved 
#
# @CreateTime    : 2023/1/14 22:38
# @Author        : Emperor_Yang 
# @File          : seed_loader_gnn.py
# @Software      : PyCharm

import os
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from utils.load_m_data import load_m_data
from data_process.edge_index import build_local_edge_index_pt
from data_process.feature_x import build_graph_feature_data
from torch_geometric.loader import DataLoader


class SeedGnnMemoryDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        file_name_s = []
        data_dir = os.path.join(self.root, 'ExtractedFeatures/')
        for file_name in os.listdir(data_dir):
            if 'label.mat' in file_name:
                continue
            if '.mat' in file_name:
                file_name_s.append(file_name)
        # 规定将label路径放在最后一个
        file_name_s.append('label.mat')
        return file_name_s

    @property
    def processed_file_names(self):
        return ['seed_data.pt']

    def download(self):
        ...

    def process(self):
        data_list = []
        extracted_features_dir = os.path.join(self.root, 'ExtractedFeatures/')
        label_dict = load_m_data(extracted_features_dir + self.raw_file_names[-1])
        edge_index = build_local_edge_index_pt(os.path.join(self.root, 'local_edge_index.pt'))

        # 迭代文件列表,对每个文件进行处理,得到图数据
        for file_name in self.raw_file_names[:-1]:
            feature_path = os.path.join(extracted_features_dir, file_name)
            label_path = os.path.join(extracted_features_dir, 'label.mat')
            # 构建图的Data对象,放到列表中
            data_x_s, label_s = build_graph_feature_data(feature_path, label_path)  # data_x : (samples, 62, 5)
            # 迭代样本,对每个样本进行处理,构建Data格式
            for sample_index in range(data_x_s.shape[0]):
                x_list = []
                for channel_index in range(data_x_s.shape[1]):
                    x_list.append(data_x_s[sample_index, channel_index, :])
                x = torch.tensor(np.array(x_list), dtype=torch.float)
                one_data = Data(x=x, edge_index=edge_index, y=label_s[sample_index, 0])
                data_list.append(one_data)

        # 进行数据过滤
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        # 进行数据预转换
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        # 对预转换的数据进行压缩和保存
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


if __name__ == '__main__':
    data_set = SeedGnnMemoryDataset(root='../data/SEED/')
    data_loader = DataLoader(data_set, batch_size=32, shuffle=False, num_workers=8)

你可能感兴趣的:(机器学习,信号处理算法,神经网络,深度学习,EEG)