Transforms 常用的类

Transforms

用于对输入图片的格式进行调整

一些题外话

在transforms.py文件中,点击pycharm左侧靠下的structure按钮可以得到文件的结构(class等)
展开要查看的class可以找到该类包括的方法

Transforms 常用的类_第1张图片

其中很多类都有__call__()这个成员函数
他的作用是:使类的实例化对象可以像调用函数一样使用,形式是:对象名(参数)

# __call__的举例:
class Person(object):
    def __call__(self, friend):
        print('__call__'+'My friend is %s...' % friend)

    def hello(self, friend):
        print('hello '+'My friend is %s...' % friend)

person = Person()
# __call__
person('aaa')
# hello
person.hello('aaa')

开始之前

图片还是用PIL读入,把一些输出结果通过tensorboard查看

from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

writer = SummaryWriter('logs')
img = Image.open('dataset/train/ants/0013035.jpg')
print(img)

ToTensor

将图片转化为tensor格式

trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image('ToTensor', img_tensor)

Normalize

以mean,std在三通道上做归一化,一般对图片有固定的数值

trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
writer.add_image('Normalize', img_norm, 0)
trans_norm1 = transforms.Normalize([1, 3, 5], [2, 4, 6])
img_norm1 = trans_norm1(img_tensor)
writer.add_image('Normalize', img_norm1, 1)

Resize

输入和输出都是PIL类型
如果参数是(h,w),则将图片大小转化成(h,w)的尺寸,如果参数是一个数,就将短边转化为这个长度,长边等比例变化。

这里resize传入(h,w)

print(img.size)
trans_resize = transforms.Resize((512, 512))
img_resize = trans_resize(img)
print(img_resize)	# 输出是PIL类型
img_resize = trans_totensor(img_resize)
writer.add_image('Resize', img_resize, 0)

Compose

将不同的transforms结合在一起,下一个transforms的输入是上一个的输出,因此需要确定输入与输出的数据类型

这里与resize传入一个整数的情况结合

trans_resize2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize2, trans_totensor])
img_resize2 = trans_compose(img)
writer.add_image('Resize', img_resize2, 1)

由于有__call__,正常写transforms的方式是:将所有用到的transforms使用类调用的方式按序放入列表中,然后放到Compose中,而不用再去把用到的transforms类一一实例化,如官方给的例子:

'''
Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])
'''

最后

对tensorboard别忘了:

writer.close()

你可能感兴趣的:(pytorch,python,pytorch)