tensorflow 学习(3)-Lenet

tensorflow 学习(3)-Lenet

Author:Joyner

学习mnist数据集训练


1.数据集

192.168.9.5:/DATACENTER1/zhiwen.wang/tensorflow-wzw/MNIST_data

t10k-images-idx3-ubyte.gz

t10k-labels-idx1-ubyte.gz

train-images-idx3-ubyte.gz

train-labels-idx1-ubyte.gz


2.代码下载

https://github.com/sujaybabruwad/LeNet-in-Tensorflow


3.修改pre_data.py的路径

from tensorflow.examples.tutorials.mnist import input_data

import numpy as np

def pre_data():

    mnist = input_data.read_data_sets("/DATACENTER1/zhiwen.wang/tensorflow-wzw/MNIST_data", reshape=False)

    X_train, y_train          = mnist.train.images, mnist.train.labels

    X_validation, y_validation = mnist.validation.images, mnist.validation.labels

    X_test, y_test            = mnist.test.images, mnist.test.labels

    assert(len(X_train) == len(y_train))

    assert(len(X_validation) == len(y_validation))

    assert(len(X_test) == len(y_test))

    print("Image Shape: {}".format(X_train[0].shape))

    print("Training Set:  {} samples".format(len(X_train)))

    print("Validation Set: {} samples".format(len(X_validation)))

    print("Test Set:      {} samples".format(len(X_test)))

    # Pad images with 0s

    X_train      = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')

    X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')

    X_test      = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')

    return X_train,y_train,X_validation,y_validation,X_test,y_test


4.训练(python3的环境下运行)

cd /DATACENTER1/zhiwen.wang/tensorflow-wzw/Lenet-5-tensorflow/src

CUDA_VISIBLE_DEVICES=1 python3 main/train_and_evaluate.py


5.训练结果


Lenet训练结果图

你可能感兴趣的:(tensorflow 学习(3)-Lenet)