TensorFlow动态图使用笔记

最近在学习相关Keras的相关内容,看到TensorFlow的动态图的创建内容,根据大牛们的代码修改一下,操作mnist数据集,记录下来,以后好查找,相关的知识点。

加载库

import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
print("tensorflow: {}".format(tf.VERSION))
from keras.datasets import mnist
import numpy as np

数据集相关的操作

# 下载数据集
def load_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    return (x_train, y_train), (x_test, y_test)

# 处理相关的数据集
class MnistData:
    def __init__(self, name, need_shuffle):
        (x_train, y_train), (x_test, y_test) = load_data()
        # 训练集和测试集的处理方法
        if name == 'train':
            # 数据归一化 0~1
            self._data = x_train
            self._labels = y_train
        else:
            # 数据归一化 0~1
            self._data = x_test
            self._labels = y_test
        print(self._data.shape)
        print(self._labels.shape)

        self._num_examples = self._data.shape[0]
        self._need_shuffle = need_shuffle
        self._indicator = 0
        if self._need_shuffle:
            self._shuffle_data()

    # 打乱数据集
    def _shuffle_data(self):
        # [0,1,2,3,4,5] -> [5,3,2,4,0,1]
        p = np.random.permutation(self._num_examples)
        self._data = self._data[p]
        self._labels = self._labels[p]

    # 获取batchsize的数据
    def next_batch(self, batch_size):
        """return batch_size examples as a batch."""
        end_indicator = self._indicator + batch_size
        if end_indicator > self._num_examples:
            if self._need_shuffle:
                self._shuffle_data()
                self._indicator = 0
                end_indicator = batch_size
            else:
                raise Exception("have no more examples")
        if end_indicator > self._num_examples:
            raise Exception("batch size is larger than all examples")
        batch_data = self._data[self._indicator: end_indicator]
        batch_labels = self._labels[self._indicator: end_indicator]
        self._indicator = end_indicator
        return batch_data, batch_labels

获取数据

train_data = MnistData('train', True)
batch_size = 20

创建模型 使用tf.layers.Layer

# 创建模型mnist,init方法,是定义相关的操作,call方法,是把init方法组合起来
class MNISTModel(tf.layers.Layer):
    def __init__(self, name):
        super(MNISTModel, self).__init__(name=name)
        self._input_shape = [-1, 28, 28, 1]
        self.conv1 = tf.layers.Conv2D(32, 5, activation=tf.nn.relu)
        self.conv2 = tf.layers.Conv2D(64, 5, activation=tf.nn.relu)
        self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
        self.fc2 = tf.layers.Dense(10)
        self.dropout = tf.layers.Dropout(0.5)
        self.max_pool2d = tf.layers.MaxPooling2D(
            (2, 2), (2, 2), padding='SAME')

    def call(self, inputs, training):
        x = tf.reshape(inputs, self._input_shape)
        x = self.conv1(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.max_pool2d(x)
        x = tf.keras.layers.Flatten()(x)
        x = self.fc1(x)
        if training:
            x = self.dropout(x)
        x = self.fc2(x)
        return x

创建loss函数

def loss(model, inputs, labels):
    predictions = model(inputs, training=True)
    print(predictions[2])
    cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predictions, labels=labels)
    return tf.reduce_sum(cost)

optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
grad = tfe.implicit_gradients(loss)

训练相关代码

model = MNISTModel('net')

global_step = tf.train.get_or_create_global_step()

for epoch in range(5):
    for i in range(6000 // batch_size):
        inputs, labels = train_data.next_batch(batch_size)
        inputs, labels = tf.cast(inputs, tf.float32), tf.cast(labels, tf.int32)
        optimizer.apply_gradients(grad(model, inputs, labels), global_step=global_step)

        if i % 50 == 0:
            print("Step %d: Loss on training set : %f" %
                  (i, loss(model, inputs, labels).numpy()))
            all_variables = (
                model.variables
                + optimizer.variables()
                + [global_step])
            tfe.Saver(all_variables).save(
                "./log/ministmodel.cpkt", global_step=global_step)

test_data = MnistData('test', False)
test_inputs, test_labels = test_data.next_batch(100)
test_inputs, test_labels = tf.cast(test_inputs, tf.float32), tf.cast(test_labels, tf.int32)
print("Loss on test set : %f" %
                  (loss(model, test_inputs, test_labels).numpy()))

你可能感兴趣的:(TensorFlow动态图使用笔记)