pytorch——单张图片读取、dataloader等处理

        首先,无论是train还是test,无论是整个数据集还是单张图片,第一步都是先加载数据,此后才是做一些处理。这篇文章就是对pytorch中数据的读取和处理做一个简单的学习记录。

1 对于单张图片读取(一般用于简单检测)

        事实上,可以用opencv或者PIL(pillow)读取图片,然后再进行处理。

        但和torchvision比较相近的是用PIL读取,然后同样可以用transforms的一些函数。这里的图像预处理过程(比如说图片大小)等就可以跟训练中的validation保持一致。

from PIL import image
from torchvision import transforms

img = Image.open(img_path)
img = transforms.Resize(448)(img)  # 保持长宽比的resize方法
# img = transforms.Resize((448,448))(img)  # 直接resize成正方形的方法
img = transforms.ToTensor()(img)
img = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)

img_ = img.unsqueeze(0)  # 拓展维度, 拓展batch_size那一维
img_ = img_.to(device)

# 推理过程
output = net(img_)  # net是提前读取的模型
pred_index = int(torch.argmax(output_com, dim=1))

2 对于图片数据集batch化读取(用于训练or批量测试)

        2.1 torchvision自带的dataloader

        

你可能感兴趣的:(pytorch,深度学习,人工智能)