# 解压代码
!unzip /home/aistudio/data/data41298/code.zip -d /home/aistudio/work/
!pip install paddlex
拳头表示向下走:
手掌表示向上走:
下面两个分别是向左和向右:
空白表示按位不动:
# 设置工作路径
import os
os.chdir('/home/aistudio/work/Pacman-master/')
这一步需要在本地运行collect文件夹下PalmTracker.py文件进行手势数据采集;
运行该程序时会打开摄像头,在指定区域做出手势,按s保存;
# !python collect/PalmTracker.py
collect data game.py pacman.py test.jpg utils.py
config.py demo.py images src tools weights
这一步使用PaddleX提供的ResNet18进行训练;
预训练模型使用在’IMAGENET’上训练的权重,PaddleX选择参数 pretrain_weights=‘IMAGENET’ 即可;
我这里每种手势共收集了40张左右,训练结果准确率在93%以上;
from paddlex.cls import transforms
import os
import cv2
import numpy as np
import paddlex as pdx
base = './data'
with open(os.path.join('train_list.txt'), 'w') as f:
for i, cls_fold in enumerate(os.listdir(base)):
cls_base = os.path.join(base, cls_fold)
files = os.listdir(cls_base)
print('{} train num:'.format(cls_fold), len(files))
for pt in files:
img = os.path.join(cls_fold, pt)
info = img + ' ' + str(i) + '\n'
f.write(info)
with open(os.path.join('labels.txt'), 'w') as f:
for i, cls_fold in enumerate(os.listdir(base)):
f.write(cls_fold+'\n')
train_transforms = transforms.Compose([
transforms.RandomCrop(crop_size=224),
transforms.Normalize()
])
train_dataset = pdx.datasets.ImageNet(
data_dir=base,
file_list='train_list.txt',
label_list='labels.txt',
transforms=train_transforms,
shuffle=True)
此处训练20个epoch,初始学习率为2e-2
num_classes = len(train_dataset.labels)
model = pdx.cls.ResNet18(num_classes=num_classes)
model.train(num_epochs=20,
train_dataset=train_dataset,
train_batch_size=32,
lr_decay_epochs=[5, 10, 15],
learning_rate=2e-2,
save_dir='w',
log_interval_steps=5,
save_interval_epochs=4)
from paddlex.cls import transforms
import matplotlib.pyplot as plt
import paddlex
import cv2
import warnings
warnings.filterwarnings('ignore')
train_transforms = transforms.Compose([
transforms.RandomCrop(crop_size=224),
transforms.Normalize()
])
model = paddlex.load_model('weights/final')
im = cv2.imread('test.jpg')
result = model.predict(im, topk=1, transforms=train_transforms)
print("Predict Result:", result)
%matplotlib inline
plt.imshow(im)
("Predict Result:", result)
%matplotlib inline
plt.imshow(im)
plt.show()
2020-06-23 09:27:29 [INFO] Model[ResNet18] loaded.
Predict Result: [{'category_id': 1, 'category': 'left', 'score': 0.9999609}]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xDGN7LG1-1592877702883)(output_13_1.png)]
本地运行demo.py即可;
!python demo.py
然后将该控制嵌入到游戏中即可~
游戏代码来自:https://github.com/hbokmann/Pacman
!python game.py
链接地址:https://youtu.be/tlZT2WeaK1U
链接地址:https://www.bilibili.com/video/BV1xa4y1Y7Mb/
北京理工大学 大二在读
感兴趣的方向为:目标检测、人脸识别、EEG识别等
也欢迎大家fork、评论交流
作者博客主页:https://blog.csdn.net/weixin_44936889
权重文件或者源码需要的请私戳作者~