基于TF的迁移学习VGG16网络,并进行图片的测试。

vgg16.py文件

1.一开始在通过vgg =Vgg16()的调用,会自动初始化,在初始化过程中,会找到目录下的vgg16.npy文件,也就是在imagenet上训练好的权重文件,然后通过np.load()完成权重文件的加载。代码如下:

class Vgg16:
    def __init__(self, vgg16_npy_path=None):
        if vgg16_npy_path is None:
            path = inspect.getfile(Vgg16)
            path = os.path.abspath(os.path.join(path, os.pardir))
            path = os.path.join(path, "vgg16.npy")
            vgg16_npy_path = path
            print(path)

        self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
        #print (self.data_dict)
        print("npy file loaded")

2.紧接着是进入build()方法中,用途:从npy文件中加载权重来初始化VGG网络(),也就是建立VGG网络的过程。因为一开始加载的RGB图像是4维的·[batch,height,width,3],所以先把rgb通道的图像改为bgr顺序的通道。我当时查了一下是否有必要进行通道的转换,这个要根据自己的定义的模型而定。并不是一定要转化的。我这边转化后其检测结果比转化前的结果要好一点,具体原因我也不清楚。代码如下:

        start_time = time.time()
        print("build model started")
        rgb_scaled = rgb * 255.0

        # Convert RGB to BGR
        red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
        assert red.get_shape().as_list()[1:] == [224, 224, 1]
        assert green.get_shape().as_list()[1:] == [224, 224, 1]
        assert blue.get_shape().as_list()[1:] == [224, 224, 1]
        bgr = tf.concat(axis=3, values=[
            blue - VGG_MEAN[0],
            green - VGG_MEAN[1],
            red - VGG_MEAN[2],
        ])
        assert bgr.get_shape().as_list()[1:] == [224, 224, 3]

3.构建网络,此时输入图像已经准备好了,接下来就是把图像送入到网络中。网络的结构如下:

        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
        self.pool1 = self.max_pool(self.conv1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
        self.pool2 = self.max_pool(self.conv2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
        self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
        self.pool3 = self.max_pool(self.conv3_3, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
        self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
        self.pool4 = self.max_pool(self.conv4_3, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
        self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
        self.pool5 = self.max_pool(self.conv5_3, 'pool5')

        self.fc6 = self.fc_layer(self.pool5, "fc6")
        assert self.fc6.get_shape().as_list()[1:] == [4096]
        self.relu6 = tf.nn.relu(self.fc6)

        self.fc7 = self.fc_layer(self.relu6, "fc7")
        self.relu7 = tf.nn.relu(self.fc7)

        self.fc8 = self.fc_layer(self.relu7, "fc8")

        self.prob = tf.nn.softmax(self.fc8, name="prob")

到pool5的时候,此时输入224*224*3已经变成了7*7*512的形式了。之所以提到pool5的输出,因为我们在迁移学习的过程中,通过前面的卷积层来提取特征,这里的卷积层一般就是指的pool5前面的所有层。用来提取通用的图像特征。后面紧接着三个全连接层,在接一个输出。

4.这样我们的网络也构建好了,而且输入图片的维度也调整好了,以及VGG16的权重文件也加载进去了,那么现在可以加载几张图片测试一下。(这里采用的imagenet数据集训练好的参数来直接加载网络进行判别)代码如下:

import numpy as np
import tensorflow as tf
import vgg16
import utils

img1 = utils.load_image("./test_data/tiger.jpeg")
img2 = utils.load_image("./test_data/hudie.jpeg")

batch1 = img1.reshape((1, 224, 224, 3))
batch2 = img2.reshape((1, 224, 224, 3))

batch = np.concatenate((batch1, batch2), 0)

# with tf.Session(config=tf.ConfigProto(gpu_options=(tf.GPUOptions(per_process_gpu_memory_fraction=0.7)))) as sess:
with tf.device('/cpu:0'):
    with tf.Session() as sess:
        images = tf.placeholder("float", [2, 224, 224, 3])
        feed_dict = {images: batch}

        vgg = vgg16.Vgg16()
        with tf.name_scope("content_vgg"):
            vgg.build(images)

        prob = sess.run(vgg.prob, feed_dict=feed_dict)

在这里我加载了两张图片进行测试,因为输入的图片的维度是4维度的,所以加载完图片后在进行reshape()操作,然后堆在batch里,然后喂到feed_dict里面,调用网络,再通过sess.run(y,feed_dict=feed_dict)得到最终的输出y。此时这个y是softmax的输出,是1000个分数值,代表每个类型的分数。


















你可能感兴趣的:(深度学习,tensorflow,迁移学习,VGG16)