该函数用来获取数据集(data_resize是我的文件名,训练时修改即可)
function [dataTrain,dataTest] = merchData()
% unzip(fullfile(matlabroot,'examples','nnet','MerchData.zip'));
data = imageDatastore('data_resize',...
'IncludeSubfolders',true,...
'LabelSource','foldernames');
[dataTrain,dataTest] = splitEachLabel(data,0.7);
dataTrain = shuffle(dataTrain);
end
下面的代码为主函数
%该文档建立resnet50模型
%resnet50
net=resnet50;
layer=net.Layers(1:end-3);
%读取数据
[merchImagesTrain,merchImagesTest] = merchData();
numClasses = numel(categories(merchImagesTrain.Labels))
%建立层之间的连接
layers = [
layer
fullyConnectedLayer(numClasses,'Name','fc3','WeightLearnRateFactor',1,'BiasLearnRateFactor',1)
softmaxLayer('Name','fc3_softmax')
classificationLayer('Name','ClassificationLayer_fc3')
];
lgraph = layerGraph(layers);
figure;plot(lgraph)
%修改层连接
lgraph = removeLayers(lgraph,'res2a_branch1');
lgraph = removeLayers(lgraph,'bn2a_branch1');
lgraph = removeLayers(lgraph,'res3a_branch1');
lgraph = removeLayers(lgraph,'bn3a_branch1');
lgraph = removeLayers(lgraph,'res4a_branch1');
lgraph = removeLayers(lgraph,'bn4a_branch1');
lgraph = removeLayers(lgraph,'res5a_branch1');
lgraph = removeLayers(lgraph,'bn5a_branch1');
figure;plot(lgraph)
layers_1=lgraph.Layers;
lgraph_1 = layerGraph(layers_1);
figure;plot(lgraph_1)
%添加层
res2a_branch1 = convolution2dLayer(1,256,'Name','res2a_branch1','Stride',1);
bn2a_branch1 = batchNormalizationLayer('Name','bn2a_branch1');
-----------------------------------------------------------------------------------------------------------------------------------------------
res3a_branch1 = convolution2dLayer(1,512,'Name','res3a_branch1','Stride',2);
bn3a_branch1 = batchNormalizationLayer('Name','bn3a_branch1');
-------------------------------------------------------------------------------------------------------------------------------------------------
res4a_branch1 = convolution2dLayer(1,1024,'Name','res4a_branch1','Stride',2);
bn4a_branch1 = batchNormalizationLayer('Name','bn4a_branch1');
-------------------------------------------------------------------------------------------------------------------------------------------------
res5a_branch1 = convolution2dLayer(1,2048,'Name','res5a_branch1','Stride',2);
bn5a_branch1 = batchNormalizationLayer('Name','bn5a_branch1');
lgraph_1 = addLayers(lgraph_1,res2a_branch1);
lgraph_1 = addLayers(lgraph_1,bn2a_branch1);
lgraph_1 = addLayers(lgraph_1,res3a_branch1);
lgraph_1 = addLayers(lgraph_1,bn3a_branch1);
lgraph_1 = addLayers(lgraph_1,res4a_branch1);
lgraph_1 = addLayers(lgraph_1,bn4a_branch1);
lgraph_1 = addLayers(lgraph_1,res5a_branch1);
lgraph_1 = addLayers(lgraph_1,bn5a_branch1);
figure;plot(lgraph_1)
%修改连接
lgraph_1 = connectLayers(lgraph_1,'max_pooling2d_1','res2a_branch1');
lgraph_1 = connectLayers(lgraph_1,'res2a_branch1','bn2a_branch1');
lgraph_1 = connectLayers(lgraph_1,'bn2a_branch1','add_1/in2');
% ------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_4_relu','add_2/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_7_relu','add_3/in2');
% ----------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_10_relu','res3a_branch1');
lgraph_1 = connectLayers(lgraph_1,'res3a_branch1','bn3a_branch1');
lgraph_1 = connectLayers(lgraph_1,'bn3a_branch1','add_4/in2');
% ------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_13_relu','add_5/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_16_relu','add_6/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_19_relu','add_7/in2');
% ----------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_22_relu','res4a_branch1');
lgraph_1 = connectLayers(lgraph_1,'res4a_branch1','bn4a_branch1');
lgraph_1 = connectLayers(lgraph_1,'bn4a_branch1','add_8/in2');
% ------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_25_relu','add_9/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_28_relu','add_10/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_31_relu','add_11/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_34_relu','add_12/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_37_relu','add_13/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_40_relu','res5a_branch1');
lgraph_1 = connectLayers(lgraph_1,'res5a_branch1','bn5a_branch1');
lgraph_1 = connectLayers(lgraph_1,'bn5a_branch1','add_14/in2');
% ------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_43_relu','add_15/in2');
% -------------------------------------------------------------------------------------
lgraph_1 = connectLayers(lgraph_1,'activation_46_relu','add_16/in2');
figure;plot(lgraph_1)
options = trainingOptions('sgdm',...
'MiniBatchSize',5,...
'MaxEpochs',10,...
'InitialLearnRate',0.0001);
netTransfer = trainNetwork(merchImagesTrain, lgraph_1,options);
predictedLabels = classify(netTransfer,merchImagesTest);
testLabels = merchImagesTest.Labels;
accuracy = sum(predictedLabels==testLabels)/numel(predictedLabels)