python中"getitem"详解
今天在学习为深度学习数据预处理的时候用了一下“getitem"方法,发现还挺好用,下面详细解释一下。
getitem(self,key):
把类中的属性定义为序列,可以使用__getitem__()函数输出序列属性中的某个元素,这个方法返回与指定键想关联的值。对序列来说,键应该是0~n-1的整数,其中n为序列的长度。对映射来说,键可以是任何类型。
如果在类中定义了__getitem__()方法,那么它的实例对象(假设为P)就可以以P[key]形式取值,当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。当对类的属性进行下标的操作时,首先会被__getitem__() 拦截,从而执行在__getitem__()方法中设定的操作,如赋值,修改内容,删除内容等。
class Taget:
def __init__(self,id):
self.id=id
def __getitem__(self, item):
print('这个方法被调用')
return self.id
a=Taget('This is id')
print(a.id)
print(a['python'])
>输出:
'这个方法被调用'
'This is id'
在线数据增强
1:数据增强利用随机数
2:增加epoch数
if self.flag == 'train':
# data augment
_scale = 1.0 # not by _scale, and hardcode by height, width = (288, 384)
# _scale = np.random.uniform(0.75, 1.25)
scale = np.int(self.height * _scale)
angle = (np.random.rand()-0.5)*30
rgb = TF.rotate(rgb, angle) # input (256,256) output (256,341)
dep = TF.rotate(dep, angle) # input (192,256) output (192,256)
hflip = np.random.uniform(0.0, 1.0)
vflip = np.random.uniform(0.0, 1.0)
if hflip > 0.5:
rgb = TF.hflip(rgb)
dep = TF.hflip(dep)
if vflip > 0.5:
rgb = TF.vflip(rgb)
dep = TF.vflip(dep)
t_rgb = T.Compose([
T.Resize(scale),
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
t_dep = T.Compose([
# T.Resize(scale, Image.NEAREST),
T.Resize(scale, T.InterpolationMode.NEAREST),
self.ToNumpy(),
T.ToTensor()
])
if self.save_png:
print('scale:',scale)
cv2.imwrite('before_rgb.png',np.asarray(rgb)/16*255) # w*h= 640*480
cv2.imwrite('before_dep.png',np.asarray(dep)/16*255) # w*h= 640*480
print('rgb.shape:',np.asarray(rgb).shape)
print('dep.shape:',np.asarray(dep).shape)
#print('rotate size == raw')
rgb = t_rgb(rgb)
dep = t_dep(dep)
if self.save_png:
cv2.imwrite('rgb.png',rgb[0,:,:].numpy()/16*255) # w*h= 341*256
cv2.imwrite('dep.png',dep[0,:,:].numpy()/16*255) # w*h= 341*256
print('rgb[0,:,:].numpy().shape:',rgb[0,:,:].numpy().shape)
print('dep[0,:,:].numpy().shape:',dep[0,:,:].numpy().shape)
self.save_png = False
_, height_nocrop, width_nocrop = rgb.shape
在训练模型时,每一个epoch都会进行数据增强,重新调用transforms.Compose()操作,使得训练数据变换,而因为内部操作的随机性,每一次最后输出的图像都可能会不一样,
因此每次epoch迭代后喂进网络的图像都是你增强后的图像,也都可能是不同的,所以也可以变相的认为训练数据增多了。
裁剪类transforms.CenterCrop - 中心裁剪transforms.RandomCrop - 随机裁剪transforms.RandomResizedCrop - 随机长宽比裁剪transforms.FiveCrop - 上下左右中心裁剪transforms.TenCrop - 上下左右中心裁剪后旋转翻转和旋转transforms.RandomHorizontalFlip - 随机水平翻转transforms.RandomVerticalFlip - 随机垂直翻转transforms.RandomRotation - 随机旋转图像变换transforms.Pad - 填充transforms.ColorJitter - 亮度、对比度和饱和度transforms.Grayscale - 转灰度图transforms.RandomGrayscale - 随机转灰度图transforms.RandomAffine - 随机仿射变换transforms.LinearTransformation - 线性变换transforms.Lambda - 自定义变换transforms.Resizetransforms.Normalize - 标准化transforms的操作transforms.RandomChoice - 随机选择给定 transforms 中的一种transforms.RandomApply - 加上随机概率transforms.RandomOrder - 随机打乱 transforms 操作的顺序
RuntimeError: Cuda extensions are being compiled with a version of Cuda that does not match the version used to compile Pytorch binaries. Pytorch binaries were compiled with Cuda 9.0.176.
解决方式:其错误意思就是cuda和pytorch的版本不对应,但是通过搜索也发现可以不带 --global --option 也能用
于是,修改第三行命令为:
//gitlab
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
// 改为
pip install -v --no-cache-dir ./
//最后的最后,希望大家都能成功安装,冲冲冲!
判断图片存在,图片是否损坏及处理
def _load_png(self, index):
if self.flag != 'test':
input, target = self.pairs[index]
#print('!= test input_:',input)
#print('!= test target_:',target)
input = ops.join('tof_iphone', input)
if input is None:
print("input 图片文件不存在!")
target = ops.join('tof_iphone', target)
if target is None:
print("target 图片文件损坏,加载资源失败!")
print('input:',input)
print('target:',target)
np_in = cv2.imread(input) # 只能判断文件名存在,不能判断是否损坏
print('np_in:',type(np_in))
if type(np_in) == type(np.array([1])):
print("input 图片文件损坏!")
np_in = cv2.cvtColor(np_in, cv2.COLOR_BGR2RGB)
np_tar = cv2.imread(target, cv2.IMREAD_ANYDEPTH)
import cv2
import os
import numpy as np
fn_all = ['data/tof_iphone/','./train/','./test/']
for fn_a in fn_all:
print('fn_a:',fn_a)
for fn in os.walk(fn_a):
if fn[-1] != []:
for i in fn[-1]:
img_f = fn[0] +'/'+ i
#print('img_f',img_f)
if type(cv2.imread(img_f)) != type(np.array([1])):
print('remove img_f',img_f)
#os.remove(img_f)
if __name__ == '__main__':
root_path =
csv_path =
png_path =
#depth 不用png jpg后缀保存,并且转换为.astype('float32') 则数据可用保存float类型
pixel_data = read_iphone(csv_path)
pixel_data = pixel_data.astype('float32')
cv2.imwrite('00006_depth.exr', pixel_data)
cv2.imwrite('00006_depth.png', pixel_data/16*255)
#rgb
dim = (img_w, img_h)
rgb = cv2.imread(png_path)
rgb = cv2.resize(rgb, dim,interpolation = cv2.INTER_LINEAR)
cv2.imwrite('00006.png', rgb)
#spot
spot_mask, x_syn, y_syn = GenerateSpotMask(plt_flag=False)
print('spot_mask:',spot_mask)
print('spot_mask shape:',spot_mask.shape)
dep_spot_mask = pixel_data * spot_mask
cv2.imwrite('00006_spot.png', dep_spot_mask/16*255)
dep_spot_mask = (dep_spot_mask).astype('float32')
cv2.imwrite('00006_spot.exr', dep_spot_mask)
print('dep_spot_mask/255*16:',dep_spot_mask)
#test
#depth 采用cv2.IMREAD_ANYDEPTH 可用读取float类型
exr_depth = cv2.imread("/media/zyt/0DCC12A80DCC12A8/33_2022_test/read_iphone_depth/00006_depth.exr", cv2.IMREAD_ANYDEPTH)
print('exr_depth:',exr_depth/16*255)
cv2.imwrite('00006_exr_depth.png', exr_depth/16*255)
#### depth2xyz
#pc_xyz = depth2xyz(pixel_data, depth_cam_matrix)
#print('pc_xyz',pc_xyz)