街景字符编码识别-Task02-数据读取与数据扩增

学习目标

  • 学习Python和Pytorch中图像读取

  • 学会扩增方法和使用Pytorch读取赛题数据

以下测试平台及库的版本号依次为: ubuntu 18.04,python 3.6,jupyter 1.0.0,notebook 6.0.3,Pillow  7.1.2,pytorch1.3.1,matplotlib 3.2.1

1.Python的PIL库读取图片

Python有一个强大方便的图像处理库叫PIL(Python Image Library),只支持到Python 2.7。而Pillow是PIL的一个派生分支,如今已经发展成为比PIL本身更具活力的图像处理库。因此再很多程序中,安装时的命令是pip install Pillow,使用时却是需要from  PIL import Image,记住这个流程就好。

在jupyter notebook中,使用PIL读取图片,并对图片进行RGB三通道颜色显示,一通道灰度显示,并对原图片尺寸长宽进行resize。

from  PIL import Image
import numpy as np

path = './input/train/000000.png'#图片路径

im_old=Image.open(path).convert('RGB')    #RGB三通道颜色
img_ndarray_old=np.asarray(im_old,dtype='float')/255   #像素从0-255进行归一化

im_old_grey=Image.open(path).convert('L')  #R灰度单通道颜色
img_ndarray_grey=np.asarray(im_old_grey,dtype='float')/255

imagesize=200
im_new=im.resize((imagesize,imagesize)) #对图片长宽resize,时尺寸为(100,100)
img_ndarray_new=np.asarray(im_old,dtype='float')/255

(1)RGB三通道颜色,图片尺寸为(350, 741, 3)

img_ndarray_old.shape

运行输出:(350, 741, 3) 

长741,宽为350,颜色是rgb的3通道

im_old

运行输出: 

 街景字符编码识别-Task02-数据读取与数据扩增_第1张图片

 (2)灰度图,颜色单通道,图片尺寸为(350, 741)

img_ndarray_grey.shape

运行输出:(350, 741)

im_old_grey

 运行输出: 

街景字符编码识别-Task02-数据读取与数据扩增_第2张图片

 (3)对图片长宽resize,保留三通道,尺寸为(200,200,3)

img_ndarray_new.shape
运行输出:(200, 200, 3)
im_new

 运行输出:

街景字符编码识别-Task02-数据读取与数据扩增_第3张图片

2.Pytorch中图像读取

baseline中的代码

(1)定义好读取图像的继承已有类Dataset的类 SVHNDataset

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 设置最长的字符长度为5个
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

(2)定义好训练数据和验证数据的DataLoader

train_path = glob.glob('./input/train/*.png')
train_path.sort()
train_json = json.load(open('./input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))

train_loader = torch.utils.data.DataLoader(
    SVHNDataset(train_path, train_label,
                transforms.Compose([
                    transforms.Resize((64, 128)),
                    transforms.RandomCrop((60, 120)),
                    transforms.ColorJitter(0.3, 0.3, 0.2),
                    transforms.RandomRotation(5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])), 
    batch_size=40, 
    shuffle=False, 
    num_workers=10,

)

val_path = glob.glob('./input/val/*.png')
val_path.sort()
val_json = json.load(open('./input/val.json'))
val_label = [val_json[x]['label'] for x in val_json]
print(len(val_path), len(val_label))

val_loader = torch.utils.data.DataLoader(
    SVHNDataset(val_path, val_label,
                transforms.Compose([
                    transforms.Resize((60, 120)),
                    # transforms.ColorJitter(0.3, 0.3, 0.2),
                    # transforms.RandomRotation(5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])), 
    batch_size=40, 
    shuffle=False, 
    num_workers=10,
)

 首先,在步骤(2)中,glob.glob返回所有匹配的文件路径列表, json模块用来读取json文件,返回一个字典dict,而列表和字典中的信息我们可以很容易提取出来,例如图片的数量和其对应的标签。

其次,在步骤(1)中,__getitem__ :返回图像数据和对应标签,即return img, torch.from_numpy(np.array(lbl[:5]))

在pytorch深度学习中,原始图像读取使用PIL读出的数据格式,标签需要转换为pytorch框架自定义的数据格式,在pytorch中,需要转为torch.Tensor,pytorch提供了torch.Tensornumpy.ndarray转换为接口:

import torch
from matplotlib import pyplot as plt

#torch.from_numpy()   numpy.ndarray转为torch.Tensor
atensor = torch.from_numpy(img_ndarray_new)

#atensor.numpy()  获取atensor对象的numpy格式数据
img= atensor.numpy()

#显示图片
plt.figure()
plt.title('image')
plt.imshow(img)
plt.show()

 运行输出:

街景字符编码识别-Task02-数据读取与数据扩增_第4张图片

最后,我们看到baseline的步骤2中有个torch.utils.data.DataLoader,DataLoader获取数据集需要一定的标准格式,
torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None,
)

它的第一个参数需要传入Dataset的实例,其他的参数解释可参看源码,在安装路径中查看类Dataset的源码如下

class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

 我们要重写Dataset类的方法,具体的在baseline中,我们定义一个SVHNDataset类,继承了Dataset类,然后重写类的方法__getitem__ :返回图像数据和对应标签 ,方法__len__:返回数据的长度,此处是图片或标签的个数,__add__可以先忽略。
另外的transform模块是专门负责图像预处理、实现图像增强的模块,可完成诸如图片旋转,灰度调节,尺寸变换等等,在使用的过程中transform.Compose()与DataSet类结合起来使用。

至此,我们完成了python与pytorch读取图像及其标签数据的方法,下一篇让我们继续探究如何在Pytorch框架下构建CNN模型并完成训练。

参考

[1]天池竞赛:https://tianchi.aliyun.com/competition/entrance/531795/

[2]https://github.com/datawhalechina/team-learning/

你可能感兴趣的:(街景字符编码识别-Task02-数据读取与数据扩增)