目录
一、tensorboard的使用
writer.add_scalar()
writer.add_image()
二、transforms的使用
三、常见的transforms
3.1 __call__()的作用
3.2 ToTensor() / Normalize() / Resize()
四、使用torchvision中自带的图像数据集
ps
取消pycharm大小写匹配
- 关注输入和输出类型
- 多看官方文档
- 关注方法需要什么参数
- 不知道返回值的时候 1)print 2)print(type()) 3) debug
from torch.utils.tensorboard import SummaryWriter
# ImportError: TensorBoard logging requires TensorBoard with Python summary writer installed. This should be available in 1.14 or above.
# 要安装tensorboard
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
"""Add scalar data to summary.
Args:
tag (string): Data identifier
scalar_value (float or string/blobname): Value to save
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
with seconds after epoch of event
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
image_path = "data/train/ants_image/6240329_72c01e663e.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)
writer.add_image("train", img_array, 1, dataformats='HWC')
# y = 2x
for i in range(100):
writer.add_scalar("y=2x", 3*i, i)
writer.close()
def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
"""Add image data to summary.
Note that this requires the ``pillow`` package.
Args:
## title
tag (string): Data identifier
## 图像的数据类型
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as
corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW.
可以直观看出训练过程中,给model提供了哪些数据。
transforms主要是对图片进行一些变换。
from torchvision import transforms
from PIL import Image
# python的用法-》 tensor数据类型
img_path = ""
img = Image.open(img_path)
# 1. transform如何被使用
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
print(tensor_img)
- 使用__call__()的话,实例化变量后不需要.方法;
class Person():
def __call__(self, name):
print("__call__"+"hello "+name)
def hello(self, name):
print("hello "+name)
person = Person()
person("zhangsan") # __call__hello zhangsan
person.hello("lisi") # hello lisi
https://pytorch.org/vision/stable/ 官方文档
使用torchvision中自带的图像数据集
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 训练数据True
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
# 测试数据False
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0])
# print(test_set.classes)
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
#
# print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()