Pytorch 自定义dataloader实现多通道按比例输入不同label数据

Dataloader

当我们使用pytorch进行深度学习模型开发时,为了使用自己的数据集训练网络,往往需要构建自己的dataloader。dataloader用于构建可迭代的数据装载器,可以使用for img, labels in dataloaders进行可迭代对象的访问,从而提取数据用于训练与验证。

pytorch 循环这个 DataLoader 对象,将img, label加载到模型中进行训练。

自定义

自定义的LoadData三个方法是缺一不可的:

1.__init__(),主要用来定义数据的预处理
2.__getitem__方法,返回数据的img和label
3.__len__方法,返回数据个数

pytorch会通过本文自定义的dataloader,从ch1,ch2,ch3三个文件夹按一定比例提取img和mask中的数据作为模型输入,从而实现多通道按比例输入不同label数据。
Pytorch 自定义dataloader实现多通道按比例输入不同label数据_第1张图片
Pytorch 自定义dataloader实现多通道按比例输入不同label数据_第2张图片

代码

import os
import numpy as np
import random
import cv2
from torch.utils import data
import glob




class My_data(data.Dataset):
    def __init__(self, opt):
        self.opt = opt
        self.__counter = 0

        self.load_data()

    def load_data(self):
        root_path = self.opt.data_root

        name_img_png = R'/img/*.png'
        name_mask_png = R'/mask/*.png'

        temp_char = 'ch1'
        self._channel_1_img = glob.glob(os.path.join(root_path, temp_char+name_img_png))
        self._channel_1_mask = glob.glob(os.path.join(root_path, temp_char+name_mask_png))
        self.check_error(self._channel_1_img, self._channel_1_fake, self._channel_1_mask, temp_char)

        temp_char = 'ch2'
        self._channel_2_img = glob.glob(os.path.join(root_path, temp_char + name_img_png))
        self._channel_2_mask = glob.glob(os.path.join(root_path, temp_char + name_mask_png))
        self.check_error(self._channel_2_img, self._channel_2_fake, self._channel_2_mask, temp_char)

        temp_char = 'ch3'
        self._channel_3_img = glob.glob(os.path.join(root_path, temp_char + name_img_png))
        self._channel_3_mask = glob.glob(os.path.join(root_path, temp_char + name_mask_png))
        self.check_error(self._channel_3_img, self._channel_3_fake, self._channel_3_mask, temp_char)

        self._channel_1_size = self._channel_1_img.__len__()
        self._channel_1_ramdon_index = self.get_randon_index(self._channel_1_size)  # 打亂index 隨機序列組合
        self._channel_1_counter = 0

        self._channel_2_size = self._channel_2_img.__len__()
        self._channel_2_ramdon_index = self.get_randon_index(self._channel_2_size)  # 打亂index 隨機序列組合
        self._channel_2_counter = 0

        self._channel_3_size = self._channel_3_img.__len__()
        self._channel_3_ramdon_index = self.get_randon_index(self._channel_3_size)  # 打亂index 隨機序列組合
        self._channel_3_counter = 0


    def check_error(self,img, fake, mask, path):
        img.sort()
        fake.sort()
        mask.sort()
        if(img.__len__() != fake.__len__() or img.__len__() != mask.__len__()):
            print("Data Erro!",path)

    def get_data(self, index, img_path, mask_path):
        img = cv2.imread(img_path[index])
        mask = cv2.imread(mask_path[index])

        return img, mask


    def get_randon_index(self, length):
        index = [ i for i in range(0,length)]
        randon_index = random.sample(index,index.__len__())
        return randon_index

    def image_swap(self, img):
        image = np.swapaxes(img,0,2)
        image = np.swapaxes(image,1,2)
        image = image.astype(np.float32)
        return image

    def mask_swap(self, mask):
        mask = np.swapaxes(mask,0,2)
        mask = np.swapaxes(mask,1,2)
        mask = mask.astype(np.long)
        return mask

    def __getitem__(self, index):
        self.batch_counter = (int)(self.__counter % self.opt.batchSize)
        # print(self.batch_counter)


        if(self.batch_counter == 0 or self.batch_counter == 1):
              _index = self._channel_1_ramdon_index[self._channel_1_counter]
              img1, mask = self.get_data(_index, self._channel_1_img, self._channel_1_fake, self._channel_1_mask)
              self._channel_1_counter = self._channel_1_counter + 1
              if(self._channel_1_counter >= self._channel_1_size-1):  # 已訓練圖片數超過訓練集大小,則更新隨機的index
                  self._channel_1_ramdon_index = self.get_randon_index(self._channel_1_size)
                  self._channel_1_counter = 0

        if (self.batch_counter == 2 or self.batch_counter == 3):
            _index = self._channel_2_ramdon_index[self._channel_2_counter]
            img1, mask = self.get_data(_index, self._channel_2_img, self._channel_2_fake, self._channel_2_mask)
            self._channel_2_counter = self._channel_2_counter + 1
            if (self._channel_2_counter >= self._channel_2_size - 1):  # 已訓練圖片數超過訓練集大小,則更新隨機的index
                self._channel_2_ramdon_index = self.get_randon_index(self._channel_2_size)
                self._channel_2_counter = 0

        if (self.batch_counter == 4 or self.batch_counter == 5):
            _index = self._channel_3_ramdon_index[self._channel_3_counter]
            img1, mask = self.get_data(_index, self._channel_3_img, self._channel_3_fake, self._channel_3_mask)
            self._channel_3_counter = self._channel_3_counter + 1
            if (self._channel_3_counter >= self._channel_3_size - 1):  # 已訓練圖片數超過訓練集大小,則更新隨機的index
                self._channel_3_ramdon_index = self.get_randon_index(self._channel_3_size)
                self._channel_3_counter = 0

        #
        # cv2.imwrite("img1.png", img1)
        # cv2.imwrite("mask.png", mask)
        img1 = self.image_swap(img1) / 255
        mask = self.mask_swap(mask) / 255

        self.__counter = self.__counter + 1
        return img1, mask

    def __len__(self):
        return self.opt.niter * self.opt.batchSize


你可能感兴趣的:(pytorch,深度学习,python)