图片风控NSFW(not suit for work)-2 基于tf2模型微调

直接使用yahoo开源的模型open_nsfw,不能满足业务需求,需要对模型进行重新训练。本篇主要是对模型进行训练 .
(在上篇博客已经讲述了怎么将原始模型转换为tensorflow2模型)

思路

1 将开源雅虎nsfw模型转换为 tensorflow2,见tensorflow2模型重构
2 准备训练样本,正负样本 (比例4:1~1:4之间)
3 数据增强
4 模型训练
5 模型保存
6 模型部署 (java部署)

1 数据准备

训练数据格式如下,其中positiveSapmle为正样本,negetiveSample目录中为负样本,
“”"
|-path
  |-positiveSapmle
  |-negetiveSample
“”"

# path为样本目录,labelName为positiveSapmle,或者negetiveSample
data_dirs=[]
data_labels=[]
path="./sample/"
def get_data_paths_lables(path,labelName):
    path_label=os.path.join(path,labelName)
    label= 1 if labelName=="positiveSapmle" else 0  
    path_list=os.listdir(path_label)
    data_dirs.extend( [os.path.join(path,labelName,name) for name in path_list])
    data_labels.extend([label for i in path_list])

2 模型训练

  • 1 加载模型,使用getModel ,在上篇博客已经实现了怎么tf2加载yahoo开源模型。

  • 2 数据加载
    格式如下
    |-path
      |-positiveSapmle
      |-negetiveSample

  • 3 数据增强
    翻转,旋转,增加对比度等

  • 4 模型训练

import os 
from sklearn.model_selection import train_test_split
import tensorflow as tf
from nsfwmodel import getModel
from image_utils import create_tensorflow_image_loader,__tf_jpeg_process
from  image_utils import  create_yahoo_image_loader

def load_image(input_type=1,image_loader="yahoo"):
    if input_type == 1:
        print('TENSOR...')
        if image_loader == IMAGE_LOADER_TENSORFLOW:
            print('IMAGE_LOADER_TENSORFLOW...')
            fn_load_image = create_tensorflow_image_loader()
        else:
            print('create_yahoo_image_loader')
            fn_load_image = create_yahoo_image_loader()
    elif input_type == 2:
        print('BASE64_JPEG...')
        import base64
        fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
    return fn_load_image

def imageToTensor(inputs,input_type=1):
    if input_type == 1:
        input_tensor = inputs
    elif input_type == 2:
        from image_utils import load_base64_tensor
        input_tensor = load_base64_tensor(inputs)
    else:
        raise ValueError("invalid input type '{}'".format(input_type))
    return input_tensor



# 1 数据准备
input_type=1
image_loader= "tensorflow"
IMAGE_LOADER_TENSORFLOW = "tensorflow"
IMAGE_LOADER_YAHOO = "yahoo"
# 图片加载器
fn_load_image=load_image(input_type,image_loader)

#训练样本及标签加载
get_data_paths_lables(path,"negetiveSample")  
get_data_paths_lables(path,"positiveSapmle")    
train_data_dirs,test_data_dirs,train_data_labels,test_data_labels=train_test_split(data_dirs,data_labels,test_size=0.2)
 
 
# 数据处理(数据增强,数据变化)
def load_preprosess_image(path, label):
    image = tf.io.read_file(path)
    image = __tf_jpeg_process(image)
    image=imageToTensor(image, input_type)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, 0.3)
    image = tf.image.random_contrast(image, 0, 1)
    image = tf.cast(image, tf.float32)

    label = tf.reshape(label, [1])
    return image, label

# 模型训练
BATCH_SIZE=32
train_image_ds = tf.data.Dataset.from_tensor_slices((train_data_dirs, train_data_labels)).map(load_preprosess_image)
train_dataset=train_image_ds.shuffle(10000).batch(BATCH_SIZE)
test_image_ds = tf.data.Dataset.from_tensor_slices((test_data_dirs, test_data_labels)).map(load_preprosess_image)
test_dataset=train_image_ds.batch(BATCH_SIZE)
model = getModel()
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),loss=tf.keras.losses.SparseCategoricalCrossentropy()
             ,metrics=["acc"])
model.fit(train_dataset,epochs=50,validation_data=test_dataset,)

3 模型预测及保存

# 模型预测
model.predict(fn_load_image("./images/ALqhFyWOTw004_1.jpg"))
# 返回:array([[0.97107196, 0.02892802]], dtype=float32)

# 模型保存
model.save_weights("./model/nsfw_finetune.weight")

4 保存为原始的.npy格式 (非必要)

原始开源模型参数:open_nsfw-weights.npy

import numpy as np
npweight_tf1=np.load("./open_nsfw-weights.npy", allow_pickle=True,encoding="latin1").item()
name_dict=dict(zip(['variance', 'scale', 'offset', 'mean','weights', 'biases'],["moving_variance",'gamma','beta','moving_mean','kernel','bias']))
auto_to_manu_dict=dict(zip(["moving_variance",'gamma','beta','moving_mean','kernel','bias'],['variance', 'scale', 'offset', 'mean','weights', 'biases']))
mweight=model.weights
targweight={}
for nwkey,nwvalue in list(npweight_tf1.items()):
    # numpy  value keys 
    nwkeys=  [ name_dict[j] for j in nwvalue.keys()]
    for tfweigh in mweight:  # 模型weight 
        targweight[nwkey]=targweight.get(nwkey,{})
        if  nwkey in tfweigh.name: # numpy  key 在 模型中            
            for jj in nwkeys:
                if jj in tfweigh.name:
                    targweight[nwkey][auto_to_manu_dict[jj]]=tfweigh.numpy()
targweight
np.save("./data/open_nsfw-weights_new.npy",targweight)

你可能感兴趣的:(tensorflow2,图片风控,nsfw,restNet50,预训练模型微调)