使用过TensorFlow的同学应该知道,这个深度学习库是有点难用的,它需要我们自己定义所有的计算节点(op)。比如,如果我们要定义一个全连接层,首先需要定义全连接层的参数W和b(Variable),然后定义输出Tensor等于输入Tensor乘以W加上b再取激活函数a。想想就觉得好麻烦......
TFLearn则是一个建立在TensorFlow上的模块化的、透明的深度学习库。它比TensorFlow提供了更高层次的API,从而使我们更加快速地进行实验。比如,我们想要定义一个全连接层就只需要一个语句:net=tflearn.fully_connected(net,64)。是不是非常方便?
TFLearn库有如下特性
1)Easy-to-use and understand high-level API for implementing deep neural networks, with tutorial and examples.
2)Fast prototyping through highly modular built-in neural network layers, regularizers, optimizers, metrics...
3)Full transparency over Tensorflow. All functions are built over tensors and can be used independently of TFLearn.
4)Powerful helper functions to train any TensorFlow graph, with support of multiple inputs, outputs and optimizers.
5)Easy and beautiful graph visualization, with details about weights, gradients, activations and more...
6)Effortless device placement for using multiple CPU/GPU.
目前,提供的高层次API支持最新的深度学习模型,比如Convolutions, LSTM, BiRNN, BatchNorm, PReLU, Residual networks, Generative networks。
值得注意的是,目前最新版本的TFLearn(v0.3)仅仅兼容TensorFlow v1.0及更高版本。
GitHub:https://github.com/tflearn/tflearn
官方网站:http://tflearn.org/
关于从哪儿开始学习TFLearn,官网给出了下面几个链接:
1)安装TFLearn:安装指导;
2)入门TFLearn:开始啃TFLearn,TFLearn教程;
3)示例:从例子中窥探TFLearn;
4)API列表:API文档。
下面我们举一个例子来帮助我们快速上手TFLearn,在这个例子中我们将对泰坦尼克号上的乘客进行存活可能性预测。
首先,我们来看一下数据集,每一个乘客的相关信息如下:
其中总共有9项,我们将其分为标签(label)和输入(data),则标签为是否存活,存活为1,输入包含8项,其中我们认为姓名以及船票的号码(可以由票价直接体现)对于我们预测乘客的存活几率是没有什么用的,所以在预处理中,我们将其抛弃。下面是乘客信息的样本:
在了解清楚数据结构之后,我们就可以愉快地加载并处理数据啦:
数据集被存储为csv文件格式。csv,全称为Comma-Separated Values,即逗号分隔值,其文本以纯文本形式存储表格数据,我们可以使用文本编辑器或excel直接打开。下面我们先加载数据到内存中:
import numpy as np
import tflearn
# Download the Titanic dataset
from tflearn.datasets import titanic
titanic.download_dataset('titanic_dataset.csv')
# Load CSV file, indicate that the first column represents labels
from tflearn.data_utils import load_csv
data, labels = load_csv('titanic_dataset.csv', target_column=0,
categorical_labels=True, n_classes=2)
上面使用load_csv()函数从csv文件中读取数据,并转为python List。其中target_column参数用于表示我们的标签列id,该函数将返回一个元组:(data,labels)。
然后按照我们前面说的,抛弃输入中的姓名以及船票号码字段,并将性别字段转为数值,0表示男性,1表示女性,预处理如下:
# Preprocessing function
def preprocess(data, columns_to_ignore):
# Sort by descending id and delete columns
for id in sorted(columns_to_ignore, reverse=True):
[r.pop(id) for r in data]
for i in range(len(data)):
# Converting 'sex' field to float (id is 1 after removing labels column)
data[i][1] = 1. if data[i][1] == 'female' else 0.
return np.array(data, dtype=np.float32)
# Ignore 'name' and 'ticket' columns (id 1 & 6 of data array)
to_ignore=[1, 6]
# Preprocess data
data = preprocess(data, to_ignore)
接着我们就进入最激动人心的部分了,构建神经网络:
# Build neural network
net = tflearn.input_data(shape=[None, 6])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net)
正如上面所说,TFLearn中采用Tensor进行运算,因此这里的net都是Tensor,与TensorFlow中一样,我们也可以将其中的某一个部分用TensorFlow中的函数自己写,从而实现一些TFLearn库中没有的功能,其中input_data和fully_connected定义在/usr/local/lib/python2.7/dist-packages/tflearn/layers/core.py文件中,其中全连接层的W(weights_init)和b(bias_init)可以指定,不过默认为W:'truncated_normal',b:'zeros',此外,其中的activation参数默认为'linear'。
然后是训练:
# Define model
model = tflearn.DNN(net)
# Start training (apply gradient descent algorithm)
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)
其中tflearn.DNN是TFLearn中提供的一个模型wrapper,相当于我们将很多功能包装起来,我们给它一个net结构,生成一个model对象,然后调用model对象的训练、预测、存储等功能,DNN类有三个属性(成员变量):trainer,predictor,session。在fit()函数中n_epoch=10表示整个训练数据集将会用10遍,batch_size=16表示一次用16个数据计算参数的更新。
最后利用训练得到的模型进行预测:
# Let's create some data for DiCaprio and Winslet
dicaprio = [3, 'Jack Dawson', 'male', 19, 0, 0, 'N/A', 5.0000]
winslet = [1, 'Rose DeWitt Bukater', 'female', 17, 1, 2, 'N/A', 100.0000]
# Preprocess data
dicaprio, winslet = preprocess([dicaprio, winslet], to_ignore)
# Predict surviving chances (class 1 results)
pred = model.predict([dicaprio, winslet])
print("DiCaprio Surviving Rate:", pred[0][1])
print("Winslet Surviving Rate:", pred[1][1])
其中的pred为对于[dicaprio,winslet]预测得到的结果,对于其中某一个(比如dicaprio)进行预测的结果为[死亡概率,存活概率],所以这里打印的是pred[i][1]。
调用model对象的predict()函数对数据进行预测,结果如下:
其中Dicaprio是男主角,Winslet为女主角,可以看出预测还是比较准的。关于TFLearn的第一次介绍就到这里咯~
马上到周末了,祝大家玩得愉快~