faster rcnn inception_resnet_v2物品辨识比赛demo记录

使用tensorflow model里面的object detection训练的,因为没有时间限制,使用faster rcnn inception_resnet_v2识别10个类别,在1080ti上训练了5个小时,在1050上测试的,2s一张图片。
定义 pascal_label_map.pbtxt

item {
  id: 1
  name: 'cola'
}

item {
  id: 2
  name: 'milk tea'
}

item {
  id: 3
  name: 'ice tea'
}
item {
  id: 4
  name: 'beer'
}
item {
  id: 5
  name: 'shampoo'
}
item {
  id: 6
  name: 'toothpaste'
}
item {
  id: 7
  name: 'soap'
}
item {
  id: 8
  name: 'pear'
}
item {
  id: 9
  name: 'apple'
}
item {
  id: 10
  name: 'orange'
}

测试代码

#-*-coding:utf-8-*-
import sys
import argparse
from PIL import Image
import os
import cv2
import numpy as np
import speech_recognition as sr

import wave
import requests
import time
import base64
from pyaudio import PyAudio, paInt16
import webbrowser
import serial
import speech
import numpy as np
import os
import sys
import tensorflow as tf
from PIL import Image
sys.path.append("..")
from utils import label_map_util
from utils import visualization_utils as vis_util
import cv2
from timeit import default_timer as timer

framerate = 16000  # 采样率
num_samples = 2000  # 采样点
channels = 1  # 声道
sampwidth = 2  # 采样宽度2bytes
FILEPATH = 'speech.wav'

base_url = "https://openapi.baidu.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s"
APIKey = "***"
SecretKey = "***"

HOST = base_url % (APIKey, SecretKey)

PATH_TO_CKPT = 'F:/python_project/比赛' + '/frozen_inference_graph.pb'
PATH_TO_LABELS = 'F:/python_project/比赛/pascal_label_map.pbtxt'

NUM_CLASSES = 80
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
                                                            use_display_name=True)
category_index = label_map_util.create_category_index(categories)
print(category_index)

def detect():
    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            state = True
            cap = cv2.VideoCapture(1)

            while state:
                start = timer()
                f, frame = cap.read()
                show = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image = Image.fromarray(show)
                image_np = np.array(image)

                image_np_expanded = np.expand_dims(image_np, axis=0)

                image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = detection_graph.get_tensor_by_name('detection_scores:0')
                classes = detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = detection_graph.get_tensor_by_name('num_detections:0')
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                #print(num_detections)
                end = timer()
                image_np,num=vis_util.visualize_boxes_and_labels_on_image_array_(
                    image_np, np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8)
                (r, g, b) = cv2.split(image_np)
                image_np = cv2.merge([b, g, r])
                print(num)
                if num >= 5: #识别5个物体
                    state = False
                    #cv2.imwrite("wxy-TRY-" + time.strftime("%H%M", time.localtime()) + ".jpg", image_np)
                    tmp = "wxy-TRY-" + time.strftime("%H%M", time.localtime()) + ".jpg" ##带中文路径要用imencode
                    cv2.imencode('.jpg', image_np)[1].tofile(tmp)
                    speech.say("识别完成")
                    print("写入成功,停止检测")

                cv2.imshow("test", image_np)
                cv2.waitKey(1)

                print(end - start)

def getToken(host):
    res = requests.post(host)
    return res.json()['access_token']


def save_wave_file(filepath, data):
    wf = wave.open(filepath, 'wb')
    wf.setnchannels(channels)
    wf.setsampwidth(sampwidth)
    wf.setframerate(framerate)
    wf.writeframes(b''.join(data))
    wf.close()


def my_record():
    pa = PyAudio()
    stream = pa.open(format=paInt16, channels=channels,
                     rate=framerate, input=True, frames_per_buffer=num_samples)
    my_buf = []
    # count = 0
    t = time.time()
    print('正在录音...')

    while time.time() < t + 4:  # 秒
        string_audio_data = stream.read(num_samples)
        my_buf.append(string_audio_data)
    print('录音结束.')
    save_wave_file(FILEPATH, my_buf)
    stream.close()


def get_audio(file):
    with open(file, 'rb') as f:
        data = f.read()
    return data


def speech2text(speech_data, token, dev_pid=1537):
    FORMAT = 'wav'
    RATE = '16000'
    CHANNEL = 1
    CUID = '*******'
    SPEECH = base64.b64encode(speech_data).decode('utf-8')

    data = {
        'format': FORMAT,
        'rate': RATE,
        'channel': CHANNEL,
        'cuid': CUID,
        'len': len(speech_data),
        'speech': SPEECH,
        'token': token,
        'dev_pid': dev_pid
    }
    url = 'https://vop.baidu.com/server_api'
    headers = {'Content-Type': 'application/json'}
    # r=requests.post(url,data=json.dumps(data),headers=headers)
    print('正在识别...')
    r = requests.post(url, json=data, headers=headers)
    Result = r.json()
    if 'result' in Result:
        return Result['result'][0]
    else:
        return Result

def detect_img(yolo):
    state = True
    cap = cv2.VideoCapture(0)
    while state:
        f, frame = cap.read()
        show = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(show)
        image_,num = yolo.detect_image_(image)
        image = cv2.cvtColor(np.asarray(image_), cv2.COLOR_RGB2BGR)
        #print(datetime.datetime.now())
        time_str = time.strftime("%H:%M:%S", time.localtime())
        if num == 2:
            state = False
            cv2.imwrite("WXY-TRY-"+time.strftime("%H%M", time.localtime())+".jpg",image)
            speech.say("识别完成")
            print("写入成功,停止检测")

        cv2.imshow("test", image)
        cv2.waitKey(1)




def test():
    state = True
    while state:
        my_record()
        TOKEN = getToken(HOST)
        speech_ = get_audio(FILEPATH)
        result = speech2text(speech_, TOKEN, int(1536))
        print(result)
        if result == "开始":
            state = False
    if not state:
        speech.say("开始识别")
        detect()

def test2():
    serialPort = "COM4"  # 串口
    baudRate = 115200  # 波特率
    ser = serial.Serial(serialPort, baudRate, timeout=0.5)
    print("参数设置:串口=%s ,波特率=%d" % (serialPort, baudRate))
    state = True
    # 收发数据
    while state:
        #ser.write((str + '\n').encode())
        #print(ser.readline(),"接收成功")  # 可以接收中文
        tmp = ser.readline()
        if tmp:
            detect()
            state = False
            ser.close()

def test3():
    speech.say("开始识别")
    print(time.strftime("%H%M", time.localtime()))
if __name__ == '__main__':
    #test()    #加百度语音识别
    #test2()   #加科大讯飞的语音唤醒
    #test3()   #测试windows下的speech模块
    detect()   #直接检测

你可能感兴趣的:(python,tensorflow,深度学习)