图像分类(1),数据预处理

本文介绍如何使用pytorh利用预训练模型进行图像分类,主要参考Transfer Learning Tutorial和

具体代码可以参考Image_classification

  1. 下载代码文件:git clone https://github.com/chenmozxh/pytorch_studying
  2. 下载数据集:wget https://download.pytorch.org/tutorial/hymenoptera_data.zip 
    这个数据集是imagenet的一个小子集,包含ants和bees两个分类
  3. 解压数据集:unzip hymenoptera_data.zip

    数据集结构为:文件夹hymenoptera_data下存在训练集路径train和测试集路径test,train和test下都有ants和bees两个文件夹,即相应的图像。

  4. 运行python3 example1.py就开始训练了,可以看出随着epoch的加深,loss越来越小,而准确率acc越来越高图像分类(1),数据预处理_第1张图片
  5. example1.py代码解析:
      数据导入,使用官方写好的torchvision.datasets.ImageFolder接口实现数据导入。这个函数只需要你提供图像所在文件夹data_dir/train和data_dir/test即可。这两个目录下分别为N个子文件夹,N为分类的类别数,每个文件夹下为这个类别的图像。这样,torchvision.datasets.ImageFloder就会返回一个列表,列表中每一个值都是一个tuple,每个tuple包含图像和标签信息
    def Data_loader(Data_Path):
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
             ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                #transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }
    
        data_dir = Data_Path
        image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
        class_names = image_datasets['train'].classes
     
        return dataloaders, image_datasets, class_names
    
    dataloaders, image_datasets, class_names = Data_loader('hymenoptera_data')
    print(image_datasets)
    for e in image_datasets:
        print(e)
        print(image_datasets[e])
        for index, k in enumerate(image_datasets[e]):
            print(type(k), len(k))
            print(index, k[0].size(), k[1])
    
    

    transform对图像进行预处理。torchvision.transform.Compose是用来管理所有的transforms操作的。RandomSizeCrop和RandomHorizontalFlip的输入是PIL Image,也就是用python的PIL Image库读进来图像内容。而Normalize的对象是Tensor,因此需要增加一个ToTensor()用来将图像生成成Tensor。另外,transforms.Scale(256)是resize操作,目前已经被Resize取代。
    ImageFolder只是返回list,list是不能作为模型输入,因此在pytorch中,用另外一个类来封装list,那就是torch.utils.data.DataLoader。这个类将list类型的输入数据,图像和标签分别封装成一个Tensor数据格式,让模型使用。
    另外一个非常重要的类是torch.utils.data.Dataset,这个类是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现,比如torchvision.datasets.ImageFolder和torch.utils.data.DataLoader这两个类。所以,如果数据不是按照上面的格式存储是,需要自定义一个类来读取数据,自定义的这个类必须继承自torch.utils.data.Dataset这个基类。代码如下:

    def default_loader(path):
        try:
            img = Image.open(path)
            return img.convert('RGB')
        except:
            print("Cannot read image: {}".format(path))
    
    
    class customData(Dataset):
        def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
            with open(txt_path) as input_file:
                lines = input_file.readlines()
                #self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
                #self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
                self.img_name = [os.path.join(img_path, line.strip()[:-2]) for line in lines]
                self.img_label = [int(line.strip()[-1:]) for line in lines]
            self.data_transforms = data_transforms
            self.dataset = dataset
            self.loader = loader
    
        def __len__(self):
            return len(self.img_name)
    
        def __getitem__(self, item):
            img_name = self.img_name[item]
            label = self.img_label[item]
            img = self.loader(img_name)
    
            if self.data_transforms is not None:
                try:
                    img = self.data_transforms[self.dataset](img)
                except:
                    print("Cannot transform image: {}".format(img_name))
            return img, label
    
    def Data_loader():
        batch_size = 4
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }
    
        image_datasets = {x: customData(img_path='hymenoptera_data_cp/',
                                        txt_path=(x + '.txt'),
                                        data_transforms=data_transforms,
                                        dataset=x) for x in ['train', 'val']}
    
        # wrap your data and label into Tensor
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                     batch_size=batch_size,
                                                     shuffle=True) for x in ['train', 'val']}
    
        dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    
        return image_datasets, dataloaders
    

     

你可能感兴趣的:(图像分类(1),数据预处理)