从零基础入门Tensorflow2.0 ----六、32cifar10数据训练

every blog every motto:

0. 前言

cifar10 训练

1. 代码部分

1. 导入模块

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import os,sys
import tensorflow as tf
import time
from tensorflow import keras
os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0'
print(tf.__version__)
print(sys.version_info)
for module in mpl,pd,sklearn,tf,keras:
    print(module.__name__,module.__version__)

2. 读取数据

class_names = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]
train_labels_file = './cifar10/trainLabels.csv'
test_csv_file = './cifar10/sampleSubmission.csv'
train_folder = './cifar10/train'
test_folder = './cifar10/test'

def parse_csv_file(filepath,folder):
    """parsers csv files into(filename(path),label) format"""
    results = []
    with open(filepath,'r') as f:
        lines = f.readlines()[1:]
    for line in lines:
        image_id,label_str = line.strip('\n').split(',')
        image_full_path = os.path.join(folder,image_id + '.png')
        results.append((image_full_path,label_str))
    return results


train_labels_info = parse_csv_file(train_labels_file,train_folder)
test_csv_info = parse_csv_file(test_csv_file,test_folder)

import pprint
pprint.pprint(train_labels_info[0:5])
pprint.pprint(test_csv_info[0:5])
print(len(train_labels_info),len(test_csv_info))

2.2 划分数据

# train_df = pd.DataFrame(train_labels_info)
train_df = pd.DataFrame(train_labels_info[0:45000])
valid_df = pd.DataFrame(train_labels_info[45000:])
test_df = pd.DataFrame(test_csv_info)

# 修改列名
train_df.columns = ['filepath','class']
valid_df.columns = ['filepath','class']
test_df.columns = ['filepath','class']

print(train_df.head())
print(valid_df.head())
print(test_df.head())

3. 读取图片

# 读取图片
height = 32
width = 32
channels = 3
batch_size = 32
num_classes = 10

train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale = 1. / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range = 0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip = True,
    fill_mode = 'nearest',
)
train_generator = train_datagen.flow_from_dataframe(train_df,directory='./',x_col='filepath',y_col='class',classes=class_names,
                                                    target_size=(height,width),batch_size=batch_size,seed=7,shuffle=True,
                                                    class_mode='sparse',)


valid_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
valid_generator = valid_datagen.flow_from_dataframe(valid_df,directory='./',x_col='filepath',y_col='class',classes=class_names,
                                                    target_size=(height,width),batch_size=batch_size,seed=7,shuffle=False,
                                                    class_mode="sparse")

train_num = train_generator.samples
valid_num = valid_generator.samples
print(train_num,valid_num)
# 读取数据
for i in range(2):
    x,y = train_generator.next()
    print(x.shape,y.shape)
    print(y)

4. 模型搭建

model = keras.models.Sequential([
    keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu',input_shape=[width,height,channels]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=2),
    
    keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=2),
    
    keras.layers.Conv2D(filters=512,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=512,kernel_size=3,padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=2),
    
    # 展平
    keras.layers.Flatten(),
    keras.layers.Dense(512,activation='relu'),
    keras.layers.Dense(num_classes,activation='softmax'),
])

model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()

5. 训练

epochs = 20
history = model.fit_generator(train_generator,steps_per_epoch = train_num // batch_size,epochs=epochs,
                            validation_data = valid_generator,validation_steps=valid_num // batch_size)

6. 学习曲线

# 学习曲线
def plot_learning_curves(hsitory,label,epochs,min_value,max_value):
    data = {}
    data[label] = history.history[label]
    data['val_' + label] = hsitory.history['val_' + label]
    pd.DataFrame(data).plot(figsize=(8,5))
    plt.grid(True)
    plt.axis([0,epochs,min_value,max_value])
    plt.show()

plot_learning_curves(history,'accuracy',epochs,0,1)
plot_learning_curves(history,'loss',epochs,1.5,2.5)

7. 测试集上

test_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_dataframe(test_df,directory='./',x_col='filepath',y_col='class',classes=class_names,
                                                    target_size=(height,width),batch_size=batch_size,seed=7,shuffle=False,
                                                    class_mode="sparse")
test_num = test_generator.samples
print(test_num)
test_predict = model.predict_generator(test_generator,workers=10,use_multiprocessing=False)
# True 进程; False:线程
print(test_predict.shape)
print(test_predict[0:5])
test_predict_class_indices = np.argmax(test_predict,axis=1)
print(test_predict_class_indices[0:5])
test_predict_class=[class_names[index] for index in test_predict_class_indices]
print(test_predict_class[0:5])
def generate_submissions(fielname,predict_class):
    with open(filename,'w') as f:
        f.write('id,label\n')
        for i in rangelen((predict_class)):
            f.write('%d,%s\n'%(i+1,predict_class[i]))

output_file = './cifar10/submission.csv'
generate_submissions(output_file,test_predict_class)

你可能感兴趣的:(Tensorflow2.0)