tensorflow 使用slim对自己的tfrecord数据集实现迁移学习

目录结构
tensorflow 使用slim对自己的tfrecord数据集实现迁移学习_第1张图片
inception_v1.ckpt的下载地址
http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz
并存入解压,存入文件夹中
我存储在’./tmp/checkpoints’
tfrecord数据集下载地址:
链接:https://pan.baidu.com/s/1XVt2EJvgH3dqaI8sx8G2gQ
提取码:v1oa

import os

from datasets import flowers
from nets import inception
from preprocessing import inception_preprocessing
import tensorflow as tf
from tensorflow.contrib import slim
image_size = inception.inception_v1.default_image_size
#模型所在位置 'inception_v1.ckpt'
checkpoints_dir = './tmp/checkpoints'
#训练结果的保存地址
train_dir = './tmp/inception_finetuned/'
# flowers_data_dir = './tmp/flowers/'
num_epochs=2000
heigh, width, channels, n_class = 224, 224, 3, 3
batchSize=10
def tfRecordRead(fileNameQue, heigh, width, channels, n_class):
    reader = tf.TFRecordReader()
    # 创建一个队列来维护输入文件列表
    # 从文件中读出一个Example
    _, serialized_example = reader.read(fileNameQue)
    # 用FixedLenFeature将读入的Example解析成tensor
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        })
    # 将字符串解析成图像对应的像素数组
    image = tf.decode_raw(features['image'], tf.float32)
    # image = tf.decode_raw(features["image"], tf.uint8)
    image = tf.reshape(image, [heigh, width, channels])
    # image = tf.cast(image, tf.float32) * (1 / 255.0)
    labels = tf.cast(features['label'], tf.int64)
    labels = tf.one_hot(labels, n_class)
    return image, labels


def tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize):
    fileNameQue = tf.train.string_input_producer([filename], shuffle=False, num_epochs=num_epochs)
    image, labels = tfRecordRead(fileNameQue, heigh, width, channels, n_class)  # fetch图像和label
    min_after_dequeue = 1000
    capacity = min_after_dequeue + 3 * batchSize
    # 预取图像和label并随机打乱,组成batch,此时tensor rank发生了变化,多了一个batch大小的维度
    imageBatch, labelBatch = tf.train.shuffle_batch([image, labels], batch_size=batchSize,
                                                    capacity=capacity, min_after_dequeue=min_after_dequeue)
    return imageBatch, labelBatch


filename = r'./record\Imageoutput.tfrecords'

def get_init_fn():
    """Returns a function run by the chief worker to warm-start the training."""
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"]  #原输出层

    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                break
        else:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
        variables_to_restore)




with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    imageBatch, labelBatch = tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize)

    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(imageBatch, num_classes=n_class, is_training=True)

    # Specify the loss function:

    slim.losses.softmax_cross_entropy(logits, labelBatch)
    total_loss = slim.losses.get_total_loss()

    # Create some summaries to visualize the training process:
    tf.summary.scalar('losses/Total Loss', total_loss)

    # Specify the optimizer and create the train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    # Run the training:
    final_loss = slim.learning.train(train_op,
                                     logdir=train_dir,
                                     init_fn=get_init_fn(),
                                     number_of_steps=2000,save_summaries_secs=60)



print('Finished training. Last batch loss %f' % final_loss)

你可能感兴趣的:(迁移学习)