简单CNN卷积神经网络搭建 TensorFlow

        以下文章中用到的的卷积网络仅用于分类操作

        已知有训练集X_training、Y_training,和验证集X_test、Y_test。其中Y_training与Y_test两个标签数据均为one-hot类型数据即【0,0,1,0,0】这种数据。如果你的标签数据是int型整数的话,可以用如下代码改写为one-hot类型。(比如在一个分类任务中一共有5种类别,其中类别3将会转换为【0,0,1,0,0】)

Y_train = tf.keras.utils.to_categorical(Y_train, num)#num为类别总数
Y_test = tf.keras.utils.to_categorical(Y_test, num)

        下面我们来创建一个简单的CNN网络,该网络结构如下图所示:

简单CNN卷积神经网络搭建 TensorFlow_第1张图片

        模型的代码实现:

input_features=tf.keras.layers.Conv2D(filters=64,
                                      kernel_size=[4,4],
                                      padding='same',
                                      activation=tf.nn.relu)(input_1)#input_1为该层输入,input_features为该层输出,切记不要弄混

input_features= tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2)

input_features=tf.keras.layers.Conv2D(filters=64,
                                      kernel_size=[4,4],
                                      padding='same',
                                      activation=tf.nn.relu)(input_1)

input_features= tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2)

output = tf.keras.layers.Dense(num,name="output", activation='softmax')(input_features)#num为类别总数

        如此网络便构建完成。

        下面进行模型定义:

model = tf.keras.Model(
        inputs=input1,
        outputs=output)

        这一步相当于把上方自己写的模型封装进model中,当模型定义完成后,对模型的后续操作只需要调用变量model即可。

        模型编译:

optimizer=tf.keras.optimizers.Adam(learning_rate=0.01)#定义学习率
model.compile(optimizer=optimizer,#导入上面的学习率
              loss='categorical_crossentropy',#选择损失函数
              metrics=['accuracy']#在训练时输出对训练集的精确度(可删)
             )

         模型训练并输出结果:

model.fit(
        X_training,
        Y_train,#导入训练集
        epochs=2000,#迭代次数设置为2000
        batch_size=32,#每次送入32个数据进行训练(如果少于32则会直接带入当前全部)
        validation_data=(X_test, Y_test),#当训练完成后直接对验证集进行测试并输出结果
        shuffle=True#打乱数据集顺序,防止过拟合
        )

        总的来说用TensorFlow创建简单CNN网络还是比较简单的,但是在输入数据时要注意数据的shape,如果报错的话可以用numpy.expand_dims()函数将输入数据的形状更改为模型需要的形状。总之多调试,多研究报错信息就可以跑通程序。

你可能感兴趣的:(tensorflow,cnn,人工智能)