清新脱俗的TensorFlow CIFAR10例程的代码重构——更简明更快的数据读取、loss accuracy实时输出

如果你苦于Tensorflow官方例程的冗长和数据读取的繁杂,那么这篇博客就是你需要看的啦~

Motivation

这段时间学习TensorFlow走了很多弯路,刚开始的MNIST例程还是比较容易啃的。后来钻研im2txt,大概啃了两三个星期,发现结构比MNIST复杂很多,实在难懂。遂转而看CIFAR10例程,但是光数据读取就啃了一个星期,现在都比较迷糊其中队列的用法。
经过这一个月的折腾,我发现Tensorflow虽然网络结构很好定义,而且用户可以方便灵活的自定义层,但是也有一些坑:

  1. 例程里有大量函数和方法用于描述代码架构。注意这些东西不是描述网络的,而是仅仅为了让代码更有层次。当然,这也是google工程师的牛逼之处吧。不过对于新手来说却有点迷茫,因为你除了会看到conv、relu、pool,还将会看到一大堆name_space、variable_sapce、tf.collection等等,这些代码量甚至超过模型描述的代码,有喧宾夺主的感觉。
  2. 数据读取比较复杂,如果深究的话,往往要理解Tensorflow中的“队列”。例如CIFAR10的数据读取函数,返回的是image节点,而不是image数据。若不能理解Tensorflow中的“队列”,你将难以理解image节点的含义,也就无法自如的对其操作了。(感兴趣的读者可自行阅读官方cifar10_input.py)

———————–因此————————-
我毅然决定自己造轮子,用一种清新脱俗的方式重写官方CIFAR10教程。

本代码优点:

  1. 数据读取采用LiFeiFei在cs231n课程上给出的cifar10读取方法,将cifar10数据整个读成numpy.array,熟悉numpy的你可以在数据层随便预处理。并且本代码给出了用numpy进行预处理方式(包括如何用numpy对数据进行shuffle,如何把label变成onehot形式,以及如何每个epoch之后都进行重新shuffle)。
  2. 数据预处理借鉴tensorflow里的crop、flip等函数,并解决了tensorflow某些预处理函数只能对一张图片处理,而不能对一个batch处理的问题。
  3. 官方教程在训练阶段只能输出train loss,而难以输出test loss(我为了观察过拟合,已经在官方教程上尝试各种方法,均以失败告终,这也是我重写CIFAR10数据层的动机)而我的code可以实时输出train loss、test loss、train accuracy、test accuracy。并在tensorboard上显示。
  4. 本代码风格极简,从上到下依次是 预处理、构造graph、训练和评测。结构清晰,基本不存在来回跳转,尽量避免使用tensorflow里牛逼但不必须的函数,所以叫做“清新脱俗”
  5. 本代码的模型部分与官方教程一致,运行的accuracy和loss曲线也与官方一致,证明代码不是胡写的。
  6. 注释比代码还多,而且都是英文注释,方便国际友人阅读(实际上是我的IDE不支持中文输入)。注释中包含了对以上各种trick的详细描述,以及相关trick在stackoverflow上的网址,以及调参遇到过的坑。导致注释比代码还多。。。

本代码缺点:

  1. 数据读取采用LiFeiFei在cs231n课程上给出的cifar10读取方法,将cifar10数据整个读成numpy.array。如果数据集更大,将很可能out of memory。所以本代码只是提供了一种小数据集的读取方法。大数据集要用tensorflow的queue来实现啦

运行结果:

刚刚运行时:

清新脱俗的TensorFlow CIFAR10例程的代码重构——更简明更快的数据读取、loss accuracy实时输出_第1张图片

step=500k时:

清新脱俗的TensorFlow CIFAR10例程的代码重构——更简明更快的数据读取、loss accuracy实时输出_第2张图片

代码(包括cifar10_easy_demo.py和data_utils.py)

cifar10_easy_demo.py

import re
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from data_utils import load_CIFAR10

#==============================================================================
# configurations
#==============================================================================
batch=128
INITIAL_LEARNING_RATE=0.1
LEARNING_RATE_DECAY_FACTOR=0.1
NUM_CLASSES=10
decay_steps=130000
MAX_STEPS=500000
log_dir='/home/kcheng/tensorflow_easy_demo/cifar10_log'
checkpoint_dir='/home/kcheng/tensorflow_easy_demo/cifar10_checkpoint'

def load_onehot_shuffle():
    #==============================================================================
    # this function is to load the data, shufflow it, split it, and transfer label into one-hot label
    # Notice that following operations are implemented by numpy. In theory ,this operation is implemented by queue, see detial in tensorflow offical code for cifar10
    # But that offical code really puzzled me, because I can't understand how the queue return a batch, so I decide to preprocess data with numpy instead of queue
    # This is one of the biggest shortcoming of my code, because my code load all training sets into memory.
    # Fortunatly, cifar10 datasets is so small that our memory can hold it.
    # If you try our code on bigger datasets(i.e. ImageNet), you may run out of memory.

    # All in all, this function shows a easy way to load small datasets as cifar10.
    # If you have any suggestions for loading or preprocessiong data in tensorflow, please reply me
    #==============================================================================


    #==============================================================================
    # read data
    #==============================================================================
    cifar10_dir = '/home/kcheng/tmaster/models-master/tutorials/image/cifar10/winter1516_assignment1/assignment1/cs231n/datasets/cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # As a sanity check, we print out the size of the training and test data.
    print 'Training data shape: ', X_train.shape
    print 'Training labels shape: ', y_train.shape
    print 'Test data shape: ', X_test.shape
    print 'Test labels shape: ', y_test.shape

    #==============================================================================
    # one_hot
    #==============================================================================
    y_train_one_hot=np.zeros((50000,10))
    y_train_one_hot[np.arange(50000), y_train] = 1
    print y_train_one_hot.shape

    y_test_one_hot=np.zeros((10000,10))
    y_test_one_hot[np.arange(10000), y_test] = 1
    print y_test_one_hot.shape

    y_train=y_train_one_hot
    y_test=y_test_one_hot


    #==============================================================================
    # shufflow   
    #==============================================================================
    p_train = range(50000)
    p_test=range(10000)
    np.random.shuffle(p_train)
    np.random.shuffle(p_test)

    # here we use deep copy
    # so we will delete original variable to save memory
    X_train_shuffle=X_train[p_train,:,:,:]
    y_train_shuffle=y_train[p_train,:]
    X_test_shuffle=X_test[p_test,:,:,:]
    y_test_shuffle=y_test[p_test,:]

    del X_train
    del y_train
    del X_test
    del y_test

    # here is shallow copy
    X_train=X_train_shuffle
    y_train=y_train_shuffle
    X_test=X_test_shuffle
    y_test=y_test_shuffle


    print 'Train data shape: ', X_train.shape
    print 'Train labels shape: ', y_train.shape
    print 'Test data shape: ', X_test.shape
    print 'Test labels shape: ', y_test.shape

    print '*****************  shuffle    *******************' # tell you a shuffle operation has been done


    return (X_train,y_train,X_test,y_test)

def re_shuffle(X_train,y_train,X_test,y_test):

    #==============================================================================
    # this function is to re_shufflow images at the beginning of an epoch 
    # without this operation, the order of images will be same between epoches, which will aggravate overfitting and bring down the test accuracy about 5%
    #==============================================================================
    p_train = range(50000)
    p_test=range(10000)
    np.random.shuffle(p_train)
    np.random.shuffle(p_test)

    # here we use deep copy
    # so we will delete original variable to save memory
    X_train_shuffle=X_train[p_train,:,:,:]
    y_train_shuffle=y_train[p_train,:]
    X_test_shuffle=X_test[p_test,:,:,:]
    y_test_shuffle=y_test[p_test,:]

    del X_train
    del y_train
    del X_test
    del y_test

    # here is shallow copy
    X_train=X_train_shuffle
    y_train=y_train_shuffle
    X_test=X_test_shuffle
    y_test=y_test_shuffle

    print '*****************  shuffle    *******************'

    return (X_train,y_train,X_test,y_test)


#==============================================================================
#  Up to now, we have not finished data preprocessing yet. In fact, we will do crop, mirror and other data amplification operations
#  If we implement crop and mirror with numpy, like the above operation, then the crops and mirrors won't change during training, we call it "static data amplification"
#  On the contrast, if we implements crop and mirror in tensorflow graph, in other words, different batchs will use different random seed, we call it "dynamic data amplification"
#  In most cases, we use "dynamic data amplification" ,for it prevents overfitting more efficiently
#  If you want to implement "dynamic data amplification", you should insert tensorflow data amplification oprations into your graph.(otherwise you should write functions like re_shuffle)
#  So, Let's build a graph
#==============================================================================

image=tf.placeholder(tf.float32,shape=[None,32,32,3]) #image placeholder
label=tf.placeholder(tf.float32,shape=[None,10]) #label placeholder


#==============================================================================
# use tensorflow functions to do Data Amplification
#==============================================================================


# Randomly crop the image.
# Note that the tf.random_crop has a paramenter "batch". If you set it '-1', then crop will do in batch dimension too, it will cause an error.
# In my code, I set it to batch. But this means our graph can only receive 128 images each time. Any other batchsize will cause an error.
# In training, we input the graph with a batch of 128. However, when we run on val sets and test sets, we usually input the graph with a batch larger than batchsize(i.e. 1000 in offical code.)
# this is one of the shortcomings of my code. But you can overcome it by cumulating in a loop, in other words, input several times when you run on val sets, 128 images every time.
height = 24
width =  24
distorted_image = tf.random_crop(image, [batch,height, width, 3])


# Randomly flip the image horizontally
# ATTENTION:
# tf.image.random_flip_left_right is made for single images (i.e. 3-D tensors with shape [height, width, color-channel]). 
# How can we make them work on a batch of images (i.e. 4-D tensors with shape [batch, height, width, color-channel])?
# NEXT LINE is the solution, for more informatioin,SEE http://stackoverflow.com/questions/38920240/tensorflow-image-operations-for-batches
distorted_image = tf.map_fn(lambda img: tf.image.random_flip_left_right(img),distorted_image)

# Randomly brightness
distorted_image = tf.image.random_brightness(distorted_image,
                                               max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,
                                             lower=0.2, upper=1.8)


# Subtract off the mean and divide by the variance of the pixels.
# ATTENTION:
# tf.image.per_image_standardization is also made for single images, so...
float_image = tf.map_fn(lambda img: tf.image.per_image_standardization(img), distorted_image)



#==============================================================================
# # Up to now, we have finished preprocessing the data, and the implemention is different from tensorflow offical code.
# # In fact our implemention is much more easy to understand. These is one of the biggest contribution of this blog.
#==============================================================================



#==============================================================================
# Next is CNN model, we implement it similarly to tensorflow offical code
# Notice that offical code implements kernel and biases with a custom function, where they use tf.add_to_collection('losses', weight_decay) to collect L2_loss automatically
# Here, we simplify the code, and implement kernel and biases directly
# This causes another shortcoming of our code, we have to collect L2_loss by hand
#==============================================================================
with tf.variable_scope('conv1') as scope:
    kernel = tf.Variable(tf.truncated_normal(shape=[5, 5, 3, 64],stddev=5e-2))
    conv = tf.nn.conv2d(float_image, kernel, [1, 1, 1, 1], padding='SAME')

    biases = tf.Variable(tf.zeros( [64]))
    pre_activation = tf.nn.bias_add(conv, biases)
    conv1 = tf.nn.relu(pre_activation, name=scope.name)
  # pool1
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
                         padding='SAME', name='pool1')
  # norm1
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
                    name='norm1')

  # conv2
with tf.variable_scope('conv2') as scope:

    kernel = tf.Variable(tf.truncated_normal(shape=[5, 5, 64, 64],stddev=5e-2))
    conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
    biases = tf.get_variable(name='biases',shape=[64], initializer=tf.constant_initializer(0.1))
    pre_activation = tf.nn.bias_add(conv, biases)
    conv2 = tf.nn.relu(pre_activation, name=scope.name)

  # norm2
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
                    name='norm2')
  # pool2
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
                         strides=[1, 2, 2, 1], padding='SAME', name='pool2')

  # local3
with tf.variable_scope('conv3') as scope:
    # Move everything into depth so we can perform a single matrix multiply.
    reshape = tf.reshape(pool2, [batch, -1])
    dim = reshape.get_shape()[1].value
weights3 = tf.Variable(tf.truncated_normal(shape=[dim, 384],stddev=0.04))                          # DON'T FORGET REGULARZATION !!!!!!!!!!!!!
biases3 = tf.get_variable(name='biases3',shape=[384], initializer=tf.constant_initializer(0.1))
local3 = tf.nn.relu(tf.matmul(reshape, weights3) + biases3, name=scope.name)

weights4 = tf.Variable(tf.truncated_normal(shape=[384,192],stddev=0.04))                           # DON'T FORGET REGULARZATION !!!!!!!!!!!!!
biases4 = tf.get_variable(name='biases4',shape=[192], initializer=tf.constant_initializer(0.1))
local4 = tf.nn.relu(tf.matmul(local3, weights4) + biases4, name=scope.name)

with tf.variable_scope('softmax_linear') as scope:

    weights = tf.Variable(tf.truncated_normal(shape=[192,10],stddev=1/192))
    biases = tf.Variable(tf.zeros( [10]))
    softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)


#==============================================================================
# # In offical code, they use tf.nn.sparse_softmax_cross_entropy_with_logits to compute loss
# # But we use tf.nn.softmax_cross_entropy_with_logits instead, because our label is one-hot
#==============================================================================
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=softmax_linear)) +tf.multiply(tf.nn.l2_loss(weights3),0.004) +tf.multiply(tf.nn.l2_loss(weights4),0.004)


#==============================================================================
# Decay the learning rate exponentially based on the number of steps.
# NOTE:
# most training process needs decaying learning rate, so I strongly recommend you use global_step and tf.train.exponential_decay
# DO NOT implement learning_rate_decay DIY in the train loop, because learning rate is not a simple number,but an operation in tensorflow.
# if you still feel confuzed, the following code will show you the reason
#==============================================================================
global_step = tf.Variable(0, name='global_step', trainable=False)
learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                  global_step,
                                  decay_steps,
                                  LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True)

# Create the gradient descent optimizer with the given learning rate.
train_op=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

# Calculate accuray
correct_prediction = tf.equal(tf.argmax(label,1), tf.argmax(softmax_linear,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# define a saver to save checkpoints
saver = tf.train.Saver()

# Congratulations! We have finished building a graph
# let's define a session
sess=tf.Session()
sess.run(tf.global_variables_initializer())


#==============================================================================
# Record the variables you want to observe
# here, we use a trick to observe both train_loss and test_loss, as the following code shows
# for more information about this trick, SEE http://stackoverflow.com/questions/34471563/logging-training-and-validation-loss-in-tensorboard
#==============================================================================
train_loss_summary=tf.summary.scalar('train loss',loss)
test_loss_summary=tf.summary.scalar('test  loss',loss)
train_accuracy_summary=tf.summary.scalar('train accuracy',accuracy)
test_accuracy_summary=tf.summary.scalar('test  accuracy',accuracy)
#merge_summary = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(log_dir,sess.graph)


# define some variables used in the upcoming loop
begin=0
train_num=50000

use_batch=None

train_losses=[]
test_losses=[]
train_acc=[]
test_acc=[]

(X_train,y_train,X_test,y_test) = load_onehot_shuffle()

for step in range(MAX_STEPS):
    # get a batch index
    use_batch=range(begin,begin+batch)
    begin=begin+batch
    # prevent out of range
    if begin+batch>=train_num:
        begin=0
#        (X_train,y_train,X_test,y_test) = load_onehot_shuffle()
        (X_train,y_train,X_test,y_test) = re_shuffle(X_train,y_train,X_test,y_test)


    # show loss and accuracy every 100 steps
    if (step+1)%100 == 0 or (step+1) == MAX_STEPS: # I draw lessons from tensorflow offical code for mnist
        # train
        train_aaa,train_loss=sess.run([accuracy,loss],feed_dict={image:X_train[use_batch,:,:,:],label:y_train[use_batch,:]})
        # test
        test_accuracy,test_loss=sess.run([accuracy,loss],feed_dict={image:X_test[:batch,:,:,:],label:y_test[:batch,:]})
        # record loss and accuracy
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        train_acc.append(train_aaa)
        test_acc.append(test_loss)
        print "step %d,learning rate %g, train accuracy %g, test accuracy %g, train loss %g, test loss %g"%(step,sess.run(learning_rate), train_aaa,test_accuracy,train_loss,test_loss)

        checkpoint_file = os.path.join(checkpoint_dir, 'model.ckpt') #
        saver.save(sess, checkpoint_file, global_step=step)

#==============================================================================
#     LET'S WRITE SUMMARY
#==============================================================================
        loss_summary_train,sum_train_loss=sess.run([train_loss_summary,loss],feed_dict={image:X_train[use_batch,:,:,:],label:y_train[use_batch,:]})
        summary_writer.add_summary(loss_summary_train,step)
        loss_summary_test,sum_test_loss =sess.run([test_loss_summary,loss],feed_dict={image:X_test[:batch,:,:,:],label:y_test[:batch,:]})
        summary_writer.add_summary(loss_summary_test,step)

        accuracy_summary_train,sum_train_loss=sess.run([train_accuracy_summary,accuracy],feed_dict={image:X_train[use_batch,:,:,:],label:y_train[use_batch,:]})
        summary_writer.add_summary(accuracy_summary_train,step)
        accuracy_summary_test,sum_test_loss =sess.run([test_accuracy_summary,accuracy],feed_dict={image:X_test[:batch,:,:,:],label:y_test[:batch,:]})
        summary_writer.add_summary(accuracy_summary_test,step)  

#==============================================================================
#     THIS IS THE FINAL CODE
#     LET'S TRAIN HAPPYLY ^o^
#==============================================================================
    sess.run([train_op],feed_dict={image:X_train[use_batch,:,:,:],label:y_train[use_batch,:]})

data_utils.py

import cPickle as pickle
import numpy as np
import os
from scipy.misc import imread

def load_CIFAR_batch(filename):
  """ load single batch of cifar """
  with open(filename, 'rb') as f:
    datadict = pickle.load(f)
    X = datadict['data']
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    Y = np.array(Y)
    return X, Y

def load_CIFAR10(ROOT):
  """ load all of cifar """
  xs = []
  ys = []
  for b in range(1,6):
    f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
    X, Y = load_CIFAR_batch(f)
    xs.append(X)
    ys.append(Y)    
  Xtr = np.concatenate(xs)
  Ytr = np.concatenate(ys)
  del X, Y
  Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
  return Xtr, Ytr, Xte, Yte

def load_tiny_imagenet(path, dtype=np.float32):
  """
  Load TinyImageNet. Each of TinyImageNet-100-A, TinyImageNet-100-B, and
  TinyImageNet-200 have the same directory structure, so this can be used
  to load any of them.

  Inputs:
  - path: String giving path to the directory to load.
  - dtype: numpy datatype used to load the data.

  Returns: A tuple of
  - class_names: A list where class_names[i] is a list of strings giving the
    WordNet names for class i in the loaded dataset.
  - X_train: (N_tr, 3, 64, 64) array of training images
  - y_train: (N_tr,) array of training labels
  - X_val: (N_val, 3, 64, 64) array of validation images
  - y_val: (N_val,) array of validation labels
  - X_test: (N_test, 3, 64, 64) array of testing images.
  - y_test: (N_test,) array of test labels; if test labels are not available
    (such as in student code) then y_test will be None.
  """
  # First load wnids
  with open(os.path.join(path, 'wnids.txt'), 'r') as f:
    wnids = [x.strip() for x in f]

  # Map wnids to integer labels
  wnid_to_label = {wnid: i for i, wnid in enumerate(wnids)}

  # Use words.txt to get names for each class
  with open(os.path.join(path, 'words.txt'), 'r') as f:
    wnid_to_words = dict(line.split('\t') for line in f)
    for wnid, words in wnid_to_words.iteritems():
      wnid_to_words[wnid] = [w.strip() for w in words.split(',')]
  class_names = [wnid_to_words[wnid] for wnid in wnids]

  # Next load training data.
  X_train = []
  y_train = []
  for i, wnid in enumerate(wnids):
    if (i + 1) % 20 == 0:
      print 'loading training data for synset %d / %d' % (i + 1, len(wnids))
    # To figure out the filenames we need to open the boxes file
    boxes_file = os.path.join(path, 'train', wnid, '%s_boxes.txt' % wnid)
    with open(boxes_file, 'r') as f:
      filenames = [x.split('\t')[0] for x in f]
    num_images = len(filenames)

    X_train_block = np.zeros((num_images, 3, 64, 64), dtype=dtype)
    y_train_block = wnid_to_label[wnid] * np.ones(num_images, dtype=np.int64)
    for j, img_file in enumerate(filenames):
      img_file = os.path.join(path, 'train', wnid, 'images', img_file)
      img = imread(img_file)
      if img.ndim == 2:
        ## grayscale file
        img.shape = (64, 64, 1)
      X_train_block[j] = img.transpose(2, 0, 1)
    X_train.append(X_train_block)
    y_train.append(y_train_block)

  # We need to concatenate all training data
  X_train = np.concatenate(X_train, axis=0)
  y_train = np.concatenate(y_train, axis=0)

  # Next load validation data
  with open(os.path.join(path, 'val', 'val_annotations.txt'), 'r') as f:
    img_files = []
    val_wnids = []
    for line in f:
      img_file, wnid = line.split('\t')[:2]
      img_files.append(img_file)
      val_wnids.append(wnid)
    num_val = len(img_files)
    y_val = np.array([wnid_to_label[wnid] for wnid in val_wnids])
    X_val = np.zeros((num_val, 3, 64, 64), dtype=dtype)
    for i, img_file in enumerate(img_files):
      img_file = os.path.join(path, 'val', 'images', img_file)
      img = imread(img_file)
      if img.ndim == 2:
        img.shape = (64, 64, 1)
      X_val[i] = img.transpose(2, 0, 1)

  # Next load test images
  # Students won't have test labels, so we need to iterate over files in the
  # images directory.
  img_files = os.listdir(os.path.join(path, 'test', 'images'))
  X_test = np.zeros((len(img_files), 3, 64, 64), dtype=dtype)
  for i, img_file in enumerate(img_files):
    img_file = os.path.join(path, 'test', 'images', img_file)
    img = imread(img_file)
    if img.ndim == 2:
      img.shape = (64, 64, 1)
    X_test[i] = img.transpose(2, 0, 1)

  y_test = None
  y_test_file = os.path.join(path, 'test', 'test_annotations.txt')
  if os.path.isfile(y_test_file):
    with open(y_test_file, 'r') as f:
      img_file_to_wnid = {}
      for line in f:
        line = line.split('\t')
        img_file_to_wnid[line[0]] = line[1]
    y_test = [wnid_to_label[img_file_to_wnid[img_file]] for img_file in img_files]
    y_test = np.array(y_test)

  return class_names, X_train, y_train, X_val, y_val, X_test, y_test


def load_models(models_dir):
  """
  Load saved models from disk. This will attempt to unpickle all files in a
  directory; any files that give errors on unpickling (such as README.txt) will
  be skipped.

  Inputs:
  - models_dir: String giving the path to a directory containing model files.
    Each model file is a pickled dictionary with a 'model' field.

  Returns:
  A dictionary mapping model file names to models.
  """
  models = {}
  for model_file in os.listdir(models_dir):
    with open(os.path.join(models_dir, model_file), 'rb') as f:
      try:
        models[model_file] = pickle.load(f)['model']
      except pickle.UnpicklingError:
        continue
  return models

你可能感兴趣的:(数据,博客)