nsfw模型的搭建(复现实操)

声明:该模型参数来源于开源模型nsfw,但代码编写版本为tensorflow1。本文根据CSDN的一位大佬提供的代码进行tf2版本的复现。

  1. 将github项目下载到本地

nsfw模型的搭建(复现实操)_第1张图片
  1. 将大佬的函数写到tf2.ipynb,附带需要导入的库

import tensorflow as tf
# from  nsfwmodel import ResModel,getModel
from image_utils import create_tensorflow_image_loader
from  image_utils import  create_yahoo_image_loader
import numpy as np
import os
import math
from tensorflow import keras
import keras.layers as layers
nsfw模型的搭建(复现实操)_第2张图片
  1. 模型实例化和查看模型网络结构

model = getModel()
model.summary()
  1. 测试模型效果,需要查看项目文档选择正确的变量输出result[0]

# 测试
import tensorflow as tf
# from  nsfwmodel import ResModel,getModel
from image_utils import create_tensorflow_image_loader
from  image_utils import  create_yahoo_image_loader
import numpy as np
import os

IMAGE_LOADER_TENSORFLOW = "tensorflow"
IMAGE_LOADER_YAHOO = "yahoo"
model_path='./nsfwmodel'
IMAGE_DIR=r'./jpg'
def findAllFile(base):
    for root, ds, fs in os.walk(base):
        for f in fs:
            if f.endswith('.jpg') or f.endswith('.png'):
            # if re.match(r'.*\d.*', f):
                fullname = os.path.join(root, f)
                yield fullname
def load_image(input_type=1,image_loader="yahoo"):
    if input_type == 1:
        print('TENSOR...')
        if image_loader == IMAGE_LOADER_TENSORFLOW:
            print('IMAGE_LOADER_TENSORFLOW...')
            fn_load_image = create_tensorflow_image_loader(tf.Session(graph=tf.Graph()))
        else:
            print('create_yahoo_image_loader')
            fn_load_image = create_yahoo_image_loader()
    elif input_type == 2:
        print('BASE64_JPEG...')
        import base64
        fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
    return fn_load_image
def imageToTensor(inputs,input_type=1):
    if input_type == 1:
        input_tensor = inputs
    elif input_type == 2:
        from image_utils import load_base64_tensor
        input_tensor = load_base64_tensor(inputs)
    else:
        raise ValueError("invalid input type '{}'".format(input_type))
    return input_tensor

if __name__=='__main__':
    input_type=1
    image_loader= "yahoo"
    fn_load_image=load_image(input_type,image_loader)
    for i in findAllFile(IMAGE_DIR):
        print('predict for: ' + i)
        image = fn_load_image(i)
        imageTensor=imageToTensor(image, input_type)
        # print(model(imageTensor))
        result=model(imageTensor)
        print("\tSFW score:\t{}\n\tNSFW score:\t{}".format(*result[0])) 

结果展示

nsfw模型的搭建(复现实操)_第3张图片

后续改进方向

  • 寻找更多的数据进行训练:nsfw图片收集

  • 改进模型适应视频检测:将视频切分成帧

  • 将二分类改成多分类

你可能感兴趣的:(tensorflow,视觉检测)