transforms.RandomResizedCrop
:先按照设置的缩放和宽高比切割图片,然后将切割后的图片缩放到指定大小。主要需要解释的是get_params
函数如何获取切割位置信息和函数的执行流程:
class RandomResizedCrop(torch.nn.Module):
"""初始化"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
# 删除了参数检查和警告
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
'''该函数根据 scale 和 ratio 获取切割的 (i, j, h, w),也就是切割原图的起始位置和切割的高度宽度'''
@staticmethod
def get_params(
img: Tensor, scale: List[float], ratio: List[float]
) -> Tuple[int, int, int, int]:
# 原图宽高
width, height = F._get_image_size(img)
# 原图面积
area = height * width
log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10):
# 原图面积 * 设置范围内随机缩放比例 => 目标切割面积
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
# 同样,在设置的宽高比范围内获取一个随机值作为 本次的 宽高比
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
# 根据切割面积和切割比例,计算出切割区域的宽和高
# 以下是 这个方程组的解: xy=a; x/y=r; 其中x和y代表宽高,a代表面积,r代表比例
# x = \sqrt(a*r); y = \sqrt(a/r)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
# 如果在原图的范围内,就得到了最终的切割范围
# 否则,尝试10次后,跳出循环,进行中心切割
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
w = width
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
'''根据计算好的i,j,h,w切割图片,然后resize到指定大小'''
def forward(self, img):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def resized_crop(
img: Tensor, top: int, left: int, height: int, width: int, size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR
) -> Tensor:
"""Crop the given image and resize it to desired size.
如下,先根据位置信息crop,然后resize
"""
img = crop(img, top, left, height, width)
img = resize(img, size, interpolation)
return img