基于 Tensorflow 的蘑菇分类

引言

当我们在大自然中行走的时候,经常会碰到各种各样的菌子,这时候我们就有了疑问:我们可以触碰它们吗?它们可以吃吗?如果有一个可以识别菌子的app就很棒了,so,现在让我们来实现吧~

在我们开始之前,让我们理解一些概念。计算机视觉是人工智能的一个有趣分支之一,是教模型在图像中查找信息从而理解视觉内容的艺术。当对人类(猫、狗、汽车……)进行图像分类非常简单时,机器总是很难具有竞争力,这是我们人类从小就学习的东西。计算机视觉已经走过了漫长的道路,现在有了深度学习,它的识别和人类一样好,在特定领域甚至更好。例如,在医学放射学中,可以训练人工智能来检测和分类肿瘤,并且通常比人类有更好的结果。

计算机视觉的第一步是图像检测。图像检测是在给定的图像中找到图像中的特定对象,并返回其坐标或包围盒。

图像分类是当你给出一个物体的图像时,你的模型以概率和置信率返回一个类。因此,我们的模型应该首先检测对象,然后根据它所训练的类型对它们进行分类。为此,我们通常使用 CNN(卷积神经网络)。

图像识别是当您给模型一个图像与多个对象。该模型为图像中的每个物体提供了它的边界框(目标检测)和类的预测,并给出了置信率。

现在我们遇到的是多目标的图像分类问题。

收集数据

为了训练一个模型,你需要好的标记数据,如果这一步出现了错误,后面所有的步骤都将徒劳无功。现在我们用的是 Kaggle 的真菌数据集,这是一个非常好的数据集,有1394个类可以在这里使用。数据集的链接如下:https://www.kaggle.com/c/fungi-challenge-fgvc-2018。

数据处理

Tensorflow 为我们提供了一个很便利的API,即 tf.data.dataset。我们可以很方便的用一行代码创建一个有效的数据集,让我们来看看吧~

 data_dir = '/Mydirectory/images/'
    img_height = 256
    img_width = 256
    batch_size = 32
    
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)


    class_names = train_ds.class_names

以下是我们将使用的 10 个类:

[‘11082_Xerocomellus_chrysenteron’, ‘12919_Cylindrobasidium_laeve’, ‘14064_Fomitopsis_pinicola’, ‘14160_Ganoderma_pfeifferi’, ‘17233_Mycena_galericulata’, ‘20983_Trametes_versicolor’, ‘21143_Tricholoma_scalpturatum’, ‘40392_Armillaria_lutea’, ‘40985_Byssomerulius_corium’, ‘61207_Coprinellus_micaceus’]

基于 Tensorflow 的蘑菇分类_第1张图片

让我们设置数据集性能

    #################################################
    # Dataset Performance
    ##################################################
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

迁移学习模式

开始创建自己的 CNN,但效果不佳。我不是那么有耐心去改进它,我选择进行迁移学习。迁移学习是重用在更大数据集上训练的模型的能力,这些模型已经学习了多个特征。为此,我们冻结顶层并使用新类重新训练,权重可重复使用。所以让我们用这些预先训练好的模型来帮助自己。我使用了 MobileNetV2 模型,因为它非常轻巧,在我的 GPU 上运行只需几秒钟。

为了提高准确性,我增加了一个独特的步骤,那就是数据增强。数据增强是: 对于一个标记为图像的输入,您可以缩放或翻转它,并将其作为模型的输入添加。这有助于模型继续识别对象,即使它并不总是处于相同的位置。

#################################################
    # Data Augmentation
    ##################################################
    data_augmentation = tf.keras.Sequential([
        tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
        tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
        tf.keras.layers.experimental.preprocessing.RandomZoom(0.1),
    ])
    #################################################
    # CREATE THE MODEL
    ##################################################
    num_classes = 10
    preprocess_input_mobilenet_v2 = tf.keras.applications.mobilenet_v2.preprocess_input


    base_model = tf.keras.applications.MobileNetV2(input_shape=(256, 256, 3),
                                                   include_top=False,
                                                   weights='imagenet')
    
      
    base_model.trainable = False
    
    image_batch, label_batch = next(iter(train_ds))
    feature_batch = base_model(image_batch)


    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
    feature_batch_average = global_average_layer(feature_batch)
    prediction_layer = tf.keras.layers.Dense(num_classes, kernel_regularizer=tf.keras.regularizers.l2(0.001))
    prediction_batch = prediction_layer(feature_batch_average)
    inputs = tf.keras.Input(shape=(256, 256, 3))
    x = data_augmentation(inputs)
    x = preprocess_input_mobilenet_v2(x)
    x = base_model(x, training=False)
    x = global_average_layer(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = prediction_layer(x)
    model = tf.keras.Model(inputs, outputs)


    #################################################
    # COMPILE THE MODEL
    ##################################################
    #
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])


    #################################################
    # TRAIN THE MODEL
    ##################################################


    epochs = 10
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs
    )

结果如下

基于 Tensorflow 的蘑菇分类_第2张图片

现在让我们来预测下,代码如下所示:

# #################################################
# # LOAD THE MODEL
# ##################################################
model = tf.keras.models.load_model('MobileNetV2_Ep20')


# #################################################
# # Predictions
# ##################################################


img_url = "https://www.mycodb.fr/photos/Xerocomellus_chrysenteron_2014_rp_1.jpg"
img_path = tf.keras.utils.get_file('mushroom_image', origin=img_url)


img = tf.keras.preprocessing.image.load_img(
    img_path, target_size=(256, 256, 3)
)


img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create a batch


predictions = model.predict(img_array)


predictions_sigmoid = tf.nn.sigmoid(predictions)
score = tf.nn.softmax(predictions[0])


print(
    "This image most likely belongs to {} with a {:.2f} percent confidence."
    .format(class_names[np.argmax(score)], 100 * np.max(score))
)


预测结果:

结果还是相当不错的吧~

总结

在本文中,我们了解了如何使用 tensorflow 训练一个用于分类菌子的模型,下一步我们就可以将它移植到移动端,想想还是很兴奋的呢~

·  END  ·

HAPPY LIFE

基于 Tensorflow 的蘑菇分类_第3张图片

你可能感兴趣的:(神经网络,深度学习,机器学习,人工智能,计算机视觉)