AlexNet猫狗大战遇到的问题及解决方法

#数据集的加工
import cv2
import os
import numpy as np
'''def rebuild(dir):
    for root,dirs,files in os.walk(dir):
        for file in files:
            filepath = os.path.join(root,file)
            try:
                image = cv2.imread(filepath)
                dim = (227,227)
                resized = cv2.resize(image,dim)
                path = "cat_and_dog/cat_and_dog/"+file
                cv2.imwrite(path,resized)
            except:
                print(filepath)
                os.remove(filepath)
        cv2.waitKey(0)
rebuild("cat_and_dog/train")'''
#在这里导入的是图片集的根目录,os对数据集所在的文件夹进行读取,之后的一个for循环重建了图片数据所在的路径,在图片被重构后重新写入给定的位置
#需要提醒的是:笔者在这个代码段中对数据的读写是在一个try区域中,因为在整个数据集中不可避免地包含和出现坏地图片,这里当程序出现异常时,最简单地办法
#就是跳过出现问题地图片继续执行下去,因此在except模块中使用了os.remove函数对图片进行删除
'''第二步:图片数据集转换为Tensorflow专用格式'''
def get_file(file_dir):
    images = []
    temp = []
    for root,sub_folders,files in os.walk(file_dir):
        #image directories 图像目录
        for name in files:
            images.append(os.path.join(root,name))
            #get 10 sub_folders获取10个子文件夹
        for name in sub_folders:
            temp.append(os.path.join(root,name))
       # print(files)
        #assign 10 labels based on the folder names
        #根据文件夹名称分配10个标签
    labels=[]
    for one_folder in temp:
        n_img = len(os.listdir(one_folder))
        letter =  one_folder.split('\\')[-1]
        print(temp)
        if letter == 'cat':
            labels = np.append(labels, n_img * [0])
        else:
            labels = np.append(labels, n_img[1])
        # shuffle洗牌
        temp = np.array([images, labels])
        temp = temp.transpose()
        np.random.shuffle(temp)
        image_list = list(temp[:, 0])
        labels_list = list(temp[:, 1])
        label_list = [int(float(i)) for i in labels_list]
        return image_list, label_list
#上面地代码段中,首先是对数据集文件的位置进行读取,之后根据文件夹名称的不同将处于不同文件夹中的图片标签设置为0或者1,如果由更多分类的话可以根据这个格式
#设置更多的标签类进行保存,而numpy对数组的调整重构了存储有对应文件位置和文件标签的矩阵,并将其返回
get_file("cat_and_dog/cat_and_dog/")

当运行时会出现IndexError: too many indices for array这个错误此时将第二部分程序换成一下程序即可

def get_files(file_dir):  
    cats = []
    label_cats = []
    dogs = []
    label_dogs = []
    for file in os.listdir(file_dir): 
        name = file.split(sep='.')
        if name[0] == 'cat':
            cats.append(file_dir + file)
            label_cats.append(0)
        else:
            dogs.append(file_dir + file)
            label_dogs.append(1)
    print('there are %d cats\nthere are %d dogs' % (len(cats), len(dogs)))

    image_list = np.hstack((cats, dogs))  # 将猫狗图片堆积起来
    label_list = np.hstack((label_cats, label_dogs))  # label也堆积起来

    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    np.random.shuffle(temp) 

    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]

    return image_list, label_list

你可能感兴趣的:(AlexNet猫狗大战遇到的问题及解决方法)