全卷积神经网络FCN-TensorFlow代码精析

FCN-TensorFlow完整代码Github:https://github.com/EternityZY/FCN-TensorFlow.git

这里解析所有代码 并加入详细注释

注意事项:

  • 请按照代码中要求,将VGG-19模型和训练集下载好,运行下载很慢。

  • MODEL_URL =  'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat'

  • DATA_URL =  'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'

  • 代码经过修改可以运行在TensorFlow1.4上面

  • 训练模型只需执行python FCN.py

  • 修改学习率1e-5 甚至更小 否则loss会一直在3左右浮动

  • debug标志可以在训练期间设置,以添加关于激活函数,梯度,变量等的信息。全卷积神经网络FCN-TensorFlow代码精析_第1张图片全卷积神经网络FCN-TensorFlow代码精析_第2张图片全卷积神经网络FCN-TensorFlow代码精析_第3张图片


FCN.py


# coding=utf-8
from __future__ import print_function
import tensorflow as tf
import numpy as np

import TensorflowUtils as utils
import read_MITSceneParsingData as scene_parsing
import datetime
import BatchDatsetReader as dataset
from six.moves import xrange

# 参数设置
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer("batch_size", "2", "batch size for training")
tf.flags.DEFINE_string("logs_dir", "logs/", "path to logs directory")
tf.flags.DEFINE_string("data_dir", "Data_zoo/MIT_SceneParsing/", "path to dataset")
tf.flags.DEFINE_float("learning_rate", "1e-6", "Learning rate for Adam Optimizer")
tf.flags.DEFINE_string("model_dir", "Model_zoo/", "Path to vgg model mat")
tf.flags.DEFINE_bool('debug', "True", "Debug mode: True/ False")
tf.flags.DEFINE_string('mode', "train", "Mode train/ test/ visualize")

MODEL_URL = 'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat'

MAX_ITERATION = 20000        # 迭代次数
NUM_OF_CLASSESS = 151                # 类别数 151
IMAGE_SIZE = 224                    # 图片大小 224
fine_tuning = False

# VGG网络部分,weights是权重集合, image是预测图像的向量
def vgg_net(weights, image):
    # VGG网络前五大部分
    layers = (
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',

        'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',

        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
        'relu3_3', 'conv3_4', 'relu3_4', 'pool3',

        'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
        'relu4_3', 'conv4_4', 'relu4_4', 'pool4',

        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
        'relu5_3', 'conv5_4', 'relu5_4'
    )

    net = {}
    current = image     # 预测图像
    for i, name in enumerate(layers):
        kind = name[:4]
        if kind == 'conv':
            kernels, bias = weights[i][0][0][0][0]
            # matconvnet: weights are [width, height, in_channels, out_channels]
            # tensorflow: weights are [height, width, in_channels, out_channels]
            kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w")     # conv1_1_w
            bias = utils.get_variable(bias.reshape(-1), name=name + "_b")       # conv1_1_b
            current = utils.conv2d_basic(current, kernels, bias)        # 前向传播结果 current
        elif kind == 'relu':
            current = tf.nn.relu(current, name=name)    # relu1_1
            if FLAGS.debug:     # 是否开启debug模式 true / false
                utils.add_activation_summary(current)       # 画图
        elif kind == 'pool':
            # vgg 的前5层的stride都是2,也就是前5层的size依次减小1倍
            # 这里处理了前4层的stride,用的是平均池化
            # 第5层的pool在下文的外部处理了,用的是最大池化
            # pool1 size缩小2倍
            # pool2 size缩小4倍
            # pool3 size缩小8倍
            # pool4 size缩小16倍
            current = utils.avg_pool_2x2(current)
        net[name] = current     # 每层前向传播结果放在net中, 是一个字典

    return net


# 预测流程,image是输入图像,keep_prob dropout比例
def inference(image, keep_prob):
    """
    Semantic segmentation network definition    # 语义分割网络定义
    :param image: input image. Should have values in range 0-255
    :param keep_prob:
    :return:
    """
    # 获取预训练网络VGG
    print("setting up vgg initialized conv layers ...")
    # model_dir Model_zoo/
    # MODEL_URL 下载VGG19网址
    model_data = utils.get_model_data(FLAGS.model_dir, MODEL_URL)       # 返回VGG19模型中内容

    mean = model_data['normalization'][0][0][0]                         # 获得图像均值
    mean_pixel = np.mean(mean, axis=(0, 1))                             # RGB

    weights = np.squeeze(model_data['layers'])                          # 压缩VGG网络中参数,把维度是1的维度去掉 剩下的就是权重

    processed_image = utils.process_image(image, mean_pixel)            # 图像减均值

    with tf.variable_scope("inference"):                                # 命名作用域 是inference
        image_net = vgg_net(weights, processed_image)                   # 传入权重参数和预测图像,获得所有层输出结果
        conv_final_layer = image_net["conv5_3"]                         # 获得输出结果

        pool5 = utils.max_pool_2x2(conv_final_layer)                    # /32 缩小32倍

        W6 = utils.weight_variable([7, 7, 512, 4096], name="W6")        # 初始化第6层的w b
        b6 = utils.bias_variable([4096], name="b6")
        conv6 = utils.conv2d_basic(pool5, W6, b6)
        relu6 = tf.nn.relu(conv6, name="relu6")
        if FLAGS.debug:
            utils.add_activation_summary(relu6)
        relu_dropout6 = tf.nn.dropout(relu6, keep_prob=keep_prob)

        W7 = utils.weight_variable([1, 1, 4096, 4096], name="W7")       # 第7层卷积层
        b7 = utils.bias_variable([4096], name="b7")
        conv7 = utils.conv2d_basic(relu_dropout6, W7, b7)
        relu7 = tf.nn.relu(conv7, name="relu7")
        if FLAGS.debug:
            utils.add_activation_summary(relu7)
        relu_dropout7 = tf.nn.dropout(relu7, keep_prob=keep_prob)

        W8 = utils.weight_variable([1, 1, 4096, NUM_OF_CLASSESS], name="W8")
        b8 = utils.bias_variable([NUM_OF_CLASSESS], name="b8")
        conv8 = utils.conv2d_basic(relu_dropout7, W8, b8)               # 第8层卷积层 分类151类
        # annotation_pred1 = tf.argmax(conv8, dimension=3, name="prediction1")

        # now to upscale to actual image size
        deconv_shape1 = image_net["pool4"].get_shape()                  # 将pool4 1/16结果尺寸拿出来 做融合 [b,h,w,c]
        # 定义反卷积层的 W,B [H, W, OUTC, INC]  输出个数为pool4层通道个数,输入为conv8通道个数
        # 扩大两倍  所以stride = 2  kernel_size = 4
        W_t1 = utils.weight_variable([4, 4, deconv_shape1[3].value, NUM_OF_CLASSESS], name="W_t1")
        b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
        # 输入为conv8特征图,使得其特征图大小扩大两倍,并且特征图个数变为pool4的通道数
        conv_t1 = utils.conv2d_transpose_strided(conv8, W_t1, b_t1, output_shape=tf.shape(image_net["pool4"]))
        fuse_1 = tf.add(conv_t1, image_net["pool4"], name="fuse_1")     # 进行融合 逐像素相加

        # 获得pool3尺寸 是原图大小的1/8
        deconv_shape2 = image_net["pool3"].get_shape()
        # 输出通道数为pool3通道数,  输入通道数为pool4通道数
        W_t2 = utils.weight_variable([4, 4, deconv_shape2[3].value, deconv_shape1[3].value], name="W_t2")
        b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
        # 将上一层融合结果fuse_1在扩大两倍,输出尺寸和pool3相同
        conv_t2 = utils.conv2d_transpose_strided(fuse_1, W_t2, b_t2, output_shape=tf.shape(image_net["pool3"]))
        # 融合操作deconv(fuse_1) + pool3
        fuse_2 = tf.add(conv_t2, image_net["pool3"], name="fuse_2")

        shape = tf.shape(image)     # 获得原始图像大小
        # 堆叠列表,反卷积输出尺寸,[b,原图H,原图W,类别个数]
        deconv_shape3 = tf.stack([shape[0], shape[1], shape[2], NUM_OF_CLASSESS])
        # 建立反卷积w[8倍扩大需要ks=16, 输出通道数为类别个数, 输入通道数pool3通道数]
        W_t3 = utils.weight_variable([16, 16, NUM_OF_CLASSESS, deconv_shape2[3].value], name="W_t3")
        b_t3 = utils.bias_variable([NUM_OF_CLASSESS], name="b_t3")
        # 反卷积,fuse_2反卷积,输出尺寸为 [b,原图H,原图W,类别个数]
        conv_t3 = utils.conv2d_transpose_strided(fuse_2, W_t3, b_t3, output_shape=deconv_shape3, stride=8)

        # 目前conv_t3的形式为size为和原始图像相同的size,通道数与分类数相同
        # 这句我的理解是对于每个像素位置,根据第3维度(通道数)通过argmax能计算出这个像素点属于哪个分类
        # 也就是对于每个像素而言,NUM_OF_CLASSESS个通道中哪个数值最大,这个像素就属于哪个分类
        # 每个像素点有21个值,哪个值最大就属于那一类
        # 返回一张图,每一个点对于其来别信息shape=[b,h,w]
        annotation_pred = tf.argmax(conv_t3, dimension=3, name="prediction")
    # 从第三维度扩展 形成[b,h,w,c] 其中c=1, conv_t3最后具有21深度的特征图
    return tf.expand_dims(annotation_pred, dim=3), conv_t3


def train(loss_val, var_list):
    """

    :param loss_val:  损失函数
    :param var_list:  需要优化的值
    :return:
    """
    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    grads = optimizer.compute_gradients(loss_val, var_list=var_list)
    if FLAGS.debug:
        # print(len(var_list))
        for grad, var in grads:
            utils.add_gradient_summary(grad, var)
    return optimizer.apply_gradients(grads)     # 返回迭代梯度


def main(argv=None):
    # dropout保留率
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    # 图像占坑
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
    # 标签占坑
    annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="annotation")

    # 预测一个batch图像  获得预测图[b,h,w,c=1]  结果特征图[b,h,w,c=151]
    pred_annotation, logits = inference(image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2)
    tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)
    # 空间交叉熵损失函数[b,h,w,c=151]  和labels[b,h,w]    每一张图分别对比
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                          labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                          name="entropy")))
    tf.summary.scalar("entropy", loss)

    # 返回需要训练的变量列表
    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)

    # 传入损失函数和需要训练的变量列表
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    # 生成绘图数据
    summary_op = tf.summary.merge_all()

    print("Setting up image reader...")
    # data_dir = Data_zoo/MIT_SceneParsing/
    # training: [{image: 图片全路径, annotation:标签全路径, filename:图片名字}] [{}][{}]
    train_records, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
    print(len(train_records))   # 长度
    print(len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        # 读取图片 产生类对象 其中包含所有图片信息
        train_dataset_reader = dataset.BatchDatset(train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    # logs/
    if fine_tuning:
        ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)    # 训练断点回复
        if ckpt and ckpt.model_checkpoint_path:                 # 如果存在checkpoint文件 则恢复sess
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model restored...")

    if FLAGS.mode == "train":
        for itr in range(MAX_ITERATION):
            # 读取下一batch
            train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
            feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85}

            # 迭代优化需要训练的变量
            sess.run(train_op, feed_dict=feed_dict)

            if itr % 10 == 0:
                # 迭代10次打印显示
                train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g" % (itr, train_loss))
                summary_writer.add_summary(summary_str, itr)

            if itr % 500 == 0:
                # 迭代500 次验证
                valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.batch_size)
                valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
                                                       keep_probability: 1.0})
                print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
                # 保存模型
                saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)

    elif FLAGS.mode == "visualize":
        # 可视化
        valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.batch_size)
        # pred_annotation预测结果图
        pred = sess.run(pred_annotation, feed_dict={image: valid_images, annotation: valid_annotations,
                                                    keep_probability: 1.0})
        valid_annotations = np.squeeze(valid_annotations, axis=3)
        pred = np.squeeze(pred, axis=3)

        for itr in range(FLAGS.batch_size):
            utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5+itr))
            utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.logs_dir, name="gt_" + str(5+itr))
            utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
            print("Saved image: %d" % itr)


if __name__ == "__main__":
    tf.app.run()

read_MITSceneParsingData.py

# coding=utf-8
__author__ = 'charlie'
import numpy as np
import os
import random
from six.moves import cPickle as pickle
from tensorflow.python.platform import gfile
import glob

import TensorflowUtils as utils

# DATA_URL = 'http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip'
DATA_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'


def read_dataset(data_dir):
    # data_dir = Data_zoo / MIT_SceneParsing /
    pickle_filename = "MITSceneParsing.pickle"
    # 文件路径  Data_zoo / MIT_SceneParsing / MITSceneParsing.pickle
    pickle_filepath = os.path.join(data_dir, pickle_filename)
    if not os.path.exists(pickle_filepath):
        utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True)       # 不存在文件 则下载
        SceneParsing_folder = os.path.splitext(DATA_URL.split("/")[-1])[0]          # ADEChallengeData2016
        # result =   {training: [{image: 图片全路径, annotation:标签全路径, filename:图片名字}] [][]
        #            validation:[{image:图片全路径, annotation:标签全路径, filename:图片名字}] [] []}
        result = create_image_lists(os.path.join(data_dir, SceneParsing_folder))    # Data_zoo / MIT_SceneParsing / ADEChallengeData2016
        print ("Pickling ...")      # 制作pickle文件
        with open(pickle_filepath, 'wb') as f:
            pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
    else:
        print ("Found pickle file!")

    with open(pickle_filepath, 'rb') as f:      # 打开pickle文件
        result = pickle.load(f)                 # 读取
        training_records = result['training']
        validation_records = result['validation']
        del result
    # training: [{image: 图片全路径, annotation:标签全路径, filename:图片名字}] [{}][{}]
    return training_records, validation_records


def create_image_lists(image_dir):
    """

    :param image_dir:   Data_zoo / MIT_SceneParsing / ADEChallengeData2016
    :return:
    """
    if not gfile.Exists(image_dir):
        print("Image directory '" + image_dir + "' not found.")
        return None
    directories = ['training', 'validation']
    image_list = {}     # 图像字典   training:[]  validation:[]

    for directory in directories:       # 训练集和验证集 分别制作
        file_list = []
        image_list[directory] = []
        # Data_zoo/MIT_SceneParsing/ADEChallengeData2016/images/training/*.jpg
        file_glob = os.path.join(image_dir, "images", directory, '*.' + 'jpg')
        # 加入文件列表  包含所有图片文件全路径+文件名字  如 Data_zoo/MIT_SceneParsing/ADEChallengeData2016/images/training/hi.jpg
        file_list.extend(glob.glob(file_glob))

        if not file_list:   # 文件为空
            print('No files found')
        else:
            for f in file_list:     # 扫描文件列表   这里f对应文件全路径
                # 获取图片名字 hi
                filename = os.path.splitext(f.split("/")[-1])[0]
                # Data_zoo/MIT_SceneParsing/ADEChallengeData2016/annotations/training/*.png
                annotation_file = os.path.join(image_dir, "annotations", directory, filename + '.png')
                if os.path.exists(annotation_file):     # 如果文件路径存在
                    #  image:图片全路径, annotation:标签全路径, filename:图片名字
                    record = {'image': f, 'annotation': annotation_file, 'filename': filename}
                    # image_list{training:[{image:图片全路径, annotation:标签全路径, filename:图片名字}] [] []
                    #            validation:[{image:图片全路径, annotation:标签全路径, filename:图片名字}] [] []}
                    image_list[directory].append(record)
                else:
                    print("Annotation file not found for %s - Skipping" % filename)
        # 对图片列表进行洗牌
        random.shuffle(image_list[directory])
        no_of_images = len(image_list[directory])   # 包含图片文件的个数
        print ('No. of %s files: %d' % (directory, no_of_images))

    return image_list

TensorflowUitls.py

# coding=utf-8
__author__ = 'Charlie'
# Utils used with tensorflow implemetation
import tensorflow as tf
import numpy as np
import scipy.misc as misc
import os, sys
from six.moves import urllib
import tarfile
import zipfile
import scipy.io


# 获取VGG预训练模型
def get_model_data(dir_path, model_url):
    # model_dir Model_zoo/
    # MODEL_URL 下载VGG19网址
    maybe_download_and_extract(dir_path, model_url)     # 判断文件目录和文件是否存在, 不存在则下载
    filename = model_url.split("/")[-1]                 # 将url按/切分, 取最后一个字符串作为文件名
    filepath = os.path.join(dir_path, filename)         # dir_path/filename     文件全路径
    if not os.path.exists(filepath):                    # 判断是否存在此文件
        raise IOError("VGG Model not found!")
    data = scipy.io.loadmat(filepath)                   # 使用io读取VGG.mat文件
    return data


def maybe_download_and_extract(dir_path, url_name, is_tarfile=False, is_zipfile=False):
    # dir_path Model_zoo/
    # url_name 下载VGG19网址
    if not os.path.exists(dir_path):        # 判断文件路径是否存在,如果不存在则创建此路径
        os.makedirs(dir_path)
    filename = url_name.split('/')[-1]      # 将url中 按照/切分,并取最后一个字符串 作为文件名字
    filepath = os.path.join(dir_path, filename)     # 文件路径 = dir_path/filename
    if not os.path.exists(filepath):         # 判断此路径是否存在(此文件),如果不存在,则下载
        def _progress(count, block_size, total_size):       # 内部函数
            sys.stdout.write(
                '\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0))
            sys.stdout.flush()

        filepath, _ = urllib.request.urlretrieve(url_name, filepath, reporthook=_progress)    # 将url中文件 下载到filepath路径中
        print()
        statinfo = os.stat(filepath)
        print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
        if is_tarfile:          # 如果是tar文件, 解压缩
            tarfile.open(filepath, 'r:gz').extractall(dir_path)
        elif is_zipfile:        # 如果是zip文件 解压缩
            with zipfile.ZipFile(filepath) as zf:
                zip_dir = zf.namelist()[0]
                zf.extractall(dir_path)

BatchDatsetReader.py

# coding=utf-8
"""
Code ideas from https://github.com/Newmu/dcgan and tensorflow mnist dataset reader
"""
import numpy as np
import scipy.misc as misc


class BatchDatset:
    files = []
    images = []
    annotations = []
    image_options = {}
    batch_offset = 0
    epochs_completed = 0

    def __init__(self, records_list, image_options={}):
        """
        Intialize a generic file reader with batching for list of files
        :param records_list: list of file records to read -
        sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
        :param image_options: A dictionary of options for modifying the output image
        Available options:
        resize = True/ False
        resize_size = #size of output image - does bilinear resize
        color=True/False
        """
        print("Initializing Batch Dataset Reader...")
        print(image_options)
        self.files = records_list       # 文件列表
        self.image_options = image_options  # 图片操作方式 resize  224
        self._read_images()

    def _read_images(self):
        self.__channels = True
        # 扫描files字典中所有image 图片全路径
        # 根据文件全路径读取图像,并将其扩充为RGB格式
        self.images = np.array([self._transform(filename['image']) for filename in self.files])
        self.__channels = False

        # 扫描files字典中所有annotation 图片全路径
        # 根据文件全路径读取图像,并将其扩充为三通道格式
        self.annotations = np.array(
            [np.expand_dims(self._transform(filename['annotation']), axis=3) for filename in self.files])
        print (self.images.shape)
        print (self.annotations.shape)

    def _transform(self, filename):
        # 读取文件图片
        image = misc.imread(filename)
        if self.__channels and len(image.shape) < 3:  # make sure images are of shape(h,w,3)
            # 将图片三个通道设置为一样的图片
            image = np.array([image for i in range(3)])

        if self.image_options.get("resize", False) and self.image_options["resize"]:

            resize_size = int(self.image_options["resize_size"])
            # 使用最近邻插值法resize图片
            resize_image = misc.imresize(image,
                                         [resize_size, resize_size], interp='nearest')
        else:
            resize_image = image

        return np.array(resize_image)       # 返回已经resize的图片

    def get_records(self):
        """
        返回图片和标签全路径
        :return:
        """
        return self.images, self.annotations

    def reset_batch_offset(self, offset=0):
        """
        剩下的batch
        :param offset:
        :return:
        """
        self.batch_offset = offset

    def next_batch(self, batch_size):
        # 当前第几个batch
        start = self.batch_offset
        # 读取下一个batch  所有offset偏移量+batch_size
        self.batch_offset += batch_size
        # iamges存储所有图片信息 images.shape(len, h, w)
        if self.batch_offset > self.images.shape[0]:      # 如果下一个batch的偏移量超过了图片总数 说明完成了一个epoch
            # Finished epoch
            self.epochs_completed += 1      # epochs完成总数+1
            print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
            # Shuffle the data
            perm = np.arange(self.images.shape[0])      # arange生成数组(0 - len-1) 获取图片索引
            np.random.shuffle(perm)         # 对图片索引洗牌
            self.images = self.images[perm]     # 洗牌之后的图片顺序
            self.annotations = self.annotations[perm]
            # Start next epoch
            start = 0           # 下一个epoch从0开始
            self.batch_offset = batch_size  # 已完成的batch偏移量

        end = self.batch_offset             # 开始到结束self.batch_offset   self.batch_offset+batch_size
        return self.images[start:end], self.annotations[start:end]      # 取出batch

    def get_random_batch(self, batch_size):
        # 按照一个batch_size一个块  进行对所有图片总数进行随机操作, 相当于洗牌工作
        indexes = np.random.randint(0, self.images.shape[0], size=[batch_size]).tolist()
        return self.images[indexes], self.annotations[indexes]


你可能感兴趣的:(计算机视觉算法)