未完待续。。。
本文只说明原理,提供参考,实际应用需考虑其他因素。
环境
win7
python3.6.3
tensorflow-gpu1.5(cuda_9.0.176_windows.exe,cudnn-7.0.5(其他版本报错,运行占用内存较多会异常终止))
keras2.1.4
注意:各软件之间版本之间存在适配问题。
目标
通过人工智能技术实现目标(人脸)检测和识别
步骤
1.selectsearch技术选择候选框(如下图蓝色框)
2.训练face_model模型(在1基础判断候选框)预测和判断是否人脸(如下图红色框)
获取较好效果需要调整selectsearch参数和优化face_model模型
优化后效果
代码
face_model.py
实现人脸和非人脸二分类模型训练。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
# 模型结构图
from keras.utils import plot_model
import os
import matplotlib.pyplot as plt
batch_size = 16 # 训练时每个批次的样本数 训练样本数/批次样本数 = 批次数(每个周期)
# num_classes = 10
num_classes = 1 # 2类别
# epochs = 100 = 20 # 训练周期,训练集所有样本(数据、记录)参与训练一次为一个周期
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_face_trained_model.h5'
img_w = 150
img_h = 150
# LossHistory类,保存loss和acc
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = {'batch':[], 'epoch':[]}
self.accuracy = {'batch':[], 'epoch':[]}
self.val_loss = {'batch':[], 'epoch':[]}
self.val_acc = {'batch':[], 'epoch':[]}
def on_batch_end(self, batch, logs={}):
self.losses['batch'].append(logs.get('loss'))
self.accuracy['batch'].append(logs.get('acc'))
self.val_loss['batch'].append(logs.get('val_loss'))
self.val_acc['batch'].append(logs.get('val_acc'))
def on_epoch_end(self, batch, logs={}):
self.losses['epoch'].append(logs.get('loss'))
self.accuracy['epoch'].append(logs.get('acc'))
self.val_loss['epoch'].append(logs.get('val_loss'))
self.val_acc['epoch'].append(logs.get('val_acc'))
def loss_plot(self, loss_type):
iters = range(len(self.losses[loss_type]))
plt.figure()
# acc
plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
# loss
plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
if loss_type == 'epoch':
# val_acc
plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
# val_loss
plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
plt.grid(True)
plt.xlabel(loss_type)
plt.ylabel('acc-loss')
plt.legend(loc="upper right")
plt.show()
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=(150, 150, 3))) # 输入数据是图片转换的矩阵格式,150(行)x 150(列) x 3 (通道)(每个像素点3个单位,分别表示RGB(红绿蓝))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('sigmoid'))
model.summary()
# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
# Let's train the model using RMSprop
model.compile(loss='binary_crossentropy',
optimizer=opt,
metrics=['accuracy'])
# 创建history实例
history = LossHistory()
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1. / 255)
# 训练样本初始化处理:长宽调整,批次大小调整,数据打乱排序(shuffle=True),分类区分(binary:2分类、categorical:多分类)
train_generator = train_datagen.flow_from_directory(
'./data/train', # 训练样本
target_size=(img_w, img_h), # 图片格式调整为 150x150
batch_size=batch_size,
shuffle=True,
class_mode='binary') # 二分类
validation_generator = test_datagen.flow_from_directory(
'./data/validation',# 验证样本
target_size=(img_w, img_h),
batch_size=batch_size,
shuffle=True,
class_mode='binary') # matt 二分类
# 模型适配生成
model.fit_generator(
train_generator, # 训练集
samples_per_epoch=2400, # 训练集总样本数,如果提供样本数量不够,则调整图片(翻转、平移等)补足数量(本例,该函数补充2400-240个样本)
nb_epoch=epochs,
validation_data=validation_generator, # 验证集
nb_val_samples=800, # 验证集总样本数,如果提供样本数量不够,则调整图片(翻转、平移等)补足数量(本例,该函数补充800-60个样本)
callbacks=[history]) # 回调函数,绘制批次(epoch)和精确度(acc)关系图表函数
# Save model and weights
if not os.path.isdir(save_dir): # 没有save_dir对应目录则建立
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)
# 显示批次(epoch)和精确度(acc)关系图表
history.loss_plot('epoch')
# 模型结构图
plot_model(model, to_file='model.png', show_shapes=True)
selectivesearch.py
实现候选框选择
"""
NumPy
=====
Provides
1. An array object of arbitrary homogeneous items
2. Fast mathematical operations over arrays
3. Linear Algebra, Fourier Transforms, Random Number Generation
How to use the documentation
----------------------------
Documentation is available in two forms: docstrings provided
with the code, and a loose standing reference guide, available from
`the NumPy homepage `_.
We recommend exploring the docstrings using
`IPython `_, an advanced Python shell with
TAB-completion and introspection capabilities. See below for further
instructions.
The docstring examples assume that `numpy` has been imported as `np`::
>>> import numpy as np
Code snippets are indicated by three greater-than signs::
>>> x = 42
>>> x = x + 1
Use the built-in ``help`` function to view a function's docstring::
>>> help(np.sort)
... # doctest: +SKIP
For some objects, ``np.info(obj)`` may provide additional help. This is
particularly true if you see the line "Help on ufunc object:" at the top
of the help() page. Ufuncs are implemented in C, not Python, for speed.
The native Python help() does not know how to view their help, but our
np.info() function does.
To search for documents containing a keyword, do::
>>> np.lookfor('keyword')
... # doctest: +SKIP
General-purpose documents like a glossary and help on the basic concepts
of numpy are available under the ``doc`` sub-module::
>>> from numpy import doc
>>> help(doc)
... # doctest: +SKIP
Available subpackages
---------------------
doc
Topical documentation on broadcasting, indexing, etc.
lib
Basic functions used by several sub-packages.
random
Core Random Tools
linalg
Core Linear Algebra Tools
fft
Core FFT routines
polynomial
Polynomial tools
testing
NumPy testing tools
f2py
Fortran to Python Interface Generator.
distutils
Enhancements to distutils with support for
Fortran compilers support and more.
Utilities
---------
test
Run numpy unittests
show_config
Show numpy build configuration
dual
Overwrite certain functions with high-performance Scipy tools
matlib
Make everything matrices.
__version__
NumPy version string
Viewing documentation using IPython
-----------------------------------
Start IPython with the NumPy profile (``ipython -p numpy``), which will
import `numpy` under the alias `np`. Then, use the ``cpaste`` command to
paste examples into the shell. To see which functions are available in
`numpy`, type ``np.`` (where ```` refers to the TAB key), or use
``np.*cos*?`` (where ```` refers to the ENTER key) to narrow
down the list. To view the docstring for a function, use
``np.cos?`` (to view the docstring) and ``np.cos??`` (to view
the source code).
Copies vs. in-place operation
-----------------------------
Most of the functions in `numpy` return a copy of the array argument
(e.g., `np.sort`). In-place versions of these functions are often
available as array methods, i.e. ``x = np.array([1,2,3]); x.sort()``.
Exceptions to this rule are documented.
"""
from __future__ import division, absolute_import, print_function
import sys
import warnings
from ._globals import ModuleDeprecationWarning, VisibleDeprecationWarning
from ._globals import _NoValue
# We first need to detect if we're being called as part of the numpy setup
# procedure itself in a reliable manner.
try:
__NUMPY_SETUP__
except NameError:
__NUMPY_SETUP__ = False
if __NUMPY_SETUP__:
sys.stderr.write('Running from numpy source directory.\n')
else:
try:
from numpy.__config__ import show as show_config
except ImportError:
msg = """Error importing numpy: you should not try to import numpy from
its source directory; please exit the numpy source tree, and relaunch
your python interpreter from there."""
raise ImportError(msg)
from .version import git_revision as __git_revision__
from .version import version as __version__
from ._import_tools import PackageLoader
def pkgload(*packages, **options):
loader = PackageLoader(infunc=True)
return loader(*packages, **options)
from . import add_newdocs
__all__ = ['add_newdocs',
'ModuleDeprecationWarning',
'VisibleDeprecationWarning']
pkgload.__doc__ = PackageLoader.__call__.__doc__
# We don't actually use this ourselves anymore, but I'm not 100% sure that
# no-one else in the world is using it (though I hope not)
from .testing import Tester, _numpy_tester
test = _numpy_tester().test
bench = _numpy_tester().bench
# Allow distributors to run custom init code
from . import _distributor_init
from . import core
from .core import *
from . import compat
from . import lib
from .lib import *
from . import linalg
from . import fft
from . import polynomial
from . import random
from . import ctypeslib
from . import ma
from . import matrixlib as _mat
from .matrixlib import *
from .compat import long
# Make these accessible from numpy name-space
# but not imported in from numpy import *
if sys.version_info[0] >= 3:
from builtins import bool, int, float, complex, object, str
unicode = str
else:
from __builtin__ import bool, int, float, complex, object, unicode, str
from .core import round, abs, max, min
__all__.extend(['__version__', 'pkgload', 'PackageLoader',
'show_config'])
__all__.extend(core.__all__)
__all__.extend(_mat.__all__)
__all__.extend(lib.__all__)
__all__.extend(['linalg', 'fft', 'random', 'ctypeslib', 'ma'])
# Filter annoying Cython warnings that serve no good purpose.
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
# oldnumeric and numarray were removed in 1.9. In case some packages import
# but do not use them, we define them here for backward compatibility.
oldnumeric = 'removed'
numarray = 'removed'
def _sanity_check():
"""
Quick sanity checks for common bugs caused by environment.
There are some cases (e.g., the wrong BLAS ABI) that cause wrong
results under specific runtime conditions that are not necessarily
achieved during test suite runs, and it is useful to catch those early.
See https://github.com/numpy/numpy/issues/8577 and other
similar bug reports.
"""
try:
x = ones(2, dtype=float32)
if not abs(x.dot(x) - 2.0) < 1e-5:
raise AssertionError()
except AssertionError:
msg = ("The current Numpy installation ({!r}) fails to "
"pass simple sanity checks. This can be caused for example "
"by incorrect BLAS library being linked in.")
raise RuntimeError(msg.format(__file__))
_sanity_check()
del _sanity_check
select_rpn.py
实现候选框判断,区分出人脸。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import (
division,
print_function,
)
from skimage import transform, data
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import selectivesearch
import numpy as np
import keras
from keras.applications.imagenet_utils import decode_predictions
from keras.preprocessing import image
from keras.applications import *
from skimage import io
import os
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = r'keras_face_trained_model.h5'
def find_face():
# loading astronaut image
img = io.imread(r'./images/astronaut.jpg')
# perform selective search
img_lbl, regions = selectivesearch.selective_search(
img, scale=500, sigma=0.9, min_size=10)
candidates = set()
for r in regions:
# excluding same rectangle (with different segments)
if r['rect'] in candidates:
continue
# excluding regions smaller than 2000 pixels
if r['size'] < 2000:
continue
# distorted rects
x, y, w, h = r['rect']
if w / h > 1.2 or h / w > 1.2:
continue
candidates.add(r['rect'])
# draw rectangles on the original image
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(6, 6))
ax.imshow(img)
# load model to predict
model = keras.models.load_model(os.path.join(save_dir, model_name))
for x, y, w, h in candidates:
print(x, y, w, h)
# images = image.load_img(f_names[i], target_size=(150, 150))
img_pr = img[y:y+h,x:x+w,:]
import uuid
# uuid_str = uuid.uuid4().hex()
tmp_file_name = r'./temp/file_%d_%d_%d_%d.jpg' %(x,y,w,h)
io.imsave(tmp_file_name, img_pr)
images = image.load_img(tmp_file_name, target_size=(150, 150))
sx = image.img_to_array(images)
sx = sx.astype("float") / 255.0
sx = np.expand_dims(sx, axis=0)
face_label = model.predict(sx)
print('face_label',face_label[0][0])
# Judgement threshold # 0:face; 1:not-face
if face_label[0][0] < 0.5:
# if face_label[0][0] < 0.5:
# if face_label[0][0] > 0.97:
rect = mpatches.Rectangle(
(x, y), w, h, fill=False, edgecolor='red', linewidth=1)
ax.add_patch(rect)
else:
rect = mpatches.Rectangle(
(x, y), w, h, fill=False, edgecolor='blue', linewidth=1)
ax.add_patch(rect)
plt.show()
if __name__ == "__main__":
find_face()
模型结构图
补充
faster-rcnn基于WIDERFace数据集(已经标注的人脸数据库),可以训练人脸检测和定位(x、y、w、h)效果更好。
参考
https://blog.csdn.net/zhanghongxing007/article/details/56479206
http://scikit-image.org/docs/dev/api/skimage.draw.html#module-skimage.draw
https://github.com/playerkk/face-py-faster-rcnn
http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/
完整项目下载
为方便没积分童鞋,请加企鹅,共享文件夹。
包括:代码、数据集合(图片)、已生成model、安装库文件等。
https://github.com/gbusr/ML/tree/master/facecnn
QQ:facedetect.zip
详细讲解
推荐阅读
https://blog.csdn.net/wyx100/article/details/80939360
https://blog.csdn.net/wyx100/article/details/80950499
https://blog.csdn.net/wyx100/article/details/80647379
http://www.cnblogs.com/neo-T/p/6426029.html