python代碼、深度學習數據預處理、常用code

000

111 深度學習數據預處理

python中"getitem"详解

今天在学习为深度学习数据预处理的时候用了一下“getitem"方法,发现还挺好用,下面详细解释一下。
getitem(self,key):

把类中的属性定义为序列,可以使用__getitem__()函数输出序列属性中的某个元素,这个方法返回与指定键想关联的值。对序列来说,键应该是0~n-1的整数,其中n为序列的长度。对映射来说,键可以是任何类型。

如果在类中定义了__getitem__()方法,那么它的实例对象(假设为P)就可以以P[key]形式取值,当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。当对类的属性进行下标的操作时,首先会被__getitem__() 拦截,从而执行在__getitem__()方法中设定的操作,如赋值,修改内容,删除内容等。

	
class Taget:
    def __init__(self,id):
        self.id=id
  
    def __getitem__(self, item):
        print('这个方法被调用')
        return self.id
  
a=Taget('This is id')
print(a.id)
print(a['python'])
 
>输出:
'这个方法被调用'
'This is id'

在线数据增强

1:数据增强利用随机数
2:增加epoch数

python代碼、深度學習數據預處理、常用code_第1张图片

  if self.flag == 'train':
            # data augment
            _scale = 1.0 # not by _scale, and hardcode by height, width = (288, 384)
            # _scale = np.random.uniform(0.75, 1.25)
            scale = np.int(self.height * _scale)
            angle = (np.random.rand()-0.5)*30 
            rgb = TF.rotate(rgb, angle) # input (256,256)  output (256,341)
            dep = TF.rotate(dep, angle) # input (192,256)  output (192,256)
            hflip = np.random.uniform(0.0, 1.0)
            vflip = np.random.uniform(0.0, 1.0)
            if hflip > 0.5:
                rgb = TF.hflip(rgb)
                dep = TF.hflip(dep)
            if vflip > 0.5:
                rgb = TF.vflip(rgb)
                dep = TF.vflip(dep)
            t_rgb = T.Compose([
                T.Resize(scale),
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                T.ToTensor(),
                T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
            t_dep = T.Compose([
                # T.Resize(scale, Image.NEAREST),
                T.Resize(scale, T.InterpolationMode.NEAREST),
                self.ToNumpy(),
                T.ToTensor()
            ])
            if self.save_png:
               print('scale:',scale)
               cv2.imwrite('before_rgb.png',np.asarray(rgb)/16*255) # w*h= 640*480
               cv2.imwrite('before_dep.png',np.asarray(dep)/16*255) # w*h= 640*480
               print('rgb.shape:',np.asarray(rgb).shape)
               print('dep.shape:',np.asarray(dep).shape)
               #print('rotate size == raw')
            rgb = t_rgb(rgb)
            dep = t_dep(dep)
            if self.save_png:
               cv2.imwrite('rgb.png',rgb[0,:,:].numpy()/16*255) # w*h= 341*256
               cv2.imwrite('dep.png',dep[0,:,:].numpy()/16*255) # w*h= 341*256
               print('rgb[0,:,:].numpy().shape:',rgb[0,:,:].numpy().shape)
               print('dep[0,:,:].numpy().shape:',dep[0,:,:].numpy().shape)
               self.save_png = False

            _, height_nocrop, width_nocrop = rgb.shape

在训练模型时,每一个epoch都会进行数据增强,重新调用transforms.Compose()操作,使得训练数据变换,而因为内部操作的随机性,每一次最后输出的图像都可能会不一样,
因此每次epoch迭代后喂进网络的图像都是你增强后的图像,也都可能是不同的,所以也可以变相的认为训练数据增多了。

裁剪类transforms.CenterCrop - 中心裁剪transforms.RandomCrop - 随机裁剪transforms.RandomResizedCrop - 随机长宽比裁剪transforms.FiveCrop - 上下左右中心裁剪transforms.TenCrop - 上下左右中心裁剪后旋转翻转和旋转transforms.RandomHorizontalFlip - 随机水平翻转transforms.RandomVerticalFlip - 随机垂直翻转transforms.RandomRotation - 随机旋转图像变换transforms.Pad - 填充transforms.ColorJitter - 亮度、对比度和饱和度transforms.Grayscale - 转灰度图transforms.RandomGrayscale - 随机转灰度图transforms.RandomAffine - 随机仿射变换transforms.LinearTransformation - 线性变换transforms.Lambda - 自定义变换transforms.Resizetransforms.Normalize - 标准化transforms的操作transforms.RandomChoice - 随机选择给定 transforms 中的一种transforms.RandomApply - 加上随机概率transforms.RandomOrder - 随机打乱 transforms 操作的顺序

222 nvidia apex install

RuntimeError: Cuda extensions are being compiled with a version of Cuda that does not match the version used to compile Pytorch binaries. Pytorch binaries were compiled with Cuda 9.0.176.

解决方式:其错误意思就是cuda和pytorch的版本不对应,但是通过搜索也发现可以不带 --global --option 也能用
于是,修改第三行命令为:

//gitlab 
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 

// 改为
pip install -v --no-cache-dir ./
//最后的最后,希望大家都能成功安装,冲冲冲!

333 处理损坏图片

判断图片存在,图片是否损坏及处理

    def _load_png(self, index):
        if self.flag != 'test':
            input, target = self.pairs[index]
            #print('!= test input_:',input)
            #print('!= test target_:',target)            
            input = ops.join('tof_iphone', input)
            if input is None:
                print("input 图片文件不存在!")
            target = ops.join('tof_iphone', target)
            if target is None:
                print("target 图片文件损坏,加载资源失败!")
            print('input:',input)
            print('target:',target)           
            np_in = cv2.imread(input) # 只能判断文件名存在,不能判断是否损坏
            print('np_in:',type(np_in))
            if type(np_in) == type(np.array([1])):
               print("input 图片文件损坏!")
            np_in = cv2.cvtColor(np_in, cv2.COLOR_BGR2RGB)
            np_tar = cv2.imread(target, cv2.IMREAD_ANYDEPTH)
import cv2
import os
import numpy as np
fn_all = ['data/tof_iphone/','./train/','./test/']
for fn_a in fn_all:
  print('fn_a:',fn_a)
  for fn in os.walk(fn_a):
    if fn[-1] != []:
        for i in fn[-1]:
            img_f = fn[0] +'/'+ i
            #print('img_f',img_f)
            if type(cv2.imread(img_f)) != type(np.array([1])):
                print('remove img_f',img_f)
                #os.remove(img_f)

444 保存深度图float类型

if __name__ == '__main__':
    root_path = 
    csv_path = 
    png_path = 
    #depth   不用png jpg后缀保存,并且转换为.astype('float32')  则数据可用保存float类型
    pixel_data = read_iphone(csv_path)
    pixel_data = pixel_data.astype('float32')
    cv2.imwrite('00006_depth.exr', pixel_data)

    cv2.imwrite('00006_depth.png', pixel_data/16*255)
    #rgb
    dim = (img_w, img_h)
    rgb = cv2.imread(png_path)
    rgb = cv2.resize(rgb, dim,interpolation = cv2.INTER_LINEAR)
    cv2.imwrite('00006.png', rgb)
    #spot
    spot_mask, x_syn, y_syn = GenerateSpotMask(plt_flag=False)
    print('spot_mask:',spot_mask)
    print('spot_mask shape:',spot_mask.shape)
    dep_spot_mask = pixel_data * spot_mask
    cv2.imwrite('00006_spot.png', dep_spot_mask/16*255)
    dep_spot_mask = (dep_spot_mask).astype('float32')
    cv2.imwrite('00006_spot.exr', dep_spot_mask)
    print('dep_spot_mask/255*16:',dep_spot_mask)
    #test
    #depth   采用cv2.IMREAD_ANYDEPTH  可用读取float类型
    exr_depth = cv2.imread("/media/zyt/0DCC12A80DCC12A8/33_2022_test/read_iphone_depth/00006_depth.exr", cv2.IMREAD_ANYDEPTH)
    print('exr_depth:',exr_depth/16*255)
    cv2.imwrite('00006_exr_depth.png', exr_depth/16*255)
    #### depth2xyz
    #pc_xyz = depth2xyz(pixel_data, depth_cam_matrix)
    #print('pc_xyz',pc_xyz)

你可能感兴趣的:(环境配置方案,python,开发语言)