21天深度学习tensorflow

本节的内容主要是将数处理成tfrecord的格式,然后送进网络进行训练。

准备数据

首先使用代码data_convert.py将图片转化为tfrecord的格式
代码如下:我们需要将数据放在同一目录下的pic文件里面,包含训练集,验证集,文件结构如下:
21天深度学习tensorflow_第1张图片
21天深度学习tensorflow_第2张图片

# coding:utf-8
from __future__ import absolute_import
import argparse
import os
import logging
from src.tfrecord import main

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--tensorflow-data-dir', default='pic/')
    parser.add_argument('--train-shards', default=2, type=int)
    parser.add_argument('--validation-shards', default=2, type=int)
    parser.add_argument('--num-threads', default=2, type=int)
    parser.add_argument('--dataset-name', default='satellite', type=str)
    return parser.parse_args()

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    args = parse_args()
    args.tensorflow_dir = args.tensorflow_data_dir
    args.train_directory = os.path.join(args.tensorflow_dir, 'train')
    args.validation_directory = os.path.join(args.tensorflow_dir, 'validation')
    args.output_directory = args.tensorflow_dir
    args.labels_file = os.path.join(args.tensorflow_dir, 'label.txt')
    if os.path.exists(args.labels_file) is False:
        logging.warning('Can\'t find label.txt. Now create it.')
        all_entries = os.listdir(args.train_directory)
        dirnames = []
        for entry in all_entries:
            if os.path.isdir(os.path.join(args.train_directory, entry)):
                dirnames.append(entry)
        with open(args.labels_file, 'w') as f:
            for dirname in dirnames:
                f.write(dirname + '\n')
    main(args)

然后运行一下代码:

python data_convert.py -t pic/ \
--train-shards 2 \#将数据集分成两块
--validation-shards 2 \
--num-threads 2 \#采用两个线程产生数据
--dataset-name satellite#给生成的数据集取一个名字

下面进行训练准备

如果需要使用 Slim 微调模型,首先要下载 Slim 的源代码 。 Slim 的源代码保存在 tensorflow/models 项目中,可以使用下面的 git 命 令下载tensorflow/models ·
,git clone https://github.corn/tensorflow/models.git找到 models/research/目录中的 slim 文件夹 , 这就是要用到的 TensorF lowSlim 的源代码 。
21天深度学习tensorflow_第3张图片

  1. 定义新的dataset文件
    在slim/dataset文件夹下面创建一个文件satellite.py,将flower.py文件的内容复制到satellite.py 文件里面,接下来修改代码
    第一处是 FILE PATTERN 、 SPLITS_TO_ SIZES 、 NUM CLASSES , 将
    真进行以下修改:
    在这里插入图片描述
    在这里插入图片描述
    21天深度学习tensorflow_第4张图片
    修改完 satellite.py 后,还需要在同目录的 dataset_factory. py 文件中注册satellite 数据库

    21天深度学习tensorflow_第5张图片
  2. 定义完数据;集后,在 slim 文件夹下再新建一个 satellite 目录,在这个目
    录中,完成最后的几项准备工作:
    21天深度学习tensorflow_第6张图片
  3. inceptionv3的文件直接在Linux里面运行程序就可以了
 wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz#下载模型
  tar -xvf inception_v3_2016_08_28.tar.gz #解压模型
 

  1. 开始训练
python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004 \

我前面弄了很久都不对,都是这个命令写错了,后来复制了一片博客的就好很多,所以复制我的没有错
后记:
这个是slim的通用框架,以后自己加什么数据都可以拿来用的,是可以训练自己的数据集的,学知识都是从一个不熟悉到熟悉的过程,所以我们熟能生巧,多加练习吧~

你可能感兴趣的:(tensorflow)