class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
def __init__(self) -> None:
_log_api_usage_once(self)
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
函数主要的功能就是:Converts a PIL Image or numpy.ndarray (H x W x C) in the range[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
主要就是通过F.to_tensor()来实现的(torchvision.transforms.functional as F)
我们再看看F.to_tensor()是如何实现功能的:
def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.
See :class:`~torchvision.transforms.ToTensor` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_tensor)
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
default_float_dtype = torch.get_default_dtype()
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == "1":
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
首先是通过torch.from_numpy(pic)转化为tensor类型,然后再通过img.to(dtype=default_float_dtype)转换为默认的数据形式。例如torch.flaot32,64等。
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == "1":
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
总结:PIL.image和np.array转化为torch.tensor类型
最后都是通过torch.from_numpy()转化的,其中PIL.Image需要先转换为numpy 类型(np.array(pic))
class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
"""
def __init__(self) -> None:
_log_api_usage_once(self)
def __call__(self, pic):
"""
.. note::
A deep copy of the underlying array is performed.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.pil_to_tensor(pic)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
函数主要的功能就是:Converts a PIL Image **(H x W x C) ** to a Tensor of shape (C x H x W)
主要就是通过F.pil_to_tensor()来实现的(torchvision.transforms.functional as F)
我们再看看F.pil_to_tensor()是如何实现功能的:
def pil_to_tensor(pic):
"""Convert a ``PIL Image`` to a tensor of the same type.
This function does not support torchscript.
See :class:`~torchvision.transforms.PILToTensor` for more details.
.. note::
A deep copy of the underlying array is performed.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(pil_to_tensor)
if not F_pil._is_pil_image(pic):
raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
if accimage is not None and isinstance(pic, accimage.Image):
# accimage format is always uint8 internally, so always return uint8 here
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
pic.copyto(nppic)
return torch.as_tensor(nppic)
# handle PIL Image
img = torch.as_tensor(np.array(pic, copy=True))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1))
return img
关键代码:
# handle PIL Image
img = torch.as_tensor(np.array(pic, copy=True))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1))
return img
这里是一个深度拷贝的方式把PIL copy 到np.array然后再利用torch.as_tensor转化为torch.tensor类型。
class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
``short``).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
def __init__(self, mode=None):
_log_api_usage_once(self)
self.mode = mode
def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns:
PIL Image: Image converted to PIL Image.
"""
return F.to_pil_image(pic, self.mode)
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
if self.mode is not None:
format_string += f"mode={self.mode}"
format_string += ")"
return format_string
函数功能:
把(Tensor 或者 np.ndarray)类型的数据转换为 PIL image。主要实现是利用 F.to_pil_image(pic,mode)。
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
See :class:`~torchvision.transforms.ToPILImage` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image)
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
# check number of channels
if pic.shape[-3] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
# check number of channels
if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = "L"
elif npimg.dtype == np.int16:
expected_mode = "I;16"
elif npimg.dtype == np.int32:
expected_mode = "I"
elif npimg.dtype == np.float32:
expected_mode = "F"
if mode is not None and mode != expected_mode:
raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ["LA"]
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "LA"
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "RGBA"
else:
permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
if mode is None and npimg.dtype == np.uint8:
mode = "RGB"
if mode is None:
raise TypeError(f"Input type {npimg.dtype} is not supported")
return Image.fromarray(npimg, mode=mode)
关键代码:
npimg = pic
------------------------------------------
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)
-------------------------------------------------
Image.fromarray(npimg, mode=mode)
理解就是如果是Tensor类型就先转换为numpy.ndarray,然后再通过
Image.fromarray(npimg,mode=mode)来进行转换,具体细节分析可看源代码。