tensorflow 学习(3)-Lenet
Author:Joyner
学习mnist数据集训练
1.数据集
192.168.9.5:/DATACENTER1/zhiwen.wang/tensorflow-wzw/MNIST_data
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
2.代码下载
https://github.com/sujaybabruwad/LeNet-in-Tensorflow
3.修改pre_data.py的路径
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
def pre_data():
mnist = input_data.read_data_sets("/DATACENTER1/zhiwen.wang/tensorflow-wzw/MNIST_data", reshape=False)
X_train, y_train = mnist.train.images, mnist.train.labels
X_validation, y_validation = mnist.validation.images, mnist.validation.labels
X_test, y_test = mnist.test.images, mnist.test.labels
assert(len(X_train) == len(y_train))
assert(len(X_validation) == len(y_validation))
assert(len(X_test) == len(y_test))
print("Image Shape: {}".format(X_train[0].shape))
print("Training Set: {} samples".format(len(X_train)))
print("Validation Set: {} samples".format(len(X_validation)))
print("Test Set: {} samples".format(len(X_test)))
# Pad images with 0s
X_train = np.pad(X_train, ((0,0),(2,2),(2,2),(0,0)), 'constant')
X_validation = np.pad(X_validation, ((0,0),(2,2),(2,2),(0,0)), 'constant')
X_test = np.pad(X_test, ((0,0),(2,2),(2,2),(0,0)), 'constant')
return X_train,y_train,X_validation,y_validation,X_test,y_test
4.训练(python3的环境下运行)
cd /DATACENTER1/zhiwen.wang/tensorflow-wzw/Lenet-5-tensorflow/src
CUDA_VISIBLE_DEVICES=1 python3 main/train_and_evaluate.py
5.训练结果