SA-SSD的ssd_rotate_head.py代码的Bug

torch1.4.0跟torch1.1.0不太一样,

修改ssd_rotate_head.py代码中的bilinear_interpolate_torch_gridsample,改为

def bilinear_interpolate_torch_gridsample(image, samples_x, samples_y):
    C, H, W = image.shape
    image = image.unsqueeze(1)  # change to:  C x 1 x H x W

    samples_x = samples_x.unsqueeze(2)
    samples_x = samples_x.unsqueeze(3)
    samples_y = samples_y.unsqueeze(2)
    samples_y = samples_y.unsqueeze(3)

    samples = torch.cat([samples_x, samples_y], 3)
    samples[:, :, :, 0] = (samples[:, :, :, 0] / (W - 1))  # normalize to between  0 and 1
    samples[:, :, :, 1] = (samples[:, :, :, 1] / (H - 1))  # normalize to between  0 and 1
    samples = samples * 2 - 1  # normalize to between -1 and 1

    return torch.nn.functional.grid_sample(image, samples, align_corners=True)

添加align_corners=True!

不然会跑着跑着报CUDA illegal access错误

你可能感兴趣的:(computer,vision论文代码分析)