PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测

前言

本文将是2017下半年以来,最新也是最全的一个深度学习框架评测。这里的评测并不是简单的使用评测,我们将用这五个框架共同完成一个深度学习任务,从框架使用的易用性、训练的速度、数据预处理的繁琐程度,以及显存占用大小等几个方面来进行全方位的测评,除此之外,我们还将给出一个非常客观,非常全面的使用建议。最后提醒大家本篇文章不仅仅是一个评测,你甚至可以作为五大框架的入门教程

0. 五大框架概览

在评测之前,让我们先对这五大框架进行一个全方位的概览,以及他们目前所处的发展地位。首先在这五大框架中,很多人肯定会问,为什么没有Keras?为什么没有CNTK?在这里我说明一点,本篇文章偏向于工业化级别的应用评测,主要评测主流框架,当然不是说Keras和CNTK就不主流了,文章没有任何利益相关的东西,只不过是Keras本身就拥有多种框架作为后端,因此与它的后端框架对比也就没有任何意义,Keras毫无疑问是速度最慢的。而CNTK由于笔者对Windows无感因此也就没有在评测范围之内(CNTK也是一个优秀的框架,当然也跨平台,感兴趣者可以去踩踩坑)。

TensorFlow可以说是目前发展来说最活跃的,TensorFlow目前已经有72.3k个star,MXNet是11.5k,Caffe2是5.9K, 当然caffe2要推出的稍晚一些,MXNet的官方GitHub repo也是后来又转到Apache的孵化项目中。但是从GitHub受关注度来看,无疑TensorFlow和MXNet是更被看好的。

即使我不做这篇测评,很多人也知道这些框架目前为止有一些这样的评价:

  • TensorFlow API比较繁杂,使用上手困难,乱七八糟的东西很多,但是生态丰富,很多深度学习模型多有TF的实现,有Google大佬加持;
  • MXNet 占用内存小,速度快,非常小巧玲珑,有着天生的开源基因,完全靠社区推动的框架;
  • Caffe2 是面向工业级应用的框架,但是推出较晚,而且主打Python2(execuse me? 2017年了还主打Python2?), 我不由自主的黑一下,从安装部署角度来说用户体验不是非常友好;
  • PyTorch 是Facebook面向学术界推出的一个框架,使用非常简单,搭建神经网络就像Keras和matlab一样,但是我又不得不黑一下,每次还得判断一下是GPU还是CPU?(execuse me? 真的应了那句话,我踩过了tf的坑才知道tf的好);
  • PaddlePadddle 百度开源的一个框架,国内也有很多人用,我的感受是,非常符合中国人的使用习惯,但是在API的实用性上还有待进一步加强,我曾经写过一篇博客入门PaddlePaddle,不得不说,PaddlePaddle的中文文档写的非常清楚,上手比较简单PaddlePaddle三行代码从入门到精通;

以上评价是以前的评价,夹杂着一丝个人使用感受,最后说一下他们各自目前的好的动向:

  • TensorFlow models这个模型库更新非常快,以前的一些图片分类,目标检测,图片生成文字,生成对抗网络都有现成的深度学习应用的例子,包括现在更新的基于知识图谱的问答项目,神经网络编程机器人等项目,这些官方生态对于一个框架来说非常有用,这无疑是tf的一个长处
  • MXNet早在几个月前就推出了Gluon这个接口,说白了就是一个Keras,包装了一个更加方便使用的API,但是目前来说还只能实现一些简单的网络的构建,复杂的还是得用原生的API,这里有一个教程链接Gluon资料, 除此之外,MXNet也有一个实例仓库,其中有一些有意思的项目比如语音识别,但是感觉实现的非常不友好,代码几乎凌乱不堪;
  • Caffe2 Caffe2相对于前面两者来说可以说非常弱了,没有丝毫亮点,说好的一个C++高速工业级框架的呢?除了吹牛逼忽悠大众能搞些有用的官方使用文档或者教程出来吗?不好多说什么。
  • PyTorch就一笔带过了,偏向于学术快速实现,要工业级应用,比如做个模型跑到服务器上或者安卓手机上或者嵌入式上应该搞不来;
  • PaddlePaddle 现在做的还不错,我强调一句,Paddle是唯一一个不配置任何第三方库,克隆直接make就能成功的框架, 被caffe编译虐过的人应该对此深有感触。

说了这么多,相信大家对目前的框架有了一个大致的了解,那么接下来我们就用其中几个框架来完成分类图片这么一个任务吧,这里面将包含图片如何导入模型如何写网络, 整个训练的Pipeline等内容。

我们此次评测的任务是图片分类,大家尝试任何一个框架只需要新建一个文件夹,比如mxnet_classifier, 把数据扔到 data 里即可,我们侧重评测数据预处理的复杂程度,和网络编写的复杂程度。

图片下载地址images.tar , annotations.tar. 解压之后得到:

       
       
       
       
paddle_test
└── data
├── annotation .tar
└── images .tar

解压之后Images下面每一个文件夹是一个类别的狗, 其实分类任务我们只要这个就可以了。

1. MXNet

首先上场的,用MXNet吧。建议大家看一下上面我贴出的Gluon李沐大神写的PPT,包含了Gluon和其他框架的区别,以及MXNet在多GPU上训练的优势。

没有安装的安装一下:

       
       
       
       
sudo pip3 install mxnet
sudo pip3 install mxnet-cu80
sudo pip3 install mxnet-cu80mkl

分别是CPU乞丐版,GPU土豪版,GPU加CPU加速至尊豪华版。安装完了你应该clone一下mxnet的源代码,从tools里面找到im2rec.py这个工具,我们做图片,不管是检测还是分割还是分类,都按照mxnet的逻辑把图片转成二进制的rec格式吧。

我们现在有了Images文件夹,用im2rec.py处理参数这样写:

       
       
       
       
python3 im2rec . py standford_dogs Images/ - - list true - - recursive true - - train - ratio 0 . 8 - - test - ratio 0 . 2

这一步会生成两个文件:

  • standford_dogs_train.lst
  • standford_dogs_test.lst

standford_dogs 是前缀, —list true表示生成列表,recursive用户这种每一个文件夹代表一类的情况,最后在standford_dogs_train.lst 里面的一行是这样的:

       
       
       
       
5008 27.000000 n 02092339-Weimaraner/n020 92339_2885.jpg
5092 27.000000 n 02092339-Weimaraner/n020 92339_6548.jpg

第一个数字是图片的总数目的index,第二个应该是类别的index但是这个.0000有点不可思议。好了,有了这个lst文件我们继续用im2rec来生成rec二进制数据吧, 这一步非常简单了,直接load上面的prefix和Images这个图片根目录即可:

       
       
       
       
python3 im2rec .py standford_dogs Images/

mxnet会依次生成train和test的rec文件:

PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测_第1张图片

OK, mxnet做数据集也不是非常的麻烦,这个过程如果满分五分的话我给4分,pytorch如果不考虑性能的话应该是最直接的,直接从文件夹导入,但是rec格式更快。生成之后总共有了2.8G的文件。

好了,数据准备了,直接写一个网络开始训练罗?我要写一个vgg怎么办?我要看论文吗?我要从第一层开始看网络结构吗?我要换ResNet怎么办?要换Inception怎么办?没有关系!mxnet 官方example包含了大多数这些网络结构!!

       
       
       
       
├── alexnet .py
├── googlenet .py
├── inception-bn .py
├── inception-resnet-v2 .py
├── inception-v3 .py
├── inception-v4 .py
├── lenet .py
├── mlp .py
├── mobilenet .py
├── resnet-v1 .py
├── resnet .py
├── resnext .py
└── vgg.py

更重要的是,我们看看alexnet的代码:

       
       
       
       
import mxnet as mx
import numpy as np
def get_symbol(num_classes, dtype='float32', **kwargs):
input_data = mx.sym.Variable( name="data")
if dtype == 'float16':
input_data = mx.sym.Cast( data=input_data, dtype=np.float16)
# stage 1
conv1 = mx.sym.Convolution( name='conv1',
data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96)
relu1 = mx.sym.Activation( data=conv1, act_type="relu")
lrn1 = mx.sym.LRN( data=relu1, alpha=0.0001, beta=0.75, knorm=2, nsize=5)
pool1 = mx.sym.Pooling(
data=lrn1, pool_type="max", kernel=(3, 3), stride=(2,2))
# stage 2
conv2 = mx.sym.Convolution( name='conv2',
data=pool1, kernel=(5, 5), pad=(2, 2), num_filter=256)
relu2 = mx.sym.Activation( data=conv2, act_type="relu")
lrn2 = mx.sym.LRN( data=relu2, alpha=0.0001, beta=0.75, knorm=2, nsize=5)
pool2 = mx.sym.Pooling( data=lrn2, kernel=(3, 3), stride=(2, 2), pool_type="max")
# stage 3
conv3 = mx.sym.Convolution( name='conv3',
data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu3 = mx.sym.Activation( data=conv3, act_type="relu")
conv4 = mx.sym.Convolution( name='conv4',
data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu4 = mx.sym.Activation( data=conv4, act_type="relu")
conv5 = mx.sym.Convolution( name='conv5',
data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)
relu5 = mx.sym.Activation( data=conv5, act_type="relu")
pool3 = mx.sym.Pooling( data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max")
# stage 4
flatten = mx.sym.Flatten( data=pool3)
fc1 = mx.sym.FullyConnected( name='fc1', data=flatten, num_hidden=4096)
relu6 = mx.sym.Activation( data=fc1, act_type="relu")
dropout1 = mx.sym.Dropout( data=relu6, p=0.5)
# stage 5
fc2 = mx.sym.FullyConnected( name='fc2', data=dropout1, num_hidden=4096)
relu7 = mx.sym.Activation( data=fc2, act_type="relu")
dropout2 = mx.sym.Dropout( data=relu7, p=0.5)
# stage 6
fc3 = mx.sym.FullyConnected( name='fc3', data=dropout2, num_hidden=num_classes)
if dtype == 'float16':
fc3 = mx.sym.Cast( data=fc3, dtype=np.float32)
softmax = mx.sym.SoftmaxOutput( data=fc3, name='softmax')
return softmax

非常非常非常简洁!!!!,只是一个函数,唯一不同的就是类别的数目不同,最后函数根据类别不同返回一个softmax的loss。

最后我们看看怎么把数据导入,然后训练的!!!

       
       
       
       
"""
train pipe line in mxnet
"""
import mxnet as mx
from symbols.vgg import get_vgg
def train():
num_classes = 120
batch_size = 64
# shape not have to be it exactly are
data_shape = ( 3, 64, 64)
num_epoch = 50
prefix = 'standford_dogs_model'
train_iter = mx.io.ImageRecordIter(
path_imgrec= "data/standford_dogs_train.rec",
data_shape=data_shape,
batch_size=batch_size,
)
val_iter = mx.io.ImageRecordIter(
path_imgrec= "data/standford_dogs_test.rec",
data_shape=data_shape,
batch_size=batch_size,
)
model = mx.model.FeedForward(
# set mx.gpu(0, 1) for multiple gpu
ctx=mx.cpu(),
symbol=get_vgg(num_classes=num_classes),
num_epoch=num_epoch,
learning_rate= 0.01,
)
model.fit(
X=train_iter,
eval_data=val_iter,
# every 10 iteration log info
batch_end_callback=mx.callback.Speedometer(batch_size, 10),
epoch_end_callback=mx.callback.do_checkpoint(prefix=prefix)
)
if __name__ == '__main__':
train()

尼玛,简直简单到想哭。大家注意这里get_vgg就是直接从官方的example/image-classification里面拿的,我们训练一个vgg看看。运行之后发现网络已经跑起来了:

PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测_第2张图片

温馨提示一下,MXNet貌似已经摒弃了上面的写法,上面的写法和PyTorch一样,是一种生成式的写法,Model和Module的区别就是,后者更加Tensor化,也就是图化,运行之前先把GPU占领一下再说。

OK, MXNet的坑已经踩完了。我来总结一下MXNet不为人知的几点:

  • 这是一个良心框架。可以看出它的开发者再用心的追求速度和易用性,否则也不会推出Gluon这个接口了,这个接口就是让普通开发者更加易用,同时追求速度;
  • MXNet是唯一一个比较中立的框架,你要知道,Google推出TensorFlow可是有小九九的,其内部至少有几套速度更快的纯C写的版本,否则TensorFlow怎么那么慢?不拉开差距怎么来的KPI?怎么让全球开发者为Google服务?(不是Google员工也是不是Google敌对员工,逃…)
  • MXNet的未来潜力很大,我最近在研究MXNet构建复杂的网络,比如Cycle-GAN,比如Seq2Seq的实现,但是不得不承认,这方面TensorFlow更加强大…

2. PaddlePaddle

为什么第二个评测用PaddlePaddle?第一,它最近表现很好,但是知道人很少,秉着为开发者引路的原则,增加以下曝光度,其实说实话,很多人不知道PaddlePaddle已经升级到了v2的Python API,而且内部还引入很多Go语言的代码,我没有仔细看这些代码是用来干啥的,但是很显然,PaddlePaddle在追求速度。

对Paddle的评测我这里列举以下Paddle的几个亮点的地方:

  • 相对来说更易用的API,所谓相对是因为,它还是有一些冗杂的地方;
  • 占用内存小,速度快,Paddle在百度内部应该也服务了相当多的项目,因此工业应用不成问题;
  • 中文支持,不想国外的框架,PaddlePaddle还是有着相当多的中文文档的;
  • PaddlePaddle在自然语言处理上有很多现成的历程,比如情感分类,甚至是语音识别都有Demo;
  • PaddlePaddle支持多机多卡训练,也算是集大成者。

关于PaddlePaddle使用的Pipeline异步到我之前写的一个文章传送门。

3. TensorFlow

关于tf,还真的是爱恨交加,从刚入手到现在,他的API的繁杂性以及训练的繁琐几乎让人望而却步,不过好在它有一个非常强大的生态。我们来看看TensorFlow做分类任务应该怎么做。

首先,毫无疑问,最好的方法是把图片放到tfrecord这个文件类型中去。但是如何生成tfrecord是个蛋疼的问题,在这里我申明一点,tfrecord和MXNet的rec文件不同:

tfrecod是将文件以键值对的形式存放起来了,每个记录就是一个example,而MXNet存储需要先建立一个lst,然后从lst转成二进制文件。好吧其实也差不多,不过你应该能理解我说的意思。

我们看一下一个用来将图片转为tfrecord的代码:

       
       
       
       
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os
import random
import sys
import threading
import numpy as np
import tensorflow as tf
class TFRecordsGenerator(object):
"""
this class is using for tf_records generations in image classification use
For usages:
All images must contains in different folders, TFRecordsGenerator will traverse
all folders and find different classes.
"""
def __init__(self,
name,
images_dir,
classes_file_path,
tf_records_save_dir,
num_shards= 4,
num_threads= 4):
self.name = name
self.classes_file_path = classes_file_path
self.images_dir = images_dir
self.tf_records_saved_dir = tf_records_save_dir
self.num_shards = num_shards
self.num_threads = num_threads
@staticmethod
def _int64_feature(value):
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
@staticmethod
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(self, filename, image_buffer, label, text, height, width):
"""
Example for image classification
:param filename:
:param image_buffer:
:param label:
:param text:
:param height:
:param width:
:return:
"""
color_space = 'RGB'
channels = 3
image_format = 'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': self._int64_feature(height),
'image/width': self._int64_feature(width),
'image/color_space': self._bytes_feature(tf.compat.as_bytes(color_space)),
'image/channels': self._int64_feature(channels),
'image/class/label': self._int64_feature(label),
'image/class/text': self._bytes_feature(tf.compat.as_bytes(text)),
'image/format': self._bytes_feature(tf.compat.as_bytes(image_format)),
'image/filename': self._bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
'image/encoded': self._bytes_feature(tf.compat.as_bytes(image_buffer))}))
return example
class ImageCoder(object):
def __init__(self):
self._sess = tf.Session()
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels= 3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format= 'rgb', quality= 100)
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels= 3)
def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})
def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[ 2] == 3
return image
@staticmethod
def _is_png(filename):
return '.png' in filename
def _process_image(self, filename, coder):
with tf.gfile.FastGFile(filename, 'r') as f:
image_data = f.read()
if self._is_png(filename):
print( 'Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)
image = coder.decode_jpeg(image_data)
assert len(image.shape) == 3
height = image.shape[ 0]
width = image.shape[ 1]
assert image.shape[ 2] == 3
return image_data, height, width
def _process_image_files_batch(self, coder, thread_index, ranges, name, file_names,
texts, labels, num_shards):
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][ 0],
ranges[thread_index][ 1],
num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][ 1] - ranges[thread_index][ 0]
counter = 0
for s in range(num_shards_per_batch):
shard = thread_index * num_shards_per_batch + s
output_filename = '%s-%.5d-of-%.5d.tfrecord' % (name, shard, num_shards)
output_file = os.path.join(self.tf_records_saved_dir, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
filename = file_names[i]
label = labels[i]
text = texts[i]
image_buffer, height, width = self._process_image(filename, coder)
example = self._convert_to_example(filename, image_buffer, label,
text, height, width)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
print( '%s [thread %d]: Processed %d of %d images in thread batch.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
writer.close()
print( '%s [thread %d]: Wrote %d images to %s' %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
print( '%s [thread %d]: Wrote %d images to %d shards.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
def _process_image_files(self, file_names, texts, labels):
assert len(file_names) == len(texts)
assert len(file_names) == len(labels)
spacing = np.linspace( 0, len(file_names), self.num_threads + 1).astype(np.int)
ranges = []
for i in range(len(spacing) - 1):
ranges.append([spacing[i], spacing[i + 1]])
print( 'Launching %d threads for spacings: %s' % (self.num_threads, ranges))
sys.stdout.flush()
coord = tf.train.Coordinator()
coder = self.ImageCoder()
threads = []
for thread_index in range(len(ranges)):
args = (coder, thread_index, ranges, self.name, file_names,
texts, labels, self.num_shards)
t = threading.Thread(target=self._process_image_files_batch, args=args)
t.start()
threads.append(t)
coord.join(threads)
print( '%s: Finished writing all %d images in data set.' %
(datetime.now(), len(file_names)))
sys.stdout.flush()
def _find_image_files(self):
print( 'Determining list of input files and labels from %s.' % self.images_dir)
unique_labels = [l.strip() for l in tf.gfile.FastGFile(
self.classes_file_path, 'r').readlines()]
labels = []
file_names = []
texts = []
label_index = 1
for text in unique_labels:
jpeg_file_path = '%s/%s/*' % (self.images_dir, text)
matching_files = tf.gfile.Glob(jpeg_file_path)
labels.extend([label_index] * len(matching_files))
texts.extend([text] * len(matching_files))
file_names.extend(matching_files)
if not label_index % 100:
print( 'Finished finding files in %d of %d classes.' % (
label_index, len(labels)))
label_index += 1
shuffled_index = list(range(len(file_names)))
random.seed( 12345)
random.shuffle(shuffled_index)
file_names = [file_names[i] for i in shuffled_index]
texts = [texts[i] for i in shuffled_index]
labels = [labels[i] for i in shuffled_index]
print( 'Found %d JPEG files across %d labels inside %s.' %
(len(file_names), len(unique_labels), self.images_dir))
print( '[INFO] Attempting logging out file_names list: {}'.format( '\n'.join(file_names)))
return file_names, texts, labels
def generate(self):
assert not self.num_shards % self.num_threads, (
'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
print( 'Saving results to %s' % self.tf_records_saved_dir)
file_names, texts, labels = self._find_image_files()
self._process_image_files(file_names, texts, labels)
print( 'All Done! Solved {} images. tf_records file saved into {}.'.format(len(file_names), os.path.abspath(
self.tf_records_saved_dir)))

这是我包装的一个类,只要传入路径调用generate就可以生成tfrecord文件。看到这里估计你已经哭了,尼玛这么复杂?!!!!????

好吧,暂且不管这个具体咋么实现的,再来看看数据怎么load进模型的吧:

       
       
       
       
import tensorflow as tf
import logging
import numpy as np
import os
import time
from datasets.tiny5.tiny5 import Tiny5
from models.alexnet import AlexNet
from models.vgg import VGGNet
from models.fanet import FaNet
logging.basicConfig(level=logging.DEBUG,
format= '%(asctime)s %(filename)s line:%(lineno)d %(levelname)s %(message)s',
datefmt= '%a, %d %b %Y %H:%M:%S')
tf.app.flags.DEFINE_string( 'checkpoints_dir', './checkpoints/tiny5/', 'checkpoints save path.')
tf.app.flags.DEFINE_string( 'model_prefix', 'tiny5-alex-net', 'model save prefix.')
tf.app.flags.DEFINE_boolean( 'is_restore', False, 'to restore from previous or not.')
tf.app.flags.DEFINE_integer( 'target_width', 256, 'target width for resize.')
tf.app.flags.DEFINE_integer( 'target_height', 256, 'target height for resize.')
tf.app.flags.DEFINE_integer( 'batch_size', 24, 'batch size for train.')
FLAGS = tf.app.flags.FLAGS
def running(is_train=True):
if not os.path.exists(FLAGS.checkpoints_dir):
os.makedirs(FLAGS.checkpoints_dir)
tiny5 = Tiny5(
images_dir= './datasets/tiny5/images',
classes_file_path= './datasets/tiny5/tiny5_classes.txt',
target_height=FLAGS.target_height,
target_width=FLAGS.target_width,
batch_size=FLAGS.batch_size
)
images, labels = tiny5.batch_inputs()
print(images)
# model = AlexNet(num_classes=5)
# model = VGGNet(num_classes=5)
model = FaNet(num_classes= 5)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
saver = tf.train.Saver(max_to_keep= 2)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(init_op)
start_epoch = 0
checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
if FLAGS.is_restore:
if checkpoint:
saver.restore(sess, checkpoint)
logging.info( "restore from the checkpoint {0}".format(checkpoint))
start_epoch += int(checkpoint.split( '-')[ -1])
if is_train:
step = 0
logging.info( 'training start...')
try:
while not coord.should_stop():
feed_dict = model.make_train_inputs(images, labels)
_, loss, step = sess.run(
[model.train_op, model.loss, model.global_step], feed_dict=feed_dict
)
logging.info( 'epoch {}, loss {}'.format(step, loss))
except tf.errors.OutOfRangeError:
logging.info( 'optimization done! enjoy color net.')
saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.checkpoints_prefix), global_step=step)
except KeyboardInterrupt:
logging.info( 'interrupt manually, try saving checkpoint for now...')
saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=step)
logging.info( 'last epoch were saved, next time will start from epoch {}.'.format(step))
finally:
coord.request_stop()
coord.join(threads)
else:
logging.info( 'start inference...')
inference_image_path = './images/1.png'
input_image = tiny5.single_image_input(inference_image_path)
feed_dict = model.make_inference_inputs(input_image)
outputs = sess.run([model.inference_outputs(n_top= 2)], feed_dict=feed_dict)
print(outputs)
def main(args):
running(args)
if __name__ == '__main__':
tf.app.run()

这个训练的代码,大概的训练步骤分为:

  • 使用tf.ConfigProto()来生成一个config,设置gpu自动生长,同时设置一个saver,这个saver就是最大保存的数目;
  • 设置初始化的变量op,设置一个tf.Train.Coordinator()来作为训练协调者,初始化图;
  • for循环所有的epoch,在每次循环里面catch一下tf.errors.OutOfRangeError表示一个batch训练完了,catch一下KeyBoardInterrupt;
  • 最后是保存模型

大家可以感受一下TensorFlow一整套流程下来的复杂程度。这里面还没有写我的网络,没有写我的数据DataLoader,整个代码在我的GitHub仓库可以找到原始代码,传送门, 如果你觉得那个项目过于陈旧可以跟进我的一些最新的项目,我近期在TensorFlow上做的工作有:

  • 用Google最新nmt模型训练聊天机器人;
  • 使用GAN做Cylce-GAN生成;
  • 使用KnowledgeDatabase和知识图谱做问答系统;
  • 目标检测和分割等常规性工作

4. PyTorch

PyTorch如果做图片预测我就不详细讲了,很多人说PyTorch很简单,但是我并没有觉得简单到哪里去,我总结一下PyTorch目前来说一些优点吧。

  • 立即式编程,也就是运行立马出结果,不同于TensorFlow的图式,你必须把所有程序写完之后才知道结果什么;
  • 安装也比较方便,但是跨平台部署就比较麻烦了,这也和PyTorch的定位有关,当然PyTorch刚推出来的时候有几篇官方教程写的不错,主要是RNN文本生成,Seq2Seq翻译的实现,有兴趣的同学可以看一下,但是都是非常简单的实现,跟TensorFlow的官方例子差距蛮大;
  • 只是构建网络比较简单,但是具体训练的PipeLine还是有点麻烦,尤其是我每次变量还得指定是CPU还是GPU,每次load模型的时候还得load是CPU还是GPU,个人感觉略麻烦;

PyTorch推出来的时候很火,现在貌似熄火了….

5. Caffe2

caffe2 不得不提一下,caffe的进化版本????caffe用着还好,c++调接口还蛮方便,例子也很多,caffe2为毛主打python,还python2???不过这也跟caffe2定位于工业使用有关,但是总体来说有这么几点:

  • 感觉没有多少社区,虽然caffe非常多公司用,但是那毕竟是第一代版本,一般公司用用还行,容易与时代脱节;
  • caffe2也没有多少亮点,官方的教程我是没有看到什么实质性的东西,后期也没有更多的example;
  • 好像C++接口也不是非常友好,至少在例子上很少….一个框架推出来,不教人去用那推出来有啥意思?

总结

我写文章喜欢一目了然,文章结构大致对比了5种框架的优缺点,那么我直接给使用者一些建议,防止大家采坑:

  • 如果你是深度学习老鸟,你应该选择TensorFlow,但是我不得不告诉你TensorFlow在1.2版本推出来的API,在1.4版本很有可能就大改了…..
  • 如果你是深度学习菜鸟,你应该选择MXNet或者PaddlePaddle,很多人会说,我曹,为什么不用Keras??好吧,Keras当然也可以用,但是不建议一直用,还是得熟悉一下稍微底层一些的框架;
  • 如果你是….如果你是小学生?高中生或者初中生,你可以用一下PaddlePaddle,因为你英文可能不太好。

如果你想跟进我的更多TensorFlow项目欢迎在Github寻找我的联系方式,加入QQ群交流。

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

你可能感兴趣的:(TensorFlow,Caffe,PyTorch,PaddlePaddle,MXNet)