一个基于tensorflow的finetune整体流程代码

本文主要是https://github.com/joelthchao/tensorflow-finetune-flickr-style代码的解释,用于阐述如何构建网络,载入数据以及微调一个已有数据的网络。

1.整体结构

工程主要有5个python文件构成,其中:

  • finetune.py 为工程主体,调用不同模块完成finetune过程
  • model.py 构架所使用的网络
  • network.py 详细定义网络的不同层的属性
  • dataset.py 用于读入数据
  • assemble.py 用于下载数据文件。是一个独立模块

2.构建网络

构建网络主要由network.py和model.py构成。
其中model.py定义并返回一个alexnet。关于AlexNet的介绍可以参考网上其他资料,在此不赘述
而network.py详细定义Net的每一层实现。在AlexNet中主要使用的是卷积层conv、归一化层norm、池化层pool以及全连接层fc让我们来了解下各层参数的意义.

2.1.conv

def conv(input, k_h, k_w, c_o, s_h, s_w, name, relu=True, padding=DEFAULT_PADDING, group=1)
其实在TF里面是存在直接定义卷积层的函数tf.nn.conv2d的,当然这里面也使用了conv2d来定义卷积层。
参数列表:

  • input 输入图像
  • k_h kernel高度
  • k_w kernel宽度
  • c_o 通道的输出数目
  • s_h sdrider的高度,即纵向卷积间隔
  • s_w sdrider的宽度
  • name 定义的名字
  • relu 是否添加relu层
  • padding 填充方式
  • group 按通道进行分组的数目

2.2.norm

归一化层,在这里使用的是local-response-normal.关于归一化的方式有非常多,这里的归一化方式可以参考论文:
今年来关于Batch-normalize的研究也在推进。很大部分网络训练不收敛的问题其实都可以通过norm来解决,只是norm的方式到底什么样子的最好现在还处于研究阶段。

2.3.pool

本例子中使用的是max-pooling,其作用类似于信号处理中的降采样。将核函数区域内的响应最大的数据保留。以此获取最大的感受野。

2.4.fc

定义全连接层。
这里这几句话的用法

    op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
    fc = op(input, weights, biases, name=scope.name)
    return fc

实际上类似函数的重命名方法。op可以代指原本的relu_layer

3.数据读取

3.1.dataset

dataset类用于读取图像及其标签。
初始化函数读取train_list以及test_list,将所有的图像路径及其标签保存在内存中。这样相比于将所有的图像文件全部读入到内存更加节约空间。但是读取读图像的时候用了opencv的接口,如果遇到cv2 import失败的问题请自行安装opencv

next_batch函数返回一个batch_size大小的图像集合。使用opencv作为接口处理图像并将图像缩放裁剪到合适大小。

3.2.assemble_data

下载并创建数据集。和训练关系不是非常大。

4.训练

4.1.finetune

由于不是很长,我们直接看代码解析整个流程:

def main():
    # Dataset path
    train_list = '/path/to/data/flickr_style/train.txt'
    test_list = '/path/to/data/flickr_style/test.txt'

训练和测试样本的路径

    # Learning params
    learning_rate = 0.001
    training_iters = 12800 # 10 epochs

学习步长以及总迭代次数

    batch_size = 50

每个batch的大小,batch即每次输入训练的图像的数目

    display_step = 20
    test_step = 640 # 0.5 epoch

每20次迭代显示一次,每460次迭代测试一次训练结果

    # Network params
    n_classes = 20

分类的类别

    keep_rate = 0.5

drop_out初始比率

    # Graph input
    x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3])
    y = tf.placeholder(tf.float32, [None, n_classes])

输入数据和标签,使用palceholder作为输入数据和标签存放的结构,方便之后更换batch

    keep_var = tf.placeholder(tf.float32)

控制drop_out的变量

    # Model
    pred = Model.alexnet(x, keep_var)

构建一个alexnet,keep_var控制drop_out的数目

    # Loss and optimizer
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

loss函数和优化方式
这里loss选择的是交叉熵均值。优化为梯度下降

    # Evaluation
    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

评判模型输出结果和真实标签之间的差异

    # Init
    init = tf.initialize_all_variables()

    # Load dataset
    dataset = Dataset(train_list, test_list)

    # Launch the graph
    with tf.Session() as sess:
        print 'Init variable'
        sess.run(init)

        # Load pretrained model
        load_with_skip('caffenet.npy', sess, ['fc8']) # Skip weights from fc8

载入之前的训练数据(除了fc8层)

        print 'Start training'
        step = 1
        while step < training_iters:
            batch_xs, batch_ys = dataset.next_batch(batch_size, 'train')
            sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys, keep_var: keep_rate})

            # Display testing status
            if step%test_step == 0:
                test_acc = 0.
                test_count = 0
                for _ in range(int(dataset.test_size/batch_size)):
                    batch_tx, batch_ty = dataset.next_batch(batch_size, 'test')
                    acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, keep_var: 1.})
                    test_acc += acc
                    test_count += 1
                test_acc /= test_count
                print >> sys.stderr, "{} Iter {}: Testing Accuracy = {:.4f}".format(datetime.now(), step, test_acc)


            # Display training status
            if step%display_step == 0:
                acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys, keep_var: 1.})
                batch_loss = sess.run(loss, feed_dict={x: batch_xs, y: batch_ys, keep_var: 1.})
                print >> sys.stderr, "{} Iter {}: Training Loss = {:.4f}, Accuracy = {:.4f}".format(datetime.now(), step, batch_loss, acc)

            step += 1
        print "Finish!"

你可能感兴趣的:(tensorflow,机器学习)