通过小波变换对运动想象信号进行特征提取,生成时频图像作为神经网络的输入。
使用BCI竞赛2008–Graz dataset A
中的A01
受试者的数据作为数据集。
采样频率250Hz,每个数据前三个为伪迹参考信号,后6个为EEG信号集,
每一个划成48次,四个任务,每个任务12次;
每次任务大概8s,每次大概从3-6s(750-1500点)为运动想像时间,
拟采集770-1462点,每个样本512个点,采样间隔20个点,采集25个样本;
%% 采样信号
clear;
fs = 250; % 采样频率 250Hz
Na = 96735; %
Nt = 48; % 一个EEG信号集划成48次数据,即四个任务、每个任务12次
Ns = 25; % 样本数 25个
Np = 20; % 采样间隔20个点
N = 256;
%%
x00 = load('A01T');
%%
for k=1:6 % 后6个为EEG信号集,即data {1,4} {1,5} {1,6} {1,7} {1,8} {1,9}
x01 = x00.data{1, k+3}.X; % EEG信号
y01 = x00.data{1, k+3}.y; % 类别
t = x00.data{1, k+3}.trial; % 试验(trials),包含伪迹
t(Nt+1) = Na;
%figure
for i = 1:Nt
x0 = x01(t(i):t(i+1), :);
%subplot(6,8,i);
%plot(x0(:,1));xlim([0 2100]);ylim([-100 100]);
for j = 1:Ns
x1 = x0(750+Np*(j-1):750+Np*(j-1)+N-1, 1:22);
x2 = (x1-min(x1(:)))/(max(x1(:))-min(x1(:))); % 最大最小归一化
XTr(:, :, 1, 1200*(k-1)+25*(i-1)+j) = x2;
YTr(1, 1200*(k-1)+25*(i-1)+j) = categorical(y01(i));
end
clear x0; % 每次迭代x0的长度会发生变化
end
end
save SubA_Train XTr YTr;
%% 小波变换
clear
load SubA_Train;
%%
id=[8 10 12]; % 选三个电极,
parfor i=1:length(XTr)
for j=1:3
x = XTr(:,id(j),1,i);
x1 = abs(cwt(x)); % 小波变换
XTrft(:,:,j,i) = (x1-min(x1(:)))/(max(x1(:))-min(x1(:))); % 归一化
end
end
save SubA_TF_Train XTrft YTr;
%% 可视化一个样本为彩色图片
size(XTrft(:,:,:,1)) % 51×256×3
categories(YTr) % 查看类别数
figure;
imshow(XTrft(:,:,:,1))
%% 转成图片格式,先新建一个images文件夹,然后在images里面新建4个文件夹,分别为0、1、2、3.
load SubA_TF_Train
for i = 1:7200
k = double(string(YTr(1,i)))-1; % label
imwrite(XTrft(:,:,:,i),['images\',num2str(k)','\',num2str(i),'.jpg']) % 保存为图片
end
利用deepNetworkDesigner搭建网络,导出到工作区,训练。需要注意的是,网络的输出层为4类。可以采用典型的网络,例如Googlenet、resnet等。
clear;
%% 导入数据集
imdsTrain = imageDatastore("images","IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain, imdsValidation] = splitEachLabel(imdsTrain,0.8,"randomized");
% 调整图像大小以匹配网络输入层
% inputsize = [256 256 3];
inputsize = [51 256 3];
augimdsTrain = augmentedImageDatastore(inputsize,imdsTrain);
augimdsValidation = augmentedImageDatastore(inputsize,imdsValidation);
%% 网络结构alexnet
% Net = alexnet;
% Net = googlenet;
% Net = inceptionresnetv2;
deepNetworkDesigner
%% 训练网络
miniBatchSize = 128;
learnRate = 0.0001;
valFrequency = floor(0.8*7200.0/miniBatchSize);
options = trainingOptions('adam', ...
'InitialLearnRate',learnRate, ...
'MaxEpochs',20, ...
'MiniBatchSize',miniBatchSize, ...
'Shuffle','every-epoch', ...
'Plots','training-progress', ...
'Verbose',false, ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',valFrequency, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropFactor',0.1, ...
'LearnRateDropPeriod',5);
trainedNet = trainNetwork(augimdsTrain, lgraph_1, options);
%% 评估
% 准确率
% 训练集
[YPred,probs] = classify(trainedNet,augimdsTrain);
accuracy = mean(YPred == imdsTrain.Labels)
disp("training acc: " + accuracy*100 + "%")
% 验证集
[YPred,probs] = classify(trainedNet,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
disp("val acc: " + accuracy*100 + "%")
% 混淆矩阵
figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
cm = confusionchart(imdsValidation.Labels,YPred);
cm.Title = 'Confusion Matrix for Validation Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';