目录
1-- 环境搭建
2-- 数据集划分
3-- 训练模型
4-- 推理测试
安装Paddle OCR参考
① 创建环境
conda create -n paddle_env python=3.8
conda activate paddle_env
② 安装paddlepaddle
# 切换cuda版本为11.1(根据个人实际修改)
sudo gedit ~/.bashrc
source ~/.bashrc
# 安装paddlepaddle
python -m pip install paddlepaddle-gpu==2.3.0.post111 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
③ 安装依赖
# 安装PaddleOCR whl
pip install "paddleocr>=2.0.1"
# 版面分析
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
# 采用ccpd数据集的challenge系列
# 45003用于训练集,5000用于验证集
数据集文件转移代码
import shutil
import os
def remove_file(old_path, new_path):
filelist = os.listdir(old_path) # 列出该目录下的所有文件,listdir返回的文件列表是不包含路径的。
i = 0
for file in filelist:
src = os.path.join(old_path, file)
dst = os.path.join(new_path, file)
if i < 5000:
shutil.move(src, dst)
else:
break
i = i + 1
if __name__ == '__main__':
remove_file(r"/civi/Chinese_license_plate_Note/detection/dataset/ccpd_challenge", r"/civi/Chinese_license_plate_Note/detection/dataset/test_dataset")
# 创建label标注文件
格式参考
代码样例
import os
words_list = [
"A", "B", "C", "D", "E",
"F", "G", "H", "J", "K",
"L", "M", "N", "P", "Q",
"R", "S", "T", "U", "V",
"W", "X", "Y", "Z", "0",
"1", "2", "3", "4", "5",
"6", "7", "8", "9"
]
con_list = [
"皖", "沪", "津", "渝", "冀",
"晋", "蒙", "辽", "吉", "黑",
"苏", "浙", "京", "闽", "赣",
"鲁", "豫", "鄂", "湘", "粤",
"桂", "琼", "川", "贵", "云",
"西", "陕", "甘", "青", "宁",
"新"
]
if __name__ == "__main__":
points = []
label = []
for item in os.listdir(os.path.join('/civi/Chinese_license_plate_Note/detection/dataset/test_dataset/')): # 遍历图片
_, _, bbox, points, label, _, _ = item.split('-') # 分割文件名
points = points.split('_') # 分割四个坐标点
tmp = points
points = []
for _ in tmp:
points.append([int(_.split('&')[0]), int(_.split('&')[1])])
# print(points)
label = label.split('_')
con = con_list[int(label[0])]
words = [words_list[int(_)] for _ in label[1:]]
label = con + ''.join(words)
label = '"' + label + '"'
file_name = item
List = '[{"transcription": ' + label + ', "points": ' + str(points) + '}]'
line = file_name + '\t' + List + '\n'
with open('/civi/Chinese_license_plate_Note/detection/dataset/' + 'test_label.txt', 'a', encoding='UTF-8') as f:
f.write(line)
上述代码博主犯了一个错误,就是CCPD数据集的四个坐标是从右下坐标顺时针开始的,而OCR检测的标注文件,其坐标要求从左上顺时针开始,所以上述代码修改为:
import os
words_list = [
"A", "B", "C", "D", "E",
"F", "G", "H", "J", "K",
"L", "M", "N", "P", "Q",
"R", "S", "T", "U", "V",
"W", "X", "Y", "Z", "0",
"1", "2", "3", "4", "5",
"6", "7", "8", "9"
]
con_list = [
"皖", "沪", "津", "渝", "冀",
"晋", "蒙", "辽", "吉", "黑",
"苏", "浙", "京", "闽", "赣",
"鲁", "豫", "鄂", "湘", "粤",
"桂", "琼", "川", "贵", "云",
"西", "陕", "甘", "青", "宁",
"新"
]
if __name__ == "__main__":
points = []
label = []
for item in os.listdir(os.path.join('/civi/Chinese_license_plate_Note/detection/dataset/test_dataset/')): # 遍历图片
_, _, bbox, points, label, _, _ = item.split('-') # 分割文件名
points = points.split('_') # 分割四个坐标点
tmp = points
points = []
points.append([int(tmp[2].split('&')[0]), int(tmp[2].split('&')[1])])
points.append([int(tmp[3].split('&')[0]), int(tmp[3].split('&')[1])])
points.append([int(tmp[0].split('&')[0]), int(tmp[0].split('&')[1])])
points.append([int(tmp[1].split('&')[0]), int(tmp[1].split('&')[1])])
label = label.split('_')
con = con_list[int(label[0])]
words = [words_list[int(_)] for _ in label[1:]]
label = con + ''.join(words)
label = '"' + label + '"'
file_name = item
List = '[{"transcription": ' + label + ', "points": ' + str(points) + '}]'
line = file_name + '\t' + List + '\n'
with open('/civi/Chinese_license_plate_Note/detection/dataset/' + 'test_label.txt', 'a', encoding='UTF-8') as f:
f.write(line)
训练参考
① 下载预训练模型(DB-Net)
下载地址
② 配置config文件
示例:(注释部分需要留意并修改)
Global:
use_gpu: true # 是否使用gpu
epoch_num: 200 # epoch数目
log_smooth_window: 20
print_batch_step: 2
save_model_dir: /civi/Chinese_license_plate_Note/detection/Models_Well_trained/200epochs # 保存模型的地址
save_epoch_step: 10 # 保存模型的间隔
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: /civi/Chinese_license_plate_Note/detection/pretrain/ch_ppocr_server_v2.0_det_train/best_accuracy # 预训练模型的地址
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: /civi/Chinese_license_plate_Note/detection/test_img/test6.22.png # 测试图片
save_res_path: /civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epochs/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
disable_se: True
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: /civi/Chinese_license_plate_Note/detection/dataset/train_dataset/ # 训练集图片
label_file_list:
- /civi/Chinese_license_plate_Note/detection/dataset/train_label.txt # 训练集标签
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8 # batchsize
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: /civi/Chinese_license_plate_Note/detection/dataset/test_dataset/ # 验证集图片
label_file_list:
- /civi/Chinese_license_plate_Note/detection/dataset/test_label.txt # 验证集标签
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
③ 开始训练
# 这里博主使用多卡训练,gpu编号为1和2
python3 -m paddle.distributed.launch --gpus '1,2' tools/train.py \
-c /civi/Chinese_license_plate_Note/detection/test.yml
# 转推理模型
python3 tools/export_model.py -c /civi/Chinese_license_plate_Note/detection/test_0705.yml \
-o Global.pretrained_model="/civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epochs7.05/best_accuracy" \
Global.save_inference_dir="/civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epoches_tuili/"
# 推理测试
python3 tools/infer/predict_det.py --det_algorithm="DB" \
--det_model_dir="/civi/Chinese_license_plate_Note/detection/Models_Well_trained/100epoches_tuili/" \
--image_dir="/civi/Chinese_license_plate_Note/detection/test_img/test2.jpeg" \
--use_gpu=True
## 未完待续