学习pytorch5 常用的transforms

常用的transforms

  • 1. ToTensor()
  • 2. Normalize()

学习pytorch5 常用的transforms_第1张图片

1. ToTensor()

2. Normalize()

# 1. ToTensor  把PIL图片类型数据或ndarry numpy数据类型转换为tensor类型数据
from cv2 import imread
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

img_path = 'hymenoptera_data/train/bees/85112639_6e860b0469.jpg'
writer = SummaryWriter('logs')
def totensor():
    cv2_np = imread(img_path)
    trans_totensor = transforms.ToTensor()
    cv2_tensor = trans_totensor(cv2_np)

    writer.add_image('totensor', cv2_tensor)


# 2. normalize  用平均值和标准差对张量图像做归一化
def normalize():
    '''
    output[channel] = (input[channel] - mean[channel]) / std[channel]
    :return:
    '''
    cv2_np = imread(img_path)
    trans_totensor = transforms.ToTensor()
    cv2_tensor = trans_totensor(cv2_np)
    print(cv2_tensor[0][0][0])
    trans_norm = transforms.Normalize([1, 2, 1], [1, 1, 6])
    norm_tensor = trans_norm.forward(cv2_tensor)
    writer.add_image('norm totensor', norm_tensor, 2)
    print(norm_tensor[0][0][0])


if __name__ == '__main__':

    totensor()
    normalize()
    writer.close()

你可能感兴趣的:(学习pytorch,python,学习,python,开发语言)