引言
当我们在大自然中行走的时候,经常会碰到各种各样的菌子,这时候我们就有了疑问:我们可以触碰它们吗?它们可以吃吗?如果有一个可以识别菌子的app就很棒了,so,现在让我们来实现吧~
在我们开始之前,让我们理解一些概念。计算机视觉是人工智能的一个有趣分支之一,是教模型在图像中查找信息从而理解视觉内容的艺术。当对人类(猫、狗、汽车……)进行图像分类非常简单时,机器总是很难具有竞争力,这是我们人类从小就学习的东西。计算机视觉已经走过了漫长的道路,现在有了深度学习,它的识别和人类一样好,在特定领域甚至更好。例如,在医学放射学中,可以训练人工智能来检测和分类肿瘤,并且通常比人类有更好的结果。
计算机视觉的第一步是图像检测。图像检测是在给定的图像中找到图像中的特定对象,并返回其坐标或包围盒。
图像分类是当你给出一个物体的图像时,你的模型以概率和置信率返回一个类。因此,我们的模型应该首先检测对象,然后根据它所训练的类型对它们进行分类。为此,我们通常使用 CNN(卷积神经网络)。
图像识别是当您给模型一个图像与多个对象。该模型为图像中的每个物体提供了它的边界框(目标检测)和类的预测,并给出了置信率。
现在我们遇到的是多目标的图像分类问题。
收集数据
为了训练一个模型,你需要好的标记数据,如果这一步出现了错误,后面所有的步骤都将徒劳无功。现在我们用的是 Kaggle 的真菌数据集,这是一个非常好的数据集,有1394个类可以在这里使用。数据集的链接如下:https://www.kaggle.com/c/fungi-challenge-fgvc-2018。
数据处理
Tensorflow 为我们提供了一个很便利的API,即 tf.data.dataset。我们可以很方便的用一行代码创建一个有效的数据集,让我们来看看吧~
data_dir = '/Mydirectory/images/'
img_height = 256
img_width = 256
batch_size = 32
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
以下是我们将使用的 10 个类:
[‘11082_Xerocomellus_chrysenteron’, ‘12919_Cylindrobasidium_laeve’, ‘14064_Fomitopsis_pinicola’, ‘14160_Ganoderma_pfeifferi’, ‘17233_Mycena_galericulata’, ‘20983_Trametes_versicolor’, ‘21143_Tricholoma_scalpturatum’, ‘40392_Armillaria_lutea’, ‘40985_Byssomerulius_corium’, ‘61207_Coprinellus_micaceus’]
让我们设置数据集性能
#################################################
# Dataset Performance
##################################################
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
迁移学习模式
开始创建自己的 CNN,但效果不佳。我不是那么有耐心去改进它,我选择进行迁移学习。迁移学习是重用在更大数据集上训练的模型的能力,这些模型已经学习了多个特征。为此,我们冻结顶层并使用新类重新训练,权重可重复使用。所以让我们用这些预先训练好的模型来帮助自己。我使用了 MobileNetV2 模型,因为它非常轻巧,在我的 GPU 上运行只需几秒钟。
为了提高准确性,我增加了一个独特的步骤,那就是数据增强。数据增强是: 对于一个标记为图像的输入,您可以缩放或翻转它,并将其作为模型的输入添加。这有助于模型继续识别对象,即使它并不总是处于相同的位置。
#################################################
# Data Augmentation
##################################################
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
tf.keras.layers.experimental.preprocessing.RandomZoom(0.1),
])
#################################################
# CREATE THE MODEL
##################################################
num_classes = 10
preprocess_input_mobilenet_v2 = tf.keras.applications.mobilenet_v2.preprocess_input
base_model = tf.keras.applications.MobileNetV2(input_shape=(256, 256, 3),
include_top=False,
weights='imagenet')
base_model.trainable = False
image_batch, label_batch = next(iter(train_ds))
feature_batch = base_model(image_batch)
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
prediction_layer = tf.keras.layers.Dense(num_classes, kernel_regularizer=tf.keras.regularizers.l2(0.001))
prediction_batch = prediction_layer(feature_batch_average)
inputs = tf.keras.Input(shape=(256, 256, 3))
x = data_augmentation(inputs)
x = preprocess_input_mobilenet_v2(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
#################################################
# COMPILE THE MODEL
##################################################
#
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#################################################
# TRAIN THE MODEL
##################################################
epochs = 10
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
结果如下
现在让我们来预测下,代码如下所示:
# #################################################
# # LOAD THE MODEL
# ##################################################
model = tf.keras.models.load_model('MobileNetV2_Ep20')
# #################################################
# # Predictions
# ##################################################
img_url = "https://www.mycodb.fr/photos/Xerocomellus_chrysenteron_2014_rp_1.jpg"
img_path = tf.keras.utils.get_file('mushroom_image', origin=img_url)
img = tf.keras.preprocessing.image.load_img(
img_path, target_size=(256, 256, 3)
)
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create a batch
predictions = model.predict(img_array)
predictions_sigmoid = tf.nn.sigmoid(predictions)
score = tf.nn.softmax(predictions[0])
print(
"This image most likely belongs to {} with a {:.2f} percent confidence."
.format(class_names[np.argmax(score)], 100 * np.max(score))
)
预测结果:
结果还是相当不错的吧~
总结
在本文中,我们了解了如何使用 tensorflow 训练一个用于分类菌子的模型,下一步我们就可以将它移植到移动端,想想还是很兴奋的呢~
· END ·
HAPPY LIFE