win10 + vs2019 + opencv(OpenCV_VERSION 3.4.6) + keras(tensorflow后端)
从http://yann.lecun.com/exdb/mnist/下载
“data/t10k-labels-idx1-ubyte.gz”;
“data/t10k-images-idx3-ubyte.gz”;
“data/train-labels-idx1-ubyte.gz”;
“data/train-images-idx3-ubyte.gz”;
用解压文件直接解压为以下文件
“data/t10k-labels-idx1-ubyte”;
“data/t10k-images-idx3-ubyte”;
“data/train-labels-idx1-ubyte”;
“data/train-images-idx3-ubyte”;
使用C++和OpenCV读取MNIST文件[https://blog.csdn.net/sheng_ai/article/details/23267039]
(代码可以直接使用)
#%%
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Conv2D, MaxPool2D, ReLU, Input, Softmax, Reshape
from keras import backend as K
import tensorflow as tf
#%%
tf.reset_default_graph()
#%%
def net(input_size, optimizer):
input_x = Input(input_size, name="x")
conv1 = Conv2D(20, 5, padding = 'same', kernel_initializer = 'he_normal')(input_x)
pool1 = MaxPool2D(2)(conv1)
relu1 = ReLU()(pool1)
conv2 = Conv2D(50, 5, padding = 'same', kernel_initializer = 'he_normal')(relu1)
pool2 = MaxPool2D(2)(conv2)
relu2 = ReLU()(pool2)
conv3 = Conv2D(10, 5, padding = 'same', kernel_initializer = 'he_normal')(relu2)
pool3 = MaxPool2D(7)(conv3)
out = Reshape([10])(pool3)
out = Softmax(name="output")(out)
model = Model(inputs=input_x, outputs = out)
model.compile(
optimizer=optimizer,
loss = "categorical_crossentropy",
metrics=["accuracy"]
)
return model
(train_x, train_y), (test_x, test_y) = mnist.load_data()
train_x = train_x.reshape(train_x.shape[0], 28, 28, 1) / 255
test_x = test_x.reshape(test_x.shape[0], 28, 28, 1) / 255
train_y = np_utils.to_categorical(train_y, num_classes=10)
test_y = np_utils.to_categorical(test_y, num_classes=10)
rmsprop = keras.optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
model = net([28, 28, 1], rmsprop)
#%%
print(model.summary())
#%%
print("Training --------------")
model.fit(train_x, train_y, epochs=4, batch_size=32)
print("Testing --------------")
loss, accuracy = model.evaluate(test_x, test_y)
print("test loss: ", loss)
print("test accuracy: ", accuracy)
# 查看所有节点
# tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
# print(tensor_name_list)
#%%
# 输出Pb模型
sess = K.get_session()
frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names=["output/Softmax"]
)
tf.train.write_graph(frozen_graph, "log/1.2", "mnist.pb", as_text=False)
输出结果为:
Epoch 1/4
60000/60000 [==============================] - 13s 216us/step - loss: 0.2229 - acc: 0.9322
Epoch 2/4
60000/60000 [==============================] - 10s 167us/step - loss: 0.0754 - acc: 0.9767
Epoch 3/4
60000/60000 [==============================] - 10s 167us/step - loss: 0.0550 - acc: 0.9828
Epoch 4/4
60000/60000 [==============================] - 10s 168us/step - loss: 0.0450 - acc: 0.9859
Testing --------------
10000/10000 [==============================] - 1s 62us/step
test loss: 0.05498976424507564
test accuracy: 0.9838
#include
#include
#include
#include
#include
#include "MNIST.h"
std::vector Argmax(cv::Mat x)
{
std::vector res;
for (int i = 0; i < x.rows; i++)
{
int maxIdx = 0;
float maxNum = 0.0;
for (int j = 0; j < x.cols; j++)
{
float tmp = x.at(i, j);
if (tmp > maxNum)
{
maxIdx = j;
maxNum = tmp;
}
}
res.push_back(maxIdx);
}
return res;
}
float Accuracy(cv::Mat x, cv::Mat y, std::string pbfile)
{
float count = 0.0;
cv::dnn::Net net = cv::dnn::readNetFromTensorflow(pbfile);
// blob输入时需要至少dims为3的数据, 其数据形状为(图片数目, 宽度, 高度)
int size[] = { x.rows, 28, 28 };
cv::Mat imgs = cv::Mat(3, size, CV_8UC1, x.data);
cv::Mat blob = cv::dnn::blobFromImages(imgs, 1.0 / 255.0, cv::Size(28, 28), cv::Scalar(), false, false);
net.setInput(blob);
cv::Mat pred = net.forward();
std::vector res = Argmax(pred);
for (int i = 0; i < res.size(); i++)
{
if (*(y.ptr(0) + i) == res[i])
{
count = count + 1;
}
}
return count / x.rows;
}
int main()
{
std::string testLableFile = "data/t10k-labels-idx1-ubyte";
std::string testImageFile = "data/t10k-images-idx3-ubyte";
std::string trainLableFile = "data/train-labels-idx1-ubyte";
std::string trainImageFile = "data/train-images-idx3-ubyte";
cv::Mat trainY = ReadLabels(trainLableFile);
cv::Mat testY = ReadLabels(testLableFile);
cv::Mat trainX = ReadImages(trainImageFile);
cv::Mat testX = ReadImages(testImageFile);
testY.convertTo(testY, CV_32SC1);
std::string pbfile = "mnist.pb";
//testX.convertTo(testX, CV_32FC1, 1.0/255.0, 0);
float acc = Accuracy(testX, testY, pbfile);
std::cout << acc;
return 0;
}
控制台输出为:
[ INFO:0] Initialize OpenCL runtime...
0.9838
#pragma once
#ifndef MNIST_H
#define MNIST_H
#include
#include
#include
struct MNISTImageFileHeader
{
unsigned char MagicNumber[4];
unsigned char NumberOfImages[4];
unsigned char NumberOfRows[4];
unsigned char NumberOfColums[4];
};
struct MNISTLabelFileHeader
{
unsigned char MagicNumber[4];
unsigned char NumberOfLabels[4];
};
const int MAGICNUMBEROFIMAGE = 2051;
const int MAGICNUMBEROFLABEL = 2049;
int ConvertCharArrayToInt(unsigned char* array, int LengthOfArray);
bool IsImageDataFile(unsigned char* MagicNumber, int LengthOfArray);
bool IsLabelDataFile(unsigned char* MagicNumber, int LengthOfArray);
cv::Mat ReadData(std::fstream& DataFile, int NumberOfData, int DataSizeInBytes);
cv::Mat ReadImageData(std::fstream& ImageDataFile, int NumberOfImages);
cv::Mat ReadLabelData(std::fstream& LabelDataFile, int NumberOfLabel);
cv::Mat ReadImages(std::string& FileName);
cv::Mat ReadLabels(std::string& FileName);
#endif // MNIST_H
#include "MNIST.h"
int ConvertCharArrayToInt(unsigned char* array, int LengthOfArray)
{
if (LengthOfArray < 0)
{
return -1;
}
int result = static_cast(array[0]);
for (int i = 1; i < LengthOfArray; i++)
{
result = (result << 8) + array[i];
}
return result;
}
bool IsImageDataFile(unsigned char* MagicNumber, int LengthOfArray)
{
int MagicNumberOfImage = ConvertCharArrayToInt(MagicNumber, LengthOfArray);
if (MagicNumberOfImage == MAGICNUMBEROFIMAGE)
{
return true;
}
return false;
}
/**
* @brief IsImageDataFile Check the input MagicNumber is equal to
* MAGICNUMBEROFLABEL
* @param MagicNumber The array of the magicnumber to be checked
* @param LengthOfArray The length of the array
* @return true, if the magcinumber is mathed;
* false, otherwise.
*
* @author sheng
* @version 1.0.0
* @date 2014-04-08
*
* @histroy
* sheng 2014-04-08 1.0.0 build the function
*/
bool IsLabelDataFile(unsigned char* MagicNumber, int LengthOfArray)
{
int MagicNumberOfLabel = ConvertCharArrayToInt(MagicNumber, LengthOfArray);
if (MagicNumberOfLabel == MAGICNUMBEROFLABEL)
{
return true;
}
return false;
}
/**
* @brief ReadData Read the data in a opened file
* @param DataFile The file which the data is read from.
* @param NumberOfData The number of the data
* @param DataSizeInBytes The size fo the every data
* @return The Mat which rows is a data,
* Return a empty Mat if the file is not opened or the some flag was
* seted when reading the data.
*
* @author sheng
* @version 1.0.0
* @date 2014-04-08
*
* @histroy
* sheng 2014-04-08 1.0.0 build the function
*/
cv::Mat ReadData(std::fstream& DataFile, int NumberOfData, int DataSizeInBytes)
{
cv::Mat DataMat;
// read the data if the file is opened.
if (DataFile.is_open())
{
int AllDataSizeInBytes = DataSizeInBytes * NumberOfData;
char* TmpData = new char[AllDataSizeInBytes];
DataFile.read((char*)TmpData, AllDataSizeInBytes);
// // If the state is good, convert the array to a mat.
// if (!DataFile.fail())
// {
// DataMat = cv::Mat(NumberOfData, DataSizeInBytes, CV_8UC1,
// TmpData).clone();
// }
DataMat = cv::Mat(NumberOfData, DataSizeInBytes, CV_8UC1,
TmpData).clone();
delete[] TmpData;
DataFile.close();
}
return DataMat;
}
/**
* @brief ReadImageData Read the Image data from the MNIST file.
* @param ImageDataFile The file which contains the Images.
* @param NumberOfImages The number of the images.
* @return The mat contains the image and each row of the mat is a image.
* Return empty mat is the file is closed or the data is not matching
* the number.
*
* @author sheng
* @version 1.0.0
* @date 2014-04-08
*
* @histroy
* sheng 2014-04-08 1.0.0 build the function
*/
cv::Mat ReadImageData(std::fstream& ImageDataFile, int NumberOfImages)
{
int ImageSizeInBytes = 28*28;
return ReadData(ImageDataFile, NumberOfImages, ImageSizeInBytes);
}
/**
* @brief ReadLabelData Read the label data from the MNIST file.
* @param LabelDataFile The file contained the labels.
* @param NumberOfLabel The number of the labels.
* @return The mat contains the labels and each row of the mat is a label.
* Return empty mat is the file is closed or the data is not matching
* the number.
*
* @author sheng
* @version 1.0.0
* @date 2014-04-08
*
* @histroy
* sheng 2014-04-08 1.0.0 build the function
*/
cv::Mat ReadLabelData(std::fstream& LabelDataFile, int NumberOfLabel)
{
int LabelSizeInBytes = 1;
return ReadData(LabelDataFile, NumberOfLabel, LabelSizeInBytes);
}
/**
* @brief ReadImages Read the Training images.
* @param FileName The name of the file.
* @return The mat contains the image and each row of the mat is a image.
* Return empty mat is the file is closed or the data is not matched.
*
* @author sheng
* @version 1.0.0
* @date 2014-04-08
*
* @histroy
* sheng 2014-04-08 1.0.0 build the function
*/
cv::Mat ReadImages(std::string& FileName)
{
std::fstream File(FileName.c_str(), std::ios_base::in | std::ios_base::binary);
if (!File.is_open())
{
return cv::Mat();
}
MNISTImageFileHeader FileHeader;
File.read((char*)(&FileHeader), sizeof(FileHeader));
if (!IsImageDataFile(FileHeader.MagicNumber, 4))
{
return cv::Mat();
}
int NumberOfImage = ConvertCharArrayToInt(FileHeader.NumberOfImages, 4);
return ReadImageData(File, NumberOfImage);
}
/**
* @brief ReadLabels Read the label from the MNIST file.
* @param FileName The name of the file.
* @return The mat contains the image and each row of the mat is a image.
* Return empty mat is the file is closed or the data is not matched.
*
* @author sheng
* @version 1.0.0
* @date 2014-04-08
*
* @histroy
* sheng 2014-04-08 1.0.0 build the function
*/
cv::Mat ReadLabels(std::string& FileName)
{
std::fstream File(FileName.c_str(), std::ios_base::in | std::ios_base::binary);
if (!File.is_open())
{
return cv::Mat();
}
MNISTLabelFileHeader FileHeader;
File.read((char*)(&FileHeader), sizeof(FileHeader));
if (!IsLabelDataFile(FileHeader.MagicNumber, 4))
{
return cv::Mat();
}
int NumberOfImage = ConvertCharArrayToInt(FileHeader.NumberOfLabels, 4);
return ReadLabelData(File, NumberOfImage);
}