直接使用yahoo开源的模型open_nsfw,不能满足业务需求,需要对模型进行重新训练。本篇主要是对模型进行训练 .
(在上篇博客已经讲述了怎么将原始模型转换为tensorflow2模型)
1 将开源雅虎nsfw模型转换为 tensorflow2,见tensorflow2模型重构
2 准备训练样本,正负样本 (比例4:1~1:4之间)
3 数据增强
4 模型训练
5 模型保存
6 模型部署 (java部署)
训练数据格式如下,其中positiveSapmle为正样本,negetiveSample目录中为负样本,
“”"
|-path
|-positiveSapmle
|-negetiveSample
“”"
# path为样本目录,labelName为positiveSapmle,或者negetiveSample
data_dirs=[]
data_labels=[]
path="./sample/"
def get_data_paths_lables(path,labelName):
path_label=os.path.join(path,labelName)
label= 1 if labelName=="positiveSapmle" else 0
path_list=os.listdir(path_label)
data_dirs.extend( [os.path.join(path,labelName,name) for name in path_list])
data_labels.extend([label for i in path_list])
1 加载模型,使用getModel ,在上篇博客已经实现了怎么tf2加载yahoo开源模型。
2 数据加载
格式如下
|-path
|-positiveSapmle
|-negetiveSample
3 数据增强
翻转,旋转,增加对比度等
4 模型训练
import os
from sklearn.model_selection import train_test_split
import tensorflow as tf
from nsfwmodel import getModel
from image_utils import create_tensorflow_image_loader,__tf_jpeg_process
from image_utils import create_yahoo_image_loader
def load_image(input_type=1,image_loader="yahoo"):
if input_type == 1:
print('TENSOR...')
if image_loader == IMAGE_LOADER_TENSORFLOW:
print('IMAGE_LOADER_TENSORFLOW...')
fn_load_image = create_tensorflow_image_loader()
else:
print('create_yahoo_image_loader')
fn_load_image = create_yahoo_image_loader()
elif input_type == 2:
print('BASE64_JPEG...')
import base64
fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
return fn_load_image
def imageToTensor(inputs,input_type=1):
if input_type == 1:
input_tensor = inputs
elif input_type == 2:
from image_utils import load_base64_tensor
input_tensor = load_base64_tensor(inputs)
else:
raise ValueError("invalid input type '{}'".format(input_type))
return input_tensor
# 1 数据准备
input_type=1
image_loader= "tensorflow"
IMAGE_LOADER_TENSORFLOW = "tensorflow"
IMAGE_LOADER_YAHOO = "yahoo"
# 图片加载器
fn_load_image=load_image(input_type,image_loader)
#训练样本及标签加载
get_data_paths_lables(path,"negetiveSample")
get_data_paths_lables(path,"positiveSapmle")
train_data_dirs,test_data_dirs,train_data_labels,test_data_labels=train_test_split(data_dirs,data_labels,test_size=0.2)
# 数据处理(数据增强,数据变化)
def load_preprosess_image(path, label):
image = tf.io.read_file(path)
image = __tf_jpeg_process(image)
image=imageToTensor(image, input_type)
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
image = tf.image.random_brightness(image, 0.3)
image = tf.image.random_contrast(image, 0, 1)
image = tf.cast(image, tf.float32)
label = tf.reshape(label, [1])
return image, label
# 模型训练
BATCH_SIZE=32
train_image_ds = tf.data.Dataset.from_tensor_slices((train_data_dirs, train_data_labels)).map(load_preprosess_image)
train_dataset=train_image_ds.shuffle(10000).batch(BATCH_SIZE)
test_image_ds = tf.data.Dataset.from_tensor_slices((test_data_dirs, test_data_labels)).map(load_preprosess_image)
test_dataset=train_image_ds.batch(BATCH_SIZE)
model = getModel()
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),loss=tf.keras.losses.SparseCategoricalCrossentropy()
,metrics=["acc"])
model.fit(train_dataset,epochs=50,validation_data=test_dataset,)
# 模型预测
model.predict(fn_load_image("./images/ALqhFyWOTw004_1.jpg"))
# 返回:array([[0.97107196, 0.02892802]], dtype=float32)
# 模型保存
model.save_weights("./model/nsfw_finetune.weight")
原始开源模型参数:open_nsfw-weights.npy
import numpy as np
npweight_tf1=np.load("./open_nsfw-weights.npy", allow_pickle=True,encoding="latin1").item()
name_dict=dict(zip(['variance', 'scale', 'offset', 'mean','weights', 'biases'],["moving_variance",'gamma','beta','moving_mean','kernel','bias']))
auto_to_manu_dict=dict(zip(["moving_variance",'gamma','beta','moving_mean','kernel','bias'],['variance', 'scale', 'offset', 'mean','weights', 'biases']))
mweight=model.weights
targweight={}
for nwkey,nwvalue in list(npweight_tf1.items()):
# numpy value keys
nwkeys= [ name_dict[j] for j in nwvalue.keys()]
for tfweigh in mweight: # 模型weight
targweight[nwkey]=targweight.get(nwkey,{})
if nwkey in tfweigh.name: # numpy key 在 模型中
for jj in nwkeys:
if jj in tfweigh.name:
targweight[nwkey][auto_to_manu_dict[jj]]=tfweigh.numpy()
targweight
np.save("./data/open_nsfw-weights_new.npy",targweight)