大批量图像处理(7)——用各种各样的常见模型识别分类大量的图片(绝对干货)

准备如下:

所有常见模型的网络结构文件
大批量图像处理(7)——用各种各样的常见模型识别分类大量的图片(绝对干货)_第1张图片
需要用到的各种模型的参数文件
大批量图像处理(7)——用各种各样的常见模型识别分类大量的图片(绝对干货)_第2张图片

代码如下:

换模型需要改动的地方并不多,会在代码里注释到~~

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import csv
import os

from cleverhans.attacks import FastGradientMethod
from io import BytesIO
import numpy as np
from PIL import Image
from scipy.misc import imread
from scipy.misc import imsave
import tensorflow as tf
#from tensorflow.contrib.slim.nets import inception
from nets import inception_v3, inception_resnet_v2                 #改动地方之一,换网络结构,加上去就行了

slim = tf.contrib.slim
tensorflow_master = ""
checkpoint_path   = "/home/caad/workspace/CAAD2018/dataset/models/inception_v3.ckpt"  #改动地方之一,ckpt文件
input_csv         = "/home/caad/workspace/CAAD2018/dataset/images"
input_dir         ="/home/NEWDISK/output_champion"
max_epsilon       = 16.0
image_width       = 299
image_height      = 299
batch_size        = 10

eps = 2.0 * max_epsilon / 255.0
batch_shape = [batch_size, image_height, image_width, 3]
num_classes = 1001

def load_images(input_dir, batch_shape):
    images = np.zeros(batch_shape)
    filenames = []
    idx = 0
    batch_size = batch_shape[0]
    for filepath in sorted(tf.gfile.Glob(os.path.join(input_dir, '*.png'))):
        with tf.gfile.Open(filepath, "rb") as f:
            images[idx, :, :, :] = imread(f, mode='RGB').astype(np.float)*2.0/255.0 - 1.0
        filenames.append(os.path.basename(filepath))
        idx += 1
        if idx == batch_size:
            yield filenames, images
            filenames = []
            images = np.zeros(batch_shape)
            idx = 0
    if idx > 0:
        yield filenames, images

def load_target_class(input_dir):
  """Loads target classes."""
  with tf.gfile.Open(os.path.join(input_dir, 'target_class.csv')) as f:
    return {row[0]+'.png': int(row[7]) for row in csv.reader(f) if len(row) >= 7}

all_images_taget_class = load_target_class(input_csv)
sum=0
right_number = 0
with tf.Graph().as_default():

    x_input = tf.placeholder(tf.float32, shape=batch_shape)
	'''
    六个常见模型,选择其中一个,把其他五个挡住,需要其他要额外添加
    '''
        '''
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(x_input, num_classes=num_classes, is_training=False)
        
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(
            x_input, num_classes=num_classes, is_training=False, scope='AdvInceptionV3')

    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(
            x_input, num_classes=num_classes, is_training=False, scope='Ens3AdvInceptionV3')

    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(
            x_input, num_classes=num_classes, is_training=False, scope='Ens4AdvInceptionV3')

    with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
        _, end_points = inception_resnet_v2.inception_resnet_v2(
            x_input, num_classes=num_classes, is_training=False, scope='EnsAdvInceptionResnetV2')

    with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
        _, end_points = inception_resnet_v2.inception_resnet_v2(
            x_input, num_classes=num_classes, is_training=False, scope='AdvInceptionResnetV2')
        '''

    predicted_labels = tf.argmax(end_points['Predictions'], 1)

    saver = tf.train.Saver(slim.get_model_variables())
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=tf.train.Scaffold(saver=saver),
        checkpoint_filename_with_path=checkpoint_path,
        master=tensorflow_master)

    with tf.train.MonitoredSession(session_creator=session_creator) as sess:
        for filenames, images in load_images(input_dir, batch_shape):
            target_class_for_batch = (
                [all_images_taget_class[n] for n in filenames]
                + [0] * (batch_size - len(filenames)))
            predicted_targeted_classes = sess.run(predicted_labels, feed_dict={x_input: images})

            for i in range(len(images)):
                if(predicted_targeted_classes[i]==target_class_for_batch[i]):
                    right_number+=1
                if(target_class_for_batch[i]!=0):
                    sum+=1
                print("TARGETED ADVERSARIAL IMAGE",sum,
                      "\n\tPredicted class:", predicted_targeted_classes[i],
                      "\n\tattack class:", target_class_for_batch[i])
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~over~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print("\n\t分类图片总数: ", sum)
        print("\n\t攻击成功数量: ",right_number)
        print("\n\taccuracy:",right_number/sum)






你可能感兴趣的:(机器对抗学习,图像处理)