UCF动作识别的数据集及处理(pytorch)

数据介绍

数据首页https://www.crcv.ucf.edu/data/UCF101.php
视频数据,共101中动作分类。

  • 我下载了UCF11,数据相对小一点,经过处理了每段27.97FPS,视频尺寸为(240,320)

数据处理

视频数据可以提取帧从而用图像的方法处理
写一个类加载和处理视频数据
思想步骤:

  • 加载视频文件,并随机抽取连续帧(16)
  • 裁剪视频的大小
  • 加载到tensor,并做成batch

pytorch加载数据集流程

https://www.jianshu.com/p/6e22d21c84be
创建几个类

  • 处理RGB(如减去均值归一化)
class ClipNormal(object):
  def __init__(self, b=104, g=117, r=123):
    self.means = np.array((r, g, b))
    
  def __call__(self, sample):
    #__call__使类可直接调用
    video_con ,video_label = sample['video_con'],sample['video_label']
    new_video_con = video_con-self.means
    return {'video_con': new_video_con, 'video_label': video_label}
  • 重新调整大小
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): tuple则为输出size,int则为较小的的图像边,保持纵横比不变
    """
    def __init__(self, output_size=(182,242)):
        assert isinstance(output_size, (int, tuple)),"size must be int or tuple"
        self.output_size = output_size
        
    def __call__(self, sample):
        video_con,vedio_label = sample['video_con'],sample['video_label']
        
        h,w = video_con.shape[1],video_con.shape[2]
        if isinstance(output_size,int):
            if h<w:
                o_h,o_w = self.output_size, int(self.output_size*w/h)
            else:
                o_h,o_w = int(self.output_size*h/w), self.output_size
        else:
            o_h,o_w = self.output_size
        
        rescale_video = []
        for i in range(16):
            image = video_con[i,:,:,:]
            ## transform.resize必须接收的是uint8,否则会出错
            img = transform.resize(image.astype("uint8"),(o_h,o_w))
            rescale_video.append(img)
        
        rescale_video = np.array(rescale_video)
                
        return {"video_con": rescale_video, "video_label":video_label}
  • 随机裁剪区域
class RandCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): 如果为int,则裁剪为方形
    """
    def __init__(self, output_size=(160,160)):
        assert isinstance(output_size, (int, tuple)),"size must be int or tuple"
        if isinstance(output_size, int):
            self.output_size = (output_size,output_size)
        else:
            self.output_size = output_size
        
    def __call__(self, sample):
        video_con,vedio_label = sample['video_con'],sample['video_label']
        
        h,w = video_con.shape[1],video_con.shape[2]
        o_h,o_w = self.output_size
        
        top = np.random.randint(0, h-o_h)
        left = np.random.randint(0, w-o_w)
        
        crop_video = []
        for i in range(16):
            image = video_con[i,:,:,:]
            img = video_con[top:top+o_h,left:left+o_w]
            crop_video.append(img)
        
        crop_video = np.array(crop_video)
                
        return {"video_con": crop_video, "video_label":video_label}  
  • 最后转换为tensor
class ToTensor(object):
    def __call__(self,sample):
        video_con,video_label = sample['video_con'],sample['video_label']
        # numpy image: batch_size x FPS x H x W x C
        # torch image: batch_size x FPS x C X H X W
        video_con = video_con.transpose((0,3,1,2))
        return {"video_con": torch.from_numpy(video_con), "video_label":torch.FloatTensor([video_label])}

处理完成后接着就是加载数据集

class UCF11(Dataset):
    def __init__(self,dir_name,transform=None):
        self.dir_name = dir_name
        self.transform = transform
        self.video_class = {}
        self.base_dir = os.path.join(os.path.dirname('__file__'),dir_name)
        self.label_dir = {k:v for v,k in enumerate(os.listdir(self.base_dir))}
        for each in os.listdir(self.base_dir):
            for i_class in os.listdir(os.path.join(self.base_dir,each)):
                if i_class != "Annotation":
                    for i_video in os.listdir(os.path.join(os.path.join(self.base_dir,each),i_class)):
                        self.video_class[i_video] = each
        # 由于只能按索引getitem,所以要将这里处理成列表或dataframe
        self.video_list = list(self.video_class.keys())
    
    def __len__(self):
        return len(self.video_class)
    
    ## 获取帧
    def __getitem__(self, video_idx):
        video_name = self.video_list[video_idx]
        video_path = os.path.join(os.path.join(os.path.join(self.base_dir,self.video_class[video_name]),video_name[:-7]),video_name)
        video_label = self.video_class[video_name]
        video_con = self.get_video_i(video_path)
        sample = {'video_con':video_con, 'video_label':self.label_dir[video_label]}
        if self.transform:
            sample = self.transform(sample)
        return sample

        
    def get_video_i(self, video_path):
        vd = cv2.VideoCapture(video_path)
        # 7 CV_CAP_PROP_FRAME_COUNT  #视频总帧数
        frames_num=vd.get(7)
        image_start=random.randint(1,frames_num-17)
#         print(frames_num)
        success, frame = vd.read()
        i = 1
        video_cut = []
#         videoWriter =cv2.VideoWriter('video4.avi',cv2.CAP_FFMPEG,fourcc=cv2.VideoWriter_fourcc('D', 'I', 'V', 'X'),fps=4,frameSize=(240,320))

        while success:
            if i>=image_start and i<image_start+16:
                video_cut.append(frame)
#                 plt.subplot(4,4,i-image_start+1)
#                 plt.imshow(frame)
#                 videoWriter.write(frame)
            elif i>=image_start+16:
                break
            i += 1
            success, frame = vd.read()
#         plt.show()
#         videoWriter.release()
        return np.array(video_cut)
                            
if __name__=='__main__':
    myUCF11 = UCF11("UCF11_updated_mpg",transform=transforms.Compose([ClipNormal(),Rescale(),RandCrop(),ToTensor()]))
    dataloader=DataLoader(myUCF11,batch_size=8,shuffle=True,num_workers=0)
    ## 封装为dataset后,需要构建__len__和__getitem__方法,getitem方法回按照0-len(self)遍历
    for i_batch,sample_batched in enumerate(dataloader):
        print (i_batch,sample_batched['video_con'].size(),sample_batched['video_label'].size())

几个需要注意的点

  • transform方法的__call__方法的构建,__call__可以使类直接调用,当作函数。如myClass(para)等同于myClass.call(para)
  • skimage.transform.resize必须接收的是uint8,否则会出错,转换后的会变成非常小的值。
  • __getitem__方法会在类索引的时候调用,如myClass[0]等同于myClass.getitem(0)
  • 封装为dataset后,需要构建__len__和__getitem__方法,getitem方法回按照0-len(self)遍历,因此我们需要将数据处理为list或dataframe,这样才可以调用dataloader

代码处理部分参考:https://www.jianshu.com/p/4ebf2a82017b
数据处理完成后下一步就是进行模型构建。

你可能感兴趣的:(学习体验实践,pytorch,深度学习)