原图链接:https://pan.baidu.com/s/17iI62gt9HyRbQ-Wr8h28jw 提取码:4swb,mkdir n01440764创建文件夹n01440764。运命令tar -xvf n01440764.tar -C n01440764完成压缩文件的解压,其中-C参数后面必须为已经存在的文件夹,否则运行命令会报错。

1.加快apt-get命令的下载速度,需要做Ubuntu系统的换源:Ubuntu的设置Settings中选择Software & Updates,将Download from的值设置为http://mirrors.aliyun.com/ubuntu
2.运行命令sudo apt-get install pyqt5-dev-tools安装软件pyqt5-dev-tools。
3.运行命令cd labelImg-master进入文件夹labelImg-master。运行命令pip install -r requirements/requirements-linux-python3.txt安装软件labelImg运行时需要的库,如果已经安装Anaconda此步可能不用进行。
4.运行命令make qt5py3编译产生软件labelImg运行时需要的组件。python labelImg.py 运行打开labelImg软件。

1.1 挑选像素足够的图片

n01440764中有一部分图片像素不足416 * 416,不利于模型训练,新建_01_select_images.py

import os
import random
from PIL import Image
import shutil

def getFilePathList(dirPath, partOfFileName=''):
    allFileName_list = list(os.walk(dirPath))[0][2]
    fileName_list = [k for k in allFileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list

def generate_qualified_images(dirPath, sample_number, new_dirPath):
    jpgFilePath_list = getFilePathList(dirPath, '.JPEG')
    if not os.path.isdir(new_dirPath):
    i = 0
    for jpgFilePath in jpgFilePath_list:
        image = Image.open(jpgFilePath)
        width, height = image.size
        if width >= 416 and height >= 416:
            i += 1
            new_jpgFilePath = os.path.join(new_dirPath, '%03d.jpg' %i)
            shutil.copy(jpgFilePath, new_jpgFilePath)
        if i == sample_number:

generate_qualified_images('n01440764', 200, 'selected_images')
# 命令行写法
import os
import random
from PIL import Image
import cv2
import argparse

# 获取文件夹中的文件路径
def get_filePathList(dirPath, partOfFileName=''):
    allFileName_list = next(os.walk(dirPath))[2]
    fileName_list = [k for k in allFileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list
# 选出一部分像素足够,即长,宽都大于指定数值的图片
def select_qualifiedImages(in_dirPath, out_dirPath, in_suffix, out_suffix, sample_number, required_width, required_height):
    imageFilePath_list = get_filePathList(in_dirPath, in_suffix)
    if not os.path.isdir(out_dirPath):
    count = 0
    for i, imageFilePath in enumerate(imageFilePath_list):
        image = Image.open(imageFilePath)
        image_width, image_height = image.size
        if image_width >= required_width and image_height >= required_height:
            count += 1 
            out_imageFilePath = os.path.join(out_dirPath  , '%03d%s' %(count, out_suffix))
            image_ndarray = cv2.imread(imageFilePath)
            cv2.imwrite(out_imageFilePath, image_ndarray)
        if count == sample_number:

# 解析运行代码文件时传入的参数            
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--in_dir', type=str, default='../resources/source_images', help='输入文件夹')
    parser.add_argument('-o', '--out_dir', type=str, default='../resources/selected_images', help='输出文件夹')
    parser.add_argument('--in_suffix', type=str, default='.JPG')
    parser.add_argument('--out_suffix', type=str, default='.jpg')
    parser.add_argument('-n', '--number', type=int, default=500)
    parser.add_argument('-w', '--width', type=int, default=416)
    parser.add_argument('-he', '--height', type=int, default=416)
    argument_namespace = parser.parse_args()
    return argument_namespace
# 获取数量为200的合格样本存放到selected_images文件夹中
if __name__ == "__main__":
    argument_namespace = parse_args()
    in_dirPath = argument_namespace.in_dir.strip() # in_dir这个属性
    assert os.path.exists(in_dirPath), 'not exists this path: %s' %in_dirPath # 断言
    out_dirPath = argument_namespace.out_dir.strip()
    sample_number = argument_namespace.number
    in_suffix = argument_namespace.in_suffix.strip() # 输入文件后缀
    in_suffix = '.' + in_suffix.lstrip('.') # 去掉左边的.号,怕有多余的.号
    out_suffix = argument_namespace.out_suffix.strip()
    out_suffix = '.' + out_suffix.lstrip('.')
    required_width = argument_namespace.width
    required_height = argument_namespace.height
    select_qualifiedImages(in_dirPath, out_dirPath, in_suffix, out_suffix, sample_number, required_width, required_height)
    out_dirPath = os.path.abspath(out_dirPath) # 返回绝对路径
    print('选出的图片文件保存到文件夹:%s' %out_dirPath)
    imageFilePath_list = get_filePathList(out_dirPath, out_suffix) #out_suffix后缀
    selectedImages_number = len(imageFilePath_list)
    print('总共选出%d张图片' %selectedImages_number)
    if selectedImages_number < sample_number:  # sample_number自己传入的需要挑出的张数

1.2 数据标注及检查

作者已提供标签好200张图片:链接:https://pan.baidu.com/s/13-fRksSjUeEii54gClA3Pw 提取码:57lz

import os
def getFilePathList(dirPath, partOfFileName=''):
    allFileName_list = list(os.walk(dirPath))[0][2]
    fileName_list = [k for k in allFileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list

def check_1(dirPath):
    jpgFilePath_list = getFilePathList(dirPath, '.jpg')
    allFileMarked = True
    for jpgFilePath in jpgFilePath_list:
        xmlFilePath = jpgFilePath[:-4] + '.xml'
        if not os.path.exists(xmlFilePath):
            print('%s this picture is not marked.' %jpgFilePath)
            allFileMarked = False
    if allFileMarked:
        print('congratulation! it is been verified that all jpg file are marked.')

import xml.etree.ElementTree as ET
def check_2(dirPath, className_list):
    className_set = set(className_list)
    xmlFilePath_list = getFilePathList(dirPath, '.xml')
    allFileCorrect = True
    for xmlFilePath in xmlFilePath_list:
        with open(xmlFilePath) as file:
            fileContent = file.read()
        root = ET.XML(fileContent)
        object_list = root.findall('object')
        for object_item in object_list:
            name = object_item.find('name')
            className = name.text
            if className not in className_set:
                print('%s this xml file has wrong class name "%s" ' %(xmlFilePath, className))
                allFileCorrect = False
    if allFileCorrect:
        print('congratulation! it is been verified that all xml file are correct.')

if __name__ == '__main__':
    dirPath = 'selected_images'
    className_list = ['fish', 'human_face'] # 自己写入,后set(className_list)生成集合{}
    check_2(dirPath, className_list)
# 命令行写法
import os
def get_filePathList(dirPath, partOfFileName=''):
    all_fileName_list = next(os.walk(dirPath))[2]
    fileName_list = [k for k in all_fileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list

# 删除文件前检查文件是否存在,如果不存在则报告此文件不存在,存在则报告此文件会被删除
def delete_file(filePath):
    if not os.path.exists(filePath):
        print('%s 这个文件路径不存在,请检查一下' %filePath)
        print('%s 这个路径的文件将被删除' %filePath)
# 此段代码删除不对应的图片文件或xml文件
def check_1(dirPath, suffix):
    # 检查标记好的文件夹是否有图片漏标,并删除漏标的图片 
    imageFilePath_list = get_filePathList(dirPath, suffix)
    allFileMarked = True
    for imageFilePath in imageFilePath_list:
        xmlFilePath = imageFilePath[:-4] + '.xml'
        if not os.path.exists(xmlFilePath):
            allFileMarked = False
    if allFileMarked:
    # 检查有xml标注文件却没有图片的情况,删除多余的xml标注文件    
    xmlFilePath_list = get_filePathList(dirPath, '.xml')
    xmlFilePathPrefix_list = [k[:-4] for k in xmlFilePath_list]
    xmlFilePathPrefix_set = set(xmlFilePathPrefix_list)
    imageFilePath_list = get_filePathList(dirPath, suffix)
    imageFilePathPrefix_list = [k[:-4] for k in imageFilePath_list]
    imageFilePathPrefix_set = set(imageFilePathPrefix_list)
    redundant_xmlFilePathPrefix_list = list(xmlFilePathPrefix_set - imageFilePathPrefix_set) #集合可互减
    redundant_xmlFilePath_list = [k+'.xml' for k in redundant_xmlFilePathPrefix_list]
    for xmlFilePath in redundant_xmlFilePath_list:
# 此段代码检查标记的xml文件中是否有物体标记类别拼写错误        
import xml.etree.ElementTree as ET
def check_2(dirPath, className_list):
    className_set = set(className_list)
    xmlFilePath_list = get_filePathList(dirPath, '.xml')
    allFileCorrect = True
    for xmlFilePath in xmlFilePath_list:
        with open(xmlFilePath) as file:
            fileContent = file.read()
        root = ET.XML(fileContent)
        object_list = root.findall('object')
        for object_item in object_list:
            name = object_item.find('name')
            className = name.text
            if className not in className_set:
                print('%s 这个xml文件中有错误的种类名称 "%s" ' %(xmlFilePath, className))
                allFileCorrect = False
    if allFileCorrect:
        print('祝贺你! 已经通过检验,所有xml文件中的标注都有正确的种类名称')
# 此段代码检测标记的box是否超过图片的边界
# 如果有此类型的box,则直接删除与box相关的xml文件和图片文件
from PIL import Image
def check_3(dirPath, suffix):
    xmlFilePath_list = get_filePathList(dirPath, '.xml')
    allFileCorrect = True
    for xmlFilePath in xmlFilePath_list:
        imageFilePath = xmlFilePath[:-4] + '.' + suffix.strip('.')
        image = Image.open(imageFilePath)
        width, height = image.size
        with open(xmlFilePath) as file:
            fileContent = file.read()
        root = ET.XML(fileContent)
        object_list = root.findall('object')
        for object_item in object_list:
            bndbox = object_item.find('bndbox')
            xmin = int(bndbox.find('xmin').text)
            ymin = int(bndbox.find('ymin').text)
            xmax = int(bndbox.find('xmax').text)
            ymax = int(bndbox.find('ymax').text)
            if xmin>=1 and ymin>=1 and xmax<=width and ymax<=height:
                allFileCorrect = False
    if allFileCorrect:
        print('祝贺你! 已经通过检验,所有xml文件中的标注框都没有越界')

# 从文本文件中解析出物体种类列表className_list,要求每个种类占一行
def get_classNameList(txtFilePath):
    with open(txtFilePath, 'r', encoding='utf8') as file:
        fileContent = file.read()
        line_list = [k.strip() for k in fileContent.split('\n') if k.strip()!='']
        className_list= sorted(line_list, reverse=False)  # sorted从小到大排序
    return className_list
# 解析运行代码文件时传入的参数
import argparse
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dirPath', type=str, help='文件夹路径', default='../resources/source_images')
    parser.add_argument('-s', '--suffix', type=str, default='.JPG')
    parser.add_argument('-c', '--class_txtFilePath', type=str, default='../resources/category_list.txt')
    argument_namespace = parser.parse_args()
    return argument_namespace      
# 主函数    
if __name__ == '__main__':
    argument_namespace = parse_args()
    dirPath = argument_namespace.dirPath
    assert os.path.exists(dirPath), 'not exists this path: %s' %dirPath
    class_txtFilePath = argument_namespace.class_txtFilePath
    className_list = get_classNameList(class_txtFilePath)
    suffix = argument_namespace.suffix
    check_1(dirPath, suffix)
    check_2(dirPath, className_list)
    check_3(dirPath, suffix)

1.3 图像压缩


import os
def getFilePathList(dirPath, partOfFileName=''):
    allFileName_list = list(os.walk(dirPath))[0][2]
    fileName_list = [k for k in allFileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list

import xml.etree.ElementTree as ET
def generateNewXmlFile(old_xmlFilePath, new_xmlFilePath, new_size):
    new_width, new_height = new_size
    with open(old_xmlFilePath) as file:
        fileContent = file.read()
    root = ET.XML(fileContent)
    width = root.find('size').find('width')
    old_width = int(width.text)
    width_times = new_width / old_width
    width.text = str(new_width)
    height = root.find('size').find('height')
    old_height = int(height.text)
    height_times = new_height / old_height
    height.text = str(new_height)
    object_list = root.findall('object')
    for object_item in object_list:
        bndbox = object_item.find('bndbox')
        xmin = bndbox.find('xmin')
        xminValue = int(xmin.text)
        xmin.text = str(int(xminValue * width_times))
        ymin = bndbox.find('ymin')
        yminValue = int(ymin.text)
        ymin.text = str(int(yminValue * height_times))
        xmax = bndbox.find('xmax')
        xmaxValue = int(xmax.text)
        xmax.text = str(int(xmaxValue * width_times))
        ymax = bndbox.find('ymax')
        ymaxValue = int(ymax.text)
        ymax.text = str(int(ymaxValue * height_times))
    tree = ET.ElementTree(root) # 初始化一个tree对象

def batch_modify_xml(old_dirPath, new_dirPath, new_size):
    xmlFilePath_list = getFilePathList(old_dirPath, '.xml')
    for xmlFilePath in xmlFilePath_list:
        xmlFileName = os.path.split(xmlFilePath)[1] #不同与str.split,os.path.split返回文件的路径[0]和文件名[1]
        new_xmlFilePath = os.path.join(new_dirPath, xmlFileName)
        generateNewXmlFile(xmlFilePath, new_xmlFilePath, new_size)

from PIL import Image
def generateNewJpgFile(old_jpgFilePath, new_jpgFilePath, new_size):
    old_image = Image.open(old_jpgFilePath)  
    new_image = old_image.resize(new_size, Image.ANTIALIAS) # new_size是(,),Image.ANTIALIAS表示高质量是一个参数

def batch_modify_jpg(old_dirPath, new_dirPath, new_size):
    if not os.path.isdir(new_dirPath):
    xmlFilePath_list = getFilePathList(old_dirPath, '.xml')
    for xmlFilePath in xmlFilePath_list:
        old_jpgFilePath = xmlFilePath[:-4] + '.jpg'
        jpgFileName = os.path.split(old_jpgFilePath)[1]
        new_jpgFilePath = os.path.join(new_dirPath, jpgFileName)
        generateNewJpgFile(old_jpgFilePath, new_jpgFilePath, new_size)

if __name__ == '__main__':
    old_dirPath = 'selected_images'
    new_width = 416
    new_height = 416
    new_size = (new_width, new_height)
    new_dirPath = 'images_%sx%s' %(str(new_width), str(new_height))
    batch_modify_jpg(old_dirPath, new_dirPath, new_size)
    batch_modify_xml(old_dirPath, new_dirPath, new_size)
# 命令行写法
import os
def get_filePathList(dirPath, partOfFileName=''):
    allFileName_list = list(os.walk(dirPath))[0][2]
    fileName_list = [k for k in allFileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list

# 修改文件夹中的单个xml文件
import xml.etree.ElementTree as ET
def single_xmlCompress(old_xmlFilePath, new_xmlFilePath, new_size):
    new_width, new_height = new_size
    with open(old_xmlFilePath) as file:
        fileContent = file.read()
    root = ET.XML(fileContent)
    # 获得图片宽度变化倍数,并改变xml文件中width节点的值
    width = root.find('size').find('width')
    old_width = int(width.text)
    width_times = new_width / old_width
    width.text = str(new_width)
    # 获得图片高度变化倍数,并改变xml文件中height节点的值
    height = root.find('size').find('height')
    old_height = int(height.text)
    height_times = new_height / old_height
    height.text = str(new_height)
    # 获取标记物体的列表,修改其中xmin,ymin,xmax,ymax这4个节点的值
    object_list = root.findall('object')
    for object_item in object_list:
        bndbox = object_item.find('bndbox')
        xmin = bndbox.find('xmin')
        xminValue = int(xmin.text)
        xmin.text = str(int(xminValue * width_times))
        ymin = bndbox.find('ymin')
        yminValue = int(ymin.text)
        ymin.text = str(int(yminValue * height_times))
        xmax = bndbox.find('xmax')
        xmaxValue = int(xmax.text)
        xmax.text = str(int(xmaxValue * width_times))
        ymax = bndbox.find('ymax')
        ymaxValue = int(ymax.text)
        ymax.text = str(int(ymaxValue * height_times))
    tree = ET.ElementTree(root)
# 修改文件夹中的若干xml文件
def batch_xmlCompress(old_dirPath, new_dirPath, new_size):
    xmlFilePath_list = get_filePathList(old_dirPath, '.xml')
    for xmlFilePath in xmlFilePath_list:
        old_xmlFilePath = xmlFilePath
        xmlFileName = os.path.split(old_xmlFilePath)[1]
        new_xmlFilePath = os.path.join(new_dirPath, xmlFileName)
        single_xmlCompress(xmlFilePath, new_xmlFilePath, new_size) #写在for循环里修改若干
from PIL import Image
def single_imageCompress(old_imageFilePath, new_imageFilePath, new_size):
    old_image = Image.open(old_imageFilePath)
    new_image = old_image.resize(new_size, Image.ANTIALIAS)
# 修改文件夹中的若干jpg文件
def batch_imageCompress(old_dirPath, new_dirPath, new_size, suffix):
    if not os.path.isdir(new_dirPath):
    imageFilePath_list = get_filePathList(old_dirPath, suffix)
    for imageFilePath in imageFilePath_list:
        old_imageFilePath = imageFilePath
        jpgFileName = os.path.split(old_imageFilePath)[1]
        new_imageFilePath = os.path.join(new_dirPath, jpgFileName)
        single_imageCompress(old_imageFilePath, new_imageFilePath, new_size)

# 解析运行代码文件时传入的参数
import argparse
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dirPath', type=str, help='文件夹路径', default='../resources/source_images')    
    parser.add_argument('-w', '--width', type=int, default=416)
    parser.add_argument('-he', '--height', type=int, default=416)
    parser.add_argument('-s', '--suffix', type=str, default='.JPG')
    argument_namespace = parser.parse_args()
    return argument_namespace  

# 主函数    
if __name__ == '__main__':
    argument_namespace = parse_args()
    old_dirPath = argument_namespace.dirPath
    assert os.path.exists(old_dirPath), 'not exists this path: %s' %old_dirPath
    width = argument_namespace.width
    height = argument_namespace.height
    new_size = (width, height)
    new_dirPath = '../resources/images_%sx%s' %(str(width), str(height))
    suffix = argument_namespace.suffix
    batch_imageCompress(old_dirPath, new_dirPath, new_size, suffix)
    batch_xmlCompress(old_dirPath, new_dirPath, new_size)

1.4 划分训练集和测试集

编辑类别文件resources/className_list.txt,每1行表示1个类别。运行命令python _04_generate_txtFile.py -dir images_416*416会划分训练集dataset_train.txt和测试集dataset_test.txt,_04_generate_txtFile.py代码如下:

import xml.etree.ElementTree as ET
import os
import argparse
from sklearn.model_selection import train_test_split

# 从文本文件中解析出物体种类列表className_list,要求每个种类占一行
def get_classNameList(txtFilePath):
    with open(txtFilePath, 'r', encoding='utf8') as file:
        fileContent = file.read()    # strip()会把两头所有的空格、制表符和换行都去掉
        line_list = [k.strip() for k in fileContent.split('\n') if k.strip()!=''] 
        className_list= sorted(line_list, reverse=False)
    return className_list  
# 获取文件夹中的文件路径
import os
def get_filePathList(dirPath, partOfFileName=''):
    allFileName_list = list(os.walk(dirPath))[0][2]
    fileName_list = [k for k in allFileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list
# 解析运行代码文件时传入的参数
import argparse
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dirPath', type=str, help='文件夹路径', default='../resources/images_416x416')    
    parser.add_argument('-s', '--suffix', type=str, default='.JPG')
    parser.add_argument('-c', '--class_txtFilePath', type=str, default='../resources/category_list.txt')
    argument_namespace = parser.parse_args()
    return argument_namespace  
# 主函数
if __name__ == '__main__':
    argument_namespace = parse_args()
    dataset_dirPath = argument_namespace.dirPath
    assert os.path.exists(dataset_dirPath), 'not exists this path: %s' %dataset_dirPath    
    suffix = argument_namespace.suffix
    class_txtFilePath = argument_namespace.class_txtFilePath 
    xmlFilePath_list = get_filePathList(dataset_dirPath, '.xml')
    className_list = get_classNameList(class_txtFilePath)
    train_xmlFilePath_list, test_xmlFilePath_list = train_test_split(xmlFilePath_list, test_size=0.1)
    dataset_list = [('dataset_train', train_xmlFilePath_list), ('dataset_test', test_xmlFilePath_list)]
    for dataset in dataset_list: #先第一个(),再第二个()
        txtFile_path = '%s.txt' %dataset[0] #dataset[0]表示'dataset_train'和'dataset_test'
        txtFile = open(txtFile_path, 'w') # txtFile就是dataset_train.txt和dataset_test.txt写在循环里     
        for xmlFilePath in dataset[1]:
            jpgFilePath = xmlFilePath.replace('.xml', '.JPG')
            with open(xmlFilePath) as xmlFile:
                xmlFileContent = xmlFile.read()
            root = ET.XML(xmlFileContent)
            for obj in root.iter('object'):
                className = obj.find('name').text
                if className not in className_list:
                    print('error!! className not in className_list')
                classId = className_list.index(className)
                bndbox = obj.find('bndbox')
                bound = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text),
                         int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)]
                txtFile.write(" " + ",".join([str(k) for k in bound]) + ',' + str(classId))


from os import listdir
from os.path import isfile, isdir, join
import random
path = './Annotations' # 里面全xml文件
files = listdir(path)
# print(files) #[、、、.xml]

data_rate = {
    'test': 10, 
    'train': 60, 
    'val': 30

test, train, validation = list(), list(), list()
for index, file_name in enumerate(files):
    rand = random.randint(1,100)
    filename = file_name.split('.')[0]
    if (rand <= 10):
    elif (rand <= 70):
    elif (rand <= 100):

print('test: \n', test)
print('train: \n', train)
print('validation: \n', validation)

with open('./Main/test.txt', 'w') as f: # 0.1
    for name in test:
with open('./Main/train.txt', 'w') as f: # 0.6
    for name in train:
with open('./Main/val.txt', 'w') as f:  # 0.3 
    for name in validation:
with open('./Main/trainval.txt', 'w') as f: # 0.9
    for name in train:
    for name in validation:



文件夹keras-yolo3-master中打开终端Terminal,然后运行命令python _05_train.py即可开始训练。调整模型训练的轮次epochs需要修改代码文件_05_train.py的第85行fit_generator方法中的参数,即第90行参数epochs的值。_05_train.py代码如下:

# 导入常用的库
import os
import numpy as np
# 导入keras库
import keras.backend as K
from keras.layers import Input, Lambda
from keras.models import Model
# 导入yolo3文件夹中mode.py、utils.py这2个代码文件中的方法
from yolo3.model import preprocess_true_boxes, yolo_body, yolo_loss
from yolo3.utils import get_random_data

# 从文本文件中解析出物体种类列表category_list,要求每个种类占一行
def get_categoryList(txtFilePath):
    with open(txtFilePath, 'r', encoding='utf8') as file:
        fileContent = file.read()
    line_list = [k.strip() for k in fileContent.split('\n') if k.strip()!='']
    category_list = sorted(line_list, reverse=False)
    return category_list    

# 从表示anchor的文本文件中解析出anchor_ndarray
def get_anchorNdarray(anchor_txtFilePath): # anchor_txtFilePath是./model_data/yolo_anchors.txt
    with open(anchor_txtFilePath) as file:
        anchor_ndarray = [float(k) for k in file.read().split(',')]
    return np.array(anchor_ndarray).reshape(-1, 2)

# 创建YOLOv3模型,通过yolo_body方法架构推理层inference,配合损失函数完成搭建卷积神经网络。   
def create_model(input_shape,
    K.clear_session() # get a new session
    image_input = Input(shape=(None, None, 3))
    height, width = input_shape
    num_anchors = len(anchor_ndarray)
    y_true = [Input(shape=(height // k,
                           width // k,
                           num_anchors // 3,
                           num_classes + 5)) for k in [32, 16, 8]]
    model_body = yolo_body(image_input, num_anchors//3, num_classes)
    print('Create YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))

    if load_pretrained and os.path.exists(weights_h5FilePath):
        model_body.load_weights(weights_h5FilePath, by_name=True, skip_mismatch=True)
        print('Load weights from this path: {}.'.format(weights_h5FilePath))
        if freeze_body:
            num = len(model_body.layers)-7
            for i in range(num):
                model_body.layers[i].trainable = False
            print('Freeze the first {} layers of total {} layers.'.format(num, len(model_body.layers)))

    model_loss = Lambda(yolo_loss,
                        arguments={'anchors': anchor_ndarray,
                                   'num_classes': num_classes,
                                   'ignore_thresh': 0.5})(
                                        [*model_body.output, *y_true])
    model = Model([model_body.input, *y_true], model_loss)
    return model

# 调用此方法时,模型开始训练
def train(model,
                  loss={'yolo_loss': lambda y_true, y_pred: y_pred})
    # 划分训练集和验证集              
    batch_size = 8
    val_split = 0.05
    with open(annotationFilePath) as file:
        lines = file.readlines()
    num_val = int(len(lines)*val_split)
    num_train = len(lines) - num_val
    print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
    # 模型利用生成器产生的数据做训练
        data_generator(lines[:num_train], batch_size, input_shape, anchor_ndarray, num_classes),
        steps_per_epoch=max(1, num_train // batch_size),
        validation_data=data_generator(lines[num_train:], batch_size, input_shape, anchor_ndarray, num_classes),
        validation_steps=max(1, num_val // batch_size),
    # 当模型训练结束时,保存模型
    if not os.path.isdir(logDirPath):
    model_savedPath = os.path.join(logDirPath, 'trained_weights.h5')

# 图像数据生成器
def data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes):
    n = len(annotation_lines)
    i = 0
    while True:
        image_data = []
        box_data = []
        for b in range(batch_size):
            i %= n
            image, box = get_random_data(annotation_lines[i], input_shape, random=True)
            i += 1
        image_data = np.array(image_data)
        box_data = np.array(box_data)
        y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)
        yield [image_data, *y_true], np.zeros(batch_size)

# 解析运行代码文件时传入的参数
import argparse
def parse_args():
    parser = argparse.ArgumentParser() 
    parser.add_argument('-w', '--width', type=int, default=416)
    parser.add_argument('-he', '--height', type=int, default=416)
    parser.add_argument('-c', '--class_txtFilePath', type=str, default='../resources/category_list.txt')
    parser.add_argument('-a', '--anchor_txtFilePath', type=str, default='./model_data/yolo_anchors.txt')
    argument_namespace = parser.parse_args()
    return argument_namespace  
# 主函数
if __name__ == '__main__':
    argument_namespace = parse_args()
    class_txtFilePath = argument_namespace.class_txtFilePath
    anchor_txtFilePath = argument_namespace.anchor_txtFilePath
    category_list = get_categoryList(class_txtFilePath)
    anchor_ndarray = get_anchorNdarray(anchor_txtFilePath)
    width = argument_namespace.width
    height = argument_namespace.height
    input_shape = (width, height) # multiple of 32, height and width
    model = create_model(input_shape, anchor_ndarray, len(category_list))
    annotationFilePath = 'dataset_train.txt'
    train(model, annotationFilePath, input_shape, anchor_ndarray, len(category_list))


提取码:a0ld , fish_weights.zip解压后,将文件trained_weights.h5放到文件夹saved_model中。

3.1 单张图片

文件夹keras-yolo3-master中打开终端Terminal运行命令jupyter notebook,打开代码文件_07_yolo_test.ipynb如下:第1个代码块加载YOLOv3模型;第2个代码块加载测试集文本文件dataset_test.txt,并取出其中的图片路径赋值给变量jpgFilePath_list;第3个代码块是根据图片路径打开图片后,调用YOLO对象的detect_image方法对图片做目标检测。

from _06_yolo import YoloModel
yolo_model = YoloModel(weightsFilePath='saved_model/trained_weights.h5')
with open('dataset_test.txt') as file:
    line_list = file.readlines()
jpgFilePath_list = [k.split()[0] for k in line_list]
from PIL import Image
jpgFilePath = jpgFilePath_list[0]
image = Image.open(jpgFilePath)


3.2 视频

将图片合成为1部视频:文件夹keras-YOLOv3中打开Terminal,运行命令sudo apt-get install ffmpeg安装软件ffmpeg。继续在此Terminal中运行命令ffmpeg -start_number 1 -r 1 -i images_416x416/%03d.jpg -vcodec mpeg4 keras-yolo3-master/1.mp4,请读者确保当前Terminal所在目录中有文件夹images_416x416
继续在此Terminal中运行命令pip install opencv-python安装opencv-python库。cd keras-yolo3-master,在此Terminal中运行命令python yolo_video.py --input 1.mp4 --output fish_output.mp4,表示对视频文件1.mp4做目标检测,并将检测结果保存为视频文件fish_output.mp4。YOLOv3模型速度很快,本案例中检测1张图片只需要0.05秒。如果不人为干预,完成1帧图片的目标检测后立即开始下1帧,速度过快,人眼看不清楚。本文作者修改了代码文件_06_yolo.py的第183行,使完成1帧的目标检测后停止0.5秒,这样视频的展示效果能够易于人眼接受。_06_yolo.py代码如下:

# -*- coding: utf-8 -*-
# 导入常用的库
import os
import time
import numpy as np
# 导入keras库
from keras import backend as K
from keras.layers import Input
# 导入yolo3文件夹中mode.py、utils.py这2个代码文件中的方法
from yolo3.model import yolo_eval, yolo_body
from yolo3.utils import letterbox_image
# 导入PIL画图库
from PIL import Image, ImageFont, ImageDraw

# 通过种类的数量,每个种类对应的颜色,颜色变量color为rgb这3个数值组成的元祖
import colorsys
def get_colorList(category_quantity):
    hsv_list = []
    for i in range(category_quantity):
        hue = i / category_quantity
        saturation = 1
        value = 1
        hsv = (hue, saturation, value)
    colorFloat_list = [colorsys.hsv_to_rgb(*k) for k in hsv_list]
    color_list = [tuple([int(x * 255) for x in k]) for k in colorFloat_list]
    return color_list

# 定义类YoloModel
class YoloModel(object):
    defaults = {
        "weights_h5FilePath": '../resources/trained_weights.h5',
        "anchor_txtFilePath": 'model_data/yolo_anchors.txt',
        "category_txtFilePath": '../resources/category_list.txt',
        "score" : 0.3,
        "iou" : 0.35,
        "model_image_size" : (416, 416) #must be a multiple of 32

    def get_defaults(cls, n):
        if n in cls.defaults:
            return cls.defaults[n]
            return 'Unrecognized attribute name "%s"' %n
    # 类实例化方法
    def __init__(self, **kwargs):
        self.__dict__.update(self.defaults) # set up default values
        self.__dict__.update(kwargs) # and update with user overrides
        self.category_list = self.get_categoryList()
        self.anchor_ndarray = self.get_anchorNdarray()
        self.session = K.get_session()
        self.boxes, self.scores, self.classes = self.generate()
    # 从文本文件中解析出物体种类列表category_list,要求每个种类占一行
    def get_categoryList(self):
        with open(self.category_txtFilePath, 'r', encoding='utf8') as file:
            fileContent = file.read()
        line_list = [k.strip() for k in fileContent.split('\n') if k.strip()!='']
        category_list= sorted(line_list, reverse=False)
        return category_list    
    # 从表示anchor的文本文件中解析出anchor_ndarray
    def get_anchorNdarray(self):
        with open(self.anchor_txtFilePath, 'r', encoding='utf8') as file:
            number_list = [float(k) for k in file.read().split(',')]
        anchor_ndarray = np.array(number_list).reshape(-1, 2)
        return anchor_ndarray

    # 加载模型
    def generate(self):
        # 在Keras中,如果模型训练完成后只保存了权重,那么需要先构建网络,再加载权重
        num_anchors = len(self.anchor_ndarray)
        num_classes = len(self.category_list)
        self.yolo_model = yolo_body(Input(shape=(None, None, 3)),
        # 给不同类别的物体准备不同颜色的方框
        category_quantity = len(self.category_list)
        self.color_list = get_colorList(category_quantity)
        # 目标检测的输出:方框box,得分score,类别class
        self.input_image_size = K.placeholder(shape=(2, ))
        boxes, scores, classes = yolo_eval(self.yolo_model.output,
        return boxes, scores, classes

    # 检测图片
    def detect_image(self, image):
        startTime = time.time()
        # 模型网络结构运算所需的数据准备
        boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size)))
        image_data = np.array(boxed_image).astype('float') / 255
        image_data = np.expand_dims(image_data, 0)  # Add batch dimension.
        # 模型网络结构运算
        out_boxes, out_scores, out_classes = self.session.run(
            [self.boxes, self.scores, self.classes],
                self.yolo_model.input: image_data,
                self.input_image_size: [image.size[1], image.size[0]],
                K.learning_phase(): 0
        # 调用ImageFont.truetype方法实例化画图字体对象
        font = ImageFont.truetype(font='font/FiraMono-Medium.otf',
             size=np.floor(2e-2 * image.size[1] + 0.5).astype('int32'))
        thickness = (image.size[0] + image.size[1]) // 300
        # 循环绘制若干个方框
        for i, c in enumerate(out_classes):
            # 定义方框上方文字内容
            predicted_class = self.category_list[c]
            score = out_scores[i]
            label = '{} {:.2f}'.format(predicted_class, score)
            # 调用ImageDraw.Draw方法实例化画图对象
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)
            box = out_boxes[i]
            top, left, bottom, right = box
            top = max(0, np.floor(top + 0.5).astype('int32'))
            left = max(0, np.floor(left + 0.5).astype('int32'))
            bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32'))
            right = min(image.size[0], np.floor(right + 0.5).astype('int32'))
            # 如果方框在图片中的位置过于靠上,调整文字区域
            if top - label_size[1] >= 0:
                text_region = np.array([left, top - label_size[1]])
                text_region = np.array([left, top + 1])
            # 方框厚度为多少,则画多少个矩形
            for j in range(thickness):
                draw.rectangle([left + j, top + j, right - j, bottom - j],
            # 绘制方框中的文字
            draw.rectangle([tuple(text_region), tuple(text_region + label_size)],
            draw.text(text_region, label, fill=(0, 0, 0), font=font)
            del draw
        # 打印检测图片使用的时间
        usedTime = time.time() - startTime
        print('检测这张图片用时%.2f秒' %(usedTime))
        return image
    # 关闭tensorflow的会话
    def close_session(self):
# 对视频进行检测
def detect_video(yolo, video_path, output_path=""):
    import cv2
    vid = cv2.VideoCapture(video_path)
    if not vid.isOpened():
        raise IOError("Couldn't open webcam or video")
    video_FourCC    = int(vid.get(cv2.CAP_PROP_FOURCC))
    video_fps       = vid.get(cv2.CAP_PROP_FPS)
    video_size      = (int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)),
    isOutput = True if output_path != "" else False
    if isOutput:
        print("!!! TYPE:", type(output_path), type(video_FourCC), type(video_fps), type(video_size))
        print(video_FourCC, video_fps, video_size)
        out = cv2.VideoWriter(output_path, video_FourCC, video_fps, video_size)
    accum_time = 0
    curr_fps = 0
    fps = "FPS: ??"
    prev_time = time.time()
    cv2.namedWindow("result", cv2.WINDOW_NORMAL)
    cv2.resizeWindow('result', video_size[0], video_size[1])
    while True:
        return_value, frame = vid.read()
            image = Image.fromarray(frame[...,::-1])
        except Exception as e:
        image = yolo.detect_image(image)
        result = np.asarray(image)
        curr_time = time.time()
        exec_time = curr_time - prev_time
        prev_time = curr_time
        accum_time = accum_time + exec_time
        curr_fps = curr_fps + 1
        if accum_time > 1:
            accum_time = accum_time - 1
            fps = "FPS: " + str(curr_fps)
            curr_fps = 0
        cv2.putText(result, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                    fontScale=0.50, color=(255, 0, 0), thickness=2)
        cv2.imshow("result", result[...,::-1])
        if isOutput:
        if cv2.waitKey(1) & 0xFF == ord('q'):
        sleepTime = 0.5

3.3 多张图片


# 导入YOLO类
from _06_yolo import YoloModel
# 导入常用的库
from PIL import Image
import cv2
import os
import time
import numpy as np

# 获取文件夹中的文件路径
def get_filePathList(dirPath, partOfFileName=''):
    all_fileName_list = next(os.walk(dirPath))[2]
    fileName_list = [k for k in all_fileName_list if partOfFileName in k]
    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]
    return filePath_list

# 对多张图片做检测,并保存为avi格式的视频文件
def detect_multi_images(weights_h5FilePath, imageFilePath_list, out_aviFilePath=None):
    yolo_model = YoloModel(weights_h5FilePath=weights_h5FilePath)
    windowName = 'detect_multi_images_result'
    cv2.namedWindow(windowName, cv2.WINDOW_NORMAL)
    width = 1000
    height = 618
    display_size = (width, height)
    cv2.resizeWindow(windowName, width, height)
    if out_aviFilePath is not None:
        fourcc = cv2.VideoWriter_fourcc('M', 'P', 'E', 'G')
        videoWriter = cv2.VideoWriter(out_aviFilePath, fourcc, 1.3, display_size)
    for imageFilePath in imageFilePath_list:
        image = Image.open(imageFilePath)
        out_image = yolo_model.detect_image(image)
        resized_image = out_image.resize(display_size, Image.ANTIALIAS)
        resized_image_ndarray = np.array(resized_image)
        cv2.imshow(windowName, resized_image_ndarray[..., ::-1])
        if out_aviFilePath is not None:
            videoWriter.write(resized_image_ndarray[..., ::-1])
        # 第1次按空格键可以暂停检测,第2次按空格键继续检测
        pressKey = cv2.waitKey(500)
        if ord(' ') == pressKey:
        # 按Esc键或者q键可以退出循环
        if 27 == pressKey or ord('q') == pressKey:
    # 退出程序时关闭模型、写入器、cv窗口        
# 解析运行代码文件时传入的参数
import argparse  # dirPath, image_suffix,  weights_h5FilePath,  imageFilePath_list,  out_aviFilePath
def parse_args(): 
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dirPath', type=str, help='directory path', default='../resources/n01440764')
    parser.add_argument('--image_suffix', type=str, default='.JPEG')
    parser.add_argument('-w', '--weights_h5FilePath', type=str, default='../resources/trained_weights.h5')
    argument_namespace = parser.parse_args()
    return argument_namespace     

# 主函数
if __name__ == '__main__':   
    argument_namespace = parse_args()
    dirPath = argument_namespace.dirPath
    image_suffix = argument_namespace.image_suffix
    weights_h5FilePath = argument_namespace.weights_h5FilePath
    imageFilePath_list = get_filePathList(dirPath, image_suffix)
    out_aviFilePath = '../resources/fish_output_2.avi'  
    detect_multi_images(weights_h5FilePath, imageFilePath_list, out_aviFilePath)

