CNN卷积神经网络 入门案例

数据介绍

Clifar 10 数据集

5w张 32x32 的图片 训练
1w张 32x32 的图片 测试

输入是分为10个标签,下面的图的左边已经给出了。
CNN卷积神经网络 入门案例_第1张图片

导入数据集

CNN卷积神经网络 入门案例_第2张图片

可视化一张图片看看
打印出 x的第一张图片的像素点看看
打印出 x的第一张图片对应的输出分类结果
查看 测试集的整体大小 1000张 32x32像素 3通道的图片集合

CNN卷积神经网络 入门案例_第3张图片

搭建卷积神经网络

口诀:CBAPD
一层卷积:5x5卷积核,一共有6个
2x2的卷积池,步长为2
两层全连接
第一次层128个神经元
第二层 10分类问题,输出10神经元的全连接层
CNN卷积神经网络 入门案例_第4张图片

下图是详细解释
CNN卷积神经网络 入门案例_第5张图片
搭建网络

class Baseline(Model):
    def __init__(self):
        super(Baseline,self).__init__()
        self.c1 = Conv2D(filters = 6, kernel_size =(5,5), padding = "same")
        self.b1 = BatchNormalization()
        self.a1 = Activation('relu')
        self.p1 = MaxPool2D(pool_size=(2,2),strides=2,padding='same')
        self.d1 = Dropout(0.2)
        
        self.flatten = Flatten()
        self.f1 = Dense(128,activation='relu') # 全连接网络
        self.d2 = Dropout(0.2) # 20比例休眠神经元
        self.f2 = Dense(10,activation='softmax')
    def call(self,x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)
        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y
model = Baseline()

模型配置参数

# 模型配置参数
model.compile(optimizer=keras.optimizers.Adam(),
              loss='sparse_categorical_crossentropy', 
              metrics=['accuracy'])

指定模型输入

# 指定模型的输入
model.build(input_shape=[None, 32, 32, 3])  # 指定输入

查看一张图片的形状和通道情况

x[0].shape

CNN卷积神经网络 入门案例_第6张图片

自己创建一个x
看一下通过模型的输出是否符合

# 检查输入通道是否正常
# 创建一个图片数据
# 1是指输入一张图像,两个32是图像长宽,3是指3通道
x = tf.random.normal([1,32,32,3]) 
out = model(x)
out.shape

CNN卷积神经网络 入门案例_第7张图片

给模型导入数据,进行训练
train_db 是打包了的训练集
validation_data 是打包了的测试集

# 训练模型
history = model.fit(train_db,epochs=5,validation_data=test_db)

查看模型结构

model.summary()

查看模型的检验数据
比如 loss 或者 acc

history.history()

绘制 loss 或acc随着迭代次数的变化图
如果有loss的话就绘制loss的变化

loss = history.history['loss']
val_loss = history.history['val_loss']
# 绘制loss的图
plt.figure(figsize=(20,10))
plt.subplot(1,2,1)
plt.plot(loss,label='Trainning loss')
plt.plot(val_loss,label='Validation loss')
plt.legend()
plt.grid()
plt.title('loss')

如果有acc就绘制acc的变化

经典卷积网络

CNN卷积神经网络 入门案例_第8张图片

你可能感兴趣的:(机器学习,深度学习,tensorflow,人工智能)