在常规的CNN图像训练中,matlab和pytorch都提供了很多标准示例,但其输入都是N*N*3的图像,在部分场景下,研究人员想要模型参考多个角度或不同源的成像数据得出一个综合的输出,这样模型能够考虑到更多的特征细节以提高预测精度。本文参考了题为‘Noninvasive Detection of Salt Stress in Cotton Seedlings by Combining Multicolor Fluorescence–Multispectral Reflectance Imaging with EfficientNet-OB2’的中的训练方法,该论文将9个不同的数据拆分输入到改进的EfficientNet中,从而实现了更好的效果。
原论文中不同输入的区分,合并输入(左),拆分输入(右)
论文源:
Noninvasive Detection of Salt Stress in Cotton Seedlings by Combining Multicolor Fluorescence–Multispectral Reflectance Imaging with EfficientNet-OB2 | Plant Phenomics (science.org)https://spj.science.org/doi/10.34133/plantphenomics.0125#sec-1
脚本先是创建了9输入的EfficientNet改进网络,然后通过imageDatastore及arraryDatastore函数将图像数据和标签数据进行整合最后使用combine函数将这些数据组合成多源输入数据。之后就是使用常规的trainNetwork进行训练。需要数据的是,每个源(输入)的文件夹内排序要一一对应,如果是乱序,combine后是对应不上的,所以建议每个文件夹内的对应数据采用相同的名称命名。
lgraph=Build_ENetOB2(N,PC,9);%created
%%训练集 traindata
imdsTrainPC1 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC1","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC2 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC2","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC3 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC3","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC4 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC4","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC5 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC5","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC6 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC6","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC7 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC7","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC8 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC8","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainPC9 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_train\PC9","IncludeSubfolders",true,"LabelSource","foldernames");
imdsTrainM=arrayDatastore(imdsTrainPC1.Labels);
%验证集 validation data
TimdsTrainPC1 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC1","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC2 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC2","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC3 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC3","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC4 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC4","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC5 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC5","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC6 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC6","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC7 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC7","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC8 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC8","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainPC9 = imageDatastore("C:\Users\94562\Desktop\cotton\Environment_450_300_check\PC9","IncludeSubfolders",true,"LabelSource","foldernames");
TimdsTrainM=arrayDatastore(TimdsTrainPC1.Labels);
imdsTrainPC=combine(imdsTrainPC1,imdsTrainPC2,imdsTrainPC3,imdsTrainPC4,imdsTrainPC5,imdsTrainPC6,imdsTrainPC7,imdsTrainPC8,imdsTrainPC9,imdsTrainM,ReadOrder='associated');
read(imdsTrainPC)
TimdsTrainPC=combine(TimdsTrainPC1,TimdsTrainPC2,TimdsTrainPC3,TimdsTrainPC4,TimdsTrainPC5,TimdsTrainPC6,TimdsTrainPC7,TimdsTrainPC8,TimdsTrainPC9,TimdsTrainM,ReadOrder='associated');
%[imdsTrain, imdsValidation] = splitEachLabel(imdsTrainPC,0.7,"randomized");
opts = trainingOptions("adam",...
"ExecutionEnvironment","gpu",...
"InitialLearnRate",0.01,...
"MaxEpochs",300,...
"Shuffle","every-epoch",...
"Plots","training-progress",'MiniBatchSize',128,'ValidationData',TimdsTrainPC,'ValidationFrequency',50);
[net, traininfo] = trainNetwork(imdsTrainPC,lgraph,opts);
%resultN=predict(net,TimdsTrainPC);
Yt=TimdsTrainPC1.Labels
Yp = classify(net,TimdsTrainPC)
figure
confusionchart(Yt,Yp);