深度学习,制作类似cifar10图像数据集

 

在学习卷积神经网络的时候,遇到了cifar10图像数据集,用着挺好,但不想局限于固定的几种图像的识别,所以就有了自己制作数据集来识别的想法。

 

一、cifar10数据集。

据原网站介绍,数据集为二进制。将cifar10解压后,得到data_batch_1等数据集,打开看一下:

import pickle

f = open('./data_batch_1','rb') #以二进制读模式打开

d = pickle.load(f)

print(d)

可知数据集为dict型,主要有datalabels等四种键值。

 

二、爬取图片

首先要感谢被爬网站的开放性和包容心,潭州教育坚持对爬虫技术的无私分享以及博主Jimmy

 

import requests

import urllib.parse

import threading

# 设置最大线程锁(与电脑配置和带宽有关)

thread_lock = threading.BoundedSemaphore(value=10)

def get_page(url):

    page = requests.get(url)

    page = page.content

    page = page.decode('utf-8')

    # 将 bytes 转换成 字符串

    return page

 

def pages_from_duitang(label):

    pages = []

    #找到图片链接规律

url = 'https://www.duitang.com/napi/blog/list/by_search/?

kw={}&start={}&limit=1000'

    #将中文转成url编码

    label = urllib.parse.quote(label)

    for index in range(0,3600,100):

        u = url.format(label, index)

        print(u)

        page = get_page(u)

        pages.append(page)

    return pages

def findall_in_page(page,startpart,endpart):

    all_strings = []

    end = 0

    while page.find(startpart,end) != -1:

         start = page.find(startpart, end) + len(startpart)

         end = page.find(endpart,start)

         string = page[start:end]

         all_strings.append(string)

    return all_strings

 

def pic_urls_from_pages(pages):

    pic_urls = []

    for page in pages:

        urls = findall_in_page(page, 'path":"', '"')

        pic_urls.extend(urls)

    return pic_urls

 

def download_pics(url,n):

    r = requests.get(url)

    path = 'pics/fish/' + str(n) + '.jpg'

    with open(path, 'wb') as f:

        f.write(r.content)

    #下载结束,解锁

    thread_lock.release()

 

def main(label):

    pages = pages_from_duitang(label)

    pic_urls = pic_urls_from_pages(pages)

    n = 0

    for url in pic_urls:

        n += 1

        print('正在下载第 {} 张图片'.format(n))

        #上锁

        thread_lock.acquire()

        t = threading.Thread(target=download_pics, args = (url, n))

        t.start()

 

main('鱼')

 

三、制作数据集

from PIL import Image

import numpy as np

import pickle,glob,os

 

arr = [[]]

#number of pictures

n = 1

for infile in glob.glob('D:/py/pics/trees/*.jpg'):

        file,ext = os.path.splitext(infile)#分离文件名和扩展名

        Img = Image.open(infile)

        print(Img.mode,file)#图片尺寸和文件名(用于调试过程中定位错误)

    

if Img.mode != 'RGB':#将所有非'RGB'通道图片转化为RGB

        Img = Img.convert('RGB')

    width = Img.size[0]

    height = Img.size[1]

 

    print('{} imagesize is:{} X {}'.format(n,width,height))

    n += 1

 

Img = Img.resize([32,32],Image.ANTIALIAS)

#抗锯齿的过滤属性,这些都是为了保证剪切图片的时候,最大降低失真度,这样出

#的图片体积就稍微大些了。

    r,g,b = Img.split()

    r_array = np.array(r).reshape([1024])

    g_array = np.array(g).reshape([1024])

    b_array = np.array(b).reshape([1024])

    merge_array = np.concatenate((r_array,g_array,b_array))

    if arr == [[]]:

        arr = [merge_array]

        continue

    #拼接

arr = np.concatenate((arr,[merge_array]),axis=0)

    #打乱顺序

arr = np.random.shuffle(arr)

#生成标签

labelset = np.zeros((arr.shape[0],))

labelset = np.reshape(labelset,[arr.shape[0],])

 

#字典分割出训练集和测试集

train_dic = {'data':arr[:2000],'labels':labelset[:2000]}

test_dic = {'data':arr[2000:],'labels':labelset[2000:]}

 

f = open('./data_batch_8','wb')#二进制写模式打开,如果不存在,直接生成

pickle.dump(train_dic,f,protocol=2)

#序列化操作

#由于阿里云平台用的是Python2.7版本,我的是3.6,所以要进行退档操作protocol=2

 

g = open('./test_batch_1','wb')

pickle.dump(test_dic,g,protocol=2)

 

 

四、训练和测试

由于本机硬件水平较低,采用阿里云平台进行测试,根据自己的数据集规模,调整平台提供的代码。经测试,精度达到76%。对于这个结果还是相当满意的,因为数据集中干扰太多,没有进行筛选。

 深度学习,制作类似cifar10图像数据集_第1张图片

 

五、问题

在制作数据集过程中,遇到两个问题:

1、r,g,b = img.split():(已解决)

valueError:too many values to unpackexpected 3

unpack的个数对不上,比如:a,b = tuple(1,2,3) 就会报出这个错误

通过Img.mode发现有的图片是1”、“L”、“P”和“RGBA模式,需要convert。

2、r,g,b = img.split():(待解决)

OSError:cannot identify image file:路径+格式

暂时理解为系统兼容性问题

 

你可能感兴趣的:(深度学习)