本文是对age-gender-estimation项目的详细讲解,它给出了使用keras进行性别和年龄识别的完整流程。
采用的数据集为imdb-wiki,这是一个包含 20,284名人的460,723张以及维基百科上imdb的 62,328张共计523,051 张人脸图像的数据集,是目前开源的数据集中量级最大的,它给出了图像中人物的性别和出生时间、照片的拍摄时间等信息。原始的图片很大,分成了9个部分共计100多G,而裁剪出人脸的图片比较小,只有3G多,因此大家使用的基本都是wiki.tar.gz,不需要注册,直接就可以下载,这点很良心,省去了很多下载的时间。
解压后的目录为100个子文件夹,每个子文件夹再存储图片文件,这也是分类任务里最常见的做法
不过由于标注是采用matlab的mat格式文件存储的,实际用起来还要做一些转化。里面还含有一些噪声,比如性别标记为NAN,年龄算出来不对等,我写了一些代码来对这些信息进行过滤和统计
import os
import numpy as np
from scipy.io import loadmat
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt
def calc_age(taken, dob):
birth = datetime.fromordinal(max(int(dob) - 366, 1))
# assume the photo was taken in the middle of the year
if birth.month < 7:
return taken - birth.year
else:
return taken - birth.year - 1
def get_meta(mat_path, db):
meta = loadmat(mat_path)
full_path = meta[db][0, 0]["full_path"][0]
dob = meta[db][0, 0]["dob"][0] # Matlab serial date number
gender = meta[db][0, 0]["gender"][0]
photo_taken = meta[db][0, 0]["photo_taken"][0] # year
face_score = meta[db][0, 0]["face_score"][0]
second_face_score = meta[db][0, 0]["second_face_score"][0]
age = [calc_age(photo_taken[i], dob[i]) for i in range(len(dob))]
return full_path, dob, gender, photo_taken, face_score, second_face_score, age
def load_data(mat_path):
d = loadmat(mat_path)
return d["image"], d["gender"][0], d["age"][0], d["db"][0], d["img_size"][0, 0], d["min_score"][0, 0]
def convert2txt(mat_path="imdb.mat",db="imdb"):
lines=[]
min_score=1.0
full_path, dob, gender, photo_taken, face_score, second_face_score, age = get_meta(mat_path,db)
genders=[0,0]
ages=[]
for i in range(101):
ages.append(0)
for i in tqdm(range(len(full_path))):
#if face_score[i] < min_score:
#continue
#if (~np.isnan(second_face_score[i])) and second_face_score[i] > 0.0:
#continue
if ~(0 <= age[i] <= 100):
continue
if np.isnan(gender[i]):
continue
g=int(gender[i])
genders[g]+=1
ag=int(age[i])
ages[ag]+=1
#print(i,gender[i],age[i])
line=full_path[i][0]+" "+str(g)+" "+str(ag)
lines.append(line)
with open("gt.txt","w")as f:
for line in lines:
f.write(line+"\n")
print("genders",genders[0],genders[1])
print("age:")
for i in range(101):
print(i,ages[i])
plt.plot(np.linspace(0, 101,101),ages)
plt.savefig("plot.png")
plt.show()
if __name__=="__main__":
convert2txt()
结果如下:
性别比(男:女)=188746:262834
年龄分布画成图如下:
不难看出30-50岁之间的图片最多 ,这也是主流的分布。
具体到age-gender-estimation项目,可以简单的通过
./download.sh
下载,然后使用
python3 create_db.py --output data/imdb_db.mat --db imdb --img_size 64
将数据集转换为需要的格式,这个格式主要是清理无效标签,省的每次都再重复做,代码和我上面给出的差不多,不再赘述。
使用的模型为WiderResnet,可以通过Netron可视化,是由6个残差模型拼起来的,不过输出部分有两个输出,一个是性别的2,另一个是年龄的101
训练部分也比较简单,生成了数据文件后直接使用
python3 train.py --input data/imdb_db.mat
就可以了,如果还想使用数据增强,可以加上--aug
python3 train.py --input data/imdb_db.mat --aug
想看训练好的效果可以运行
python3 demo.py