本文讲述基于halcon实现图像分类的功能,本人对halcon例子稍作修改,更容易理解训练过程。
本文数据集使用中国象棋数据,目标是对中国象棋中的棋子分类,识别各种象棋,数据集如图所示:
每个文件夹的棋子都类似上边的文件夹路径,并且图像分类不需要使用标注
一些重要变量如下:
* halcon提供的网络结构,这里有3种
Base_ModelFile := ['./DL_BaseModel/BaseModel_cls_compact.dat','./DL_BaseModel/BaseModel_cls_enhanced.dat','./DL_BaseModel/BaseModel_cls_resnet50.dat']
* 训练结束后保存的模型
ModelFile:='./best_Cls.dat'
* 训练图片的路径
ImageDir := './DataSet/DataImage_Cls/DataImage'
* 验证/测试 图片的路径
ValDir := './DataSet/DataImage_Cls/test'
* 选择模型
ModelType := 0
* 输入模型的大小
ImageSize := 224
* 训练临时文件夹
OutPutDir:='./DataSet/Temp_ClsDataSet'
* 每次迭代数
BatchSize := 4
* 学习率
learning_rate:=0.0001
* 迭代次数
NumEpochs := 100
根据结构重新新图像文件,此处是对图像预处理
parse_filename (RawImageFiles, BaseNames, Extensions, Directories)
ObjectFilesOut := OutPutDir + '/' + Labels + '/' + BaseNames + '.hobj'
for I := 0 to |RawImageFiles| - 1 by 1
read_image (Image, RawImageFiles[I])
zoom_image_size (Image, ImageZoom, ImageSize, ImageSize, 'constant')
scale_image_max (ImageZoom, ImageZoom)
convert_image_type (ImageZoom, ImageZoom, 'real')
* 保存为本地图片
write_object (ImageZoom, ObjectFilesOut[I])
endfor
dev_clear_window ()
open_file (Base_ModelFile[ModelType], 'input_binary', FileHandle)
fread_serialized_item (FileHandle, SerializedItemHandle)
close_file (FileHandle)
deserialize_dl_classifier (SerializedItemHandle, DLClassifierHandle)
读取数据集,并分割数据集
read_dl_classifier_data_set (OutPutDir, 'last_folder', ImageFiles, Labels, LabelsIndices, Classes)
TrainingPercent := 85
ValidationPercent := 15
split_dl_classifier_data_set (ImageFiles, Labels, TrainingPercent, ValidationPercent, TrainingImages, TrainingLabels, ValidationImages, ValidationLabels, TestImages, TestLabels)
for Epoch := 0 to NumEpochs - 1 by 1
tuple_shuffle (TrainSequence, TrainSequence)
for Iteration := 0 to NumBatchesInEpoch - 1 by 1
BatchStart := Iteration * BatchSize
BatchEnd := BatchStart + (BatchSize - 1)
BatchIndices := TrainSequence[BatchStart:BatchEnd]
BatchImageFiles := TrainingImages[BatchIndices]
BatchLabels := TrainingLabels[BatchIndices]
read_image (BatchImages, BatchImageFiles)
augment_images (BatchImages, BatchImages, 'mirror', 'rc')
try
train_dl_classifier_batch (BatchImages, DLClassifierHandle, BatchLabels, DLClassifierTrainResultHandle)
catch (Exception)
train_dl_classifier_batch (BatchImages, DLClassifierHandle, BatchLabels, DLClassifierTrainResultHandle)
endtry
get_dl_classifier_train_result (DLClassifierTrainResultHandle, 'loss', Loss)
LossByIteration := [LossByIteration,Loss]
CurrentIteration := int(Iteration + (NumBatchesInEpoch * Epoch))
if (sum(CurrentIteration [==] PlottedIterations))
apply_dl_classifier_batchwise (TrainingImagesSelected, DLClassifierHandle, TrainingDLClassifierResultIDs, TrainingPredictedLabels, TrainingConfidences)
apply_dl_classifier_batchwise (ValidationImages, DLClassifierHandle, ValidationDLClassifierResultIDs, ValidationPredictedLabels, ValidationConfidences)
evaluate_dl_classifier (TrainingLabelsSelected, DLClassifierHandle, TrainingDLClassifierResultIDs, 'top1_error', 'global', TrainingTop1Error)
evaluate_dl_classifier (ValidationLabels, DLClassifierHandle, ValidationDLClassifierResultIDs, 'top1_error', 'global', ValidationTop1Error)
TrainingErrors := [TrainingErrors,TrainingTop1Error]
ValidationErrors := [ValidationErrors,ValidationTop1Error]
if (ValidationTop1Error <= MinValidationError)
serialize_dl_classifier (DLClassifierHandle, SerializedItemHandle)
open_file (ModelFile, 'output_binary', FileHandle)
fwrite_serialized_item (FileHandle, SerializedItemHandle)
close_file (FileHandle)
MinValidationError := ValidationTop1Error
endif
endif
endfor
endfor
这里有两个变量:
TrainingConfidences
ValidationConfidences
保存了训练过程中的损失值,如果能力可以可以按照此数据进行画图
for Index := 0 to |test_img_file|-1 by 1
ImageFile := test_img_file[Index]
read_image (Image, ImageFile)
convert_image_type (Image, Image, 'real')
apply_dl_classifier (Image, DLClassifierHandle, DLClassifierResultHandle)
get_dl_classifier_result (DLClassifierResultHandle, 'all', 'predicted_classes', PredictedClass)
dev_display (Image)
Text := 'Predicted class: ' + PredictedClass
dev_disp_text (Text, 'window', 'top', 'left', 'white', 'box', 'false')
stop()
endfor
代码私聊