cnn 实现图片识别

在入门之后需要对机器学习的一些思维和方法体验下2333

一个github上的一些源码变式,体验一下现在深度学习识别图片的速度之快,准确度之好。

使用了Tensorflow里的高级神经网络库,使用其keras.applications模块获取在ILSVRC竞赛中获胜的多个卷积网络模型,可识别物体量从10类增加到1001类,可为:狗熊 椅子 汽车 键盘 箱子 婴儿床 旗杆iPod播放器 轮船 面包车 项链 降落伞 桌子 钱包 球拍 步枪等等
接着导入ResNet50网络模型进行处理,主要图像数据处理函数如下:


image.img_to_array:将PIL格式的图像转换为numpy数组。


np.expand_dims:将我们的(3,224,224)大小的图像转换为(1,3,224,224)。因为model.predict函数需要4维数组作为输入,其中第4维为每批预测图像的数量。这也就是说,我们可以一次性分类多个图像。


preprocess_input:使用训练数据集中的平均通道值对图像数据进行零值处理,即使得图像所有点的和为0。这是非常重要的步骤,如果跳过,将大大影响实际预测效果。这个步骤称为数据归一化。


model.predict:对我们的数据分批处理并返回预测值。


decode_predictions:采用与model.predict函数相同的编码标签,并从ImageNet ILSVRC集返回可读的标签。
然后通过调用官方api,得到了这个classfier,可以读取同目录下images文件夹里的指定命名的图片,这个分类器在我的笔记本tensorflow中训练差不多需要两个小时的时间,训练好之后实际识别的速率快达1~2分钟一张图,准确率高达90%以上

代码:

GUI:(PyQt5)

from PyQt5 import QtWidgets
from PyQt5.QtWidgets import QFileDialog
from PyQt5 import QtCore, QtGui
import classify
class MyWindow(QtWidgets.QWidget):
    def __init__(self):
        super(MyWindow, self).__init__()
        self.setObjectName("widget")
        self.resize(490, 506)
        self.setMinimumSize(QtCore.QSize(100, 100))
        self.setCursor(QtGui.QCursor(QtCore.Qt.ArrowCursor))
        self.gridLayoutWidget = QtWidgets.QWidget(self)
        self.gridLayoutWidget.setGeometry(QtCore.QRect(60, 120, 381, 301))
        self.gridLayoutWidget.setObjectName("gridLayoutWidget")
        self.gridLayout = QtWidgets.QGridLayout(self.gridLayoutWidget)
        self.gridLayout.setContentsMargins(0, 0, 0, 0)
        self.gridLayout.setObjectName("gridLayout")
        self.label = QtWidgets.QLabel(self)
        self.label.setGeometry(QtCore.QRect(70, 50, 54, 12))
        self.label.setObjectName("label")
        self.textEdit = QtWidgets.QTextEdit(self)
        self.textEdit.setGeometry(QtCore.QRect(120, 45, 261, 25))
        self.textEdit.setObjectName("textEdit")
        self.toolButton = QtWidgets.QToolButton(self)
        self.toolButton.setGeometry(QtCore.QRect(379, 43, 50, 28))
        self.toolButton.setObjectName("toolButton")
        self.toolButton.clicked.connect(self.msg)
        self.pushButton = QtWidgets.QPushButton(self)
        self.pushButton.setGeometry(QtCore.QRect(200, 80, 81, 31))
        self.pushButton.setObjectName("pushButton")
        self.pushButton.clicked.connect(self.sbing)
        #  放图片的label
        self.label2 = QtWidgets.QLabel(self)
        self.label2.setGeometry(QtCore.QRect(72, 150, 360, 300))
        #  参数分别是左上点距左边框宽度,距顶高度,长度,高度
        self.label2.setObjectName("label2")

        self.retranslateUi(self)
        QtCore.QMetaObject.connectSlotsByName(self)
    def retranslateUi(self, widget):
        _translate = QtCore.QCoreApplication.translate
        widget.setWindowTitle(_translate("widget", "图片识别器"))
        self.label.setText(_translate("widget", "目标图片"))
        self.toolButton.setText(_translate("widget", "浏览"))
        self.pushButton.setText(_translate("widget", "开始识别"))
    def msg(self):
        '''directory1 = QFileDialog.getExistingDirectory(self,
                                                      "选取文件夹",
                                                      "C:/")  # 起始路径
        print(directory1)'''

        fileName1, filetype = QFileDialog.getOpenFileName(self,
                                                          "选取文件",
                                                          "C:/",
                                                          "All Files (*);;Text Files (*.txt)")  # 设置文件扩展名过滤,注意用双分号间隔
        #  print(fileName1, filetype)
        #  print(fileName1)
        '''files, ok1 = QFileDialog.getOpenFileNames(self,
                                                  "多文件选择",
                                                  "C:/",
                                                  "All Files (*);;Text Files (*.txt)")
        print(files, ok1)

        fileName2, ok2 = QFileDialog.getSaveFileName(self,
                                                     "文件保存",
                                                     "C:/",
                                                     "All Files (*);;Text Files (*.txt)")
        '''
        png = QtGui.QPixmap(fileName1).scaled(self.label2.width(), self.label2.height())
        self.label2.setPixmap(png)
        self.textEdit.setText(fileName1)
        classify.imgf=fileName1
    def sbing(self):
        self.pushButton.setText("识别中")
        classify.sjsy()
        self.pushButton.setText("开始识别")
if __name__ == "__main__":
    import sys
    app = QtWidgets.QApplication(sys.argv)
    myshow = MyWindow()
    myshow.show()
    sys.exit(app.exec_())
    exit()
classify

import sys
import argparse
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt

from keras.preprocessing import image
from keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions

model = ResNet50(weights='imagenet')
target_size = (224, 224)
imgf = ""

def predict(model, img, target_size, top_n=3):
    """Run model prediction on image
  Args:
    model: keras model
    img: PIL format image
    target_size: (w,h) tuple
    top_n: # of top predictions to return
  Returns:
    list of predicted labels and their probabilities
  """
    if img.size != target_size:
        img = img.resize(target_size)

    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    preds = model.predict(x)
    return decode_predictions(preds, top=top_n)[0]


def plot_preds(image, preds):
    """Displays image and the top-n predicted probabilities in a bar graph
  Args:
    image: PIL image
    preds: list of predicted labels and their probabilities
  """
    plt.imshow(image)
    plt.axis('off')

    plt.figure()
    order = list(reversed(range(len(preds))))
    bar_preds = [pr[2] for pr in preds]
    labels = (pr[1] for pr in preds)
    plt.barh(order, bar_preds, alpha=0.5)
    plt.yticks(order, labels)
    plt.xlabel('Probability')
    plt.xlim(0, 1.01)
    plt.tight_layout()
    plt.show()


def sjsy():
    print(imgf)
    img = Image.open(imgf)
    preds = predict(model, img, target_size)
    plot_preds(img, preds)


'''if __name__=="__main__":
  img = Image.open("images/3.jpg")
  preds = predict(model, img, target_size)
  plot_preds(img, preds)
  a = argparse.ArgumentParser()
  a.add_argument("--image", help="path to image")
  a.add_argument("--image_url", help="url to image")
  args = a.parse_args()

  if args.image is None and args.image_url is None:
    a.print_help()
    sys.exit(1)

  if args.image is not None:
    img = Image.open(args.image)
    preds = predict(model, img, target_size)
    plot_preds(img, preds)

  if args.image_url is not None:
    response = requests.get(args.image_url)
    img = Image.open(BytesIO(response.content))
    preds = predict(model, img, target_size)
    plot_preds(img, preds)



实验对象及运行结果
分别拿水瓶,汽车,大象,电脑等对象做了识别测试,发现实验结果非常令人满意

cnn 实现图片识别_第1张图片

cnn 实现图片识别_第2张图片

cnn 实现图片识别_第3张图片

你可能感兴趣的:(机器学习)