Learning to Resize Images for Computer Vision Tasks
文章主题思想,使用网络进行学习,来调整输入图片的大小。
https://www.kaggle.com/c/seti-breakthrough-listen/discussion/246558;
class CNNWithResizer(nn.Module):
def __init__(self, cfg, pretrained=False):
super().__init__()
self.cfg = cfg
self.n = 16
self.slope = .1
self.r = 1
self.cnn = timm.create_model(self.cfg.model_name, pretrained=pretrained, in_chans=1)
if hasattr(self.cnn, "fc"):
nb_ft = self.cnn.fc.in_features
self.cnn.fc = nn.Identity()
elif hasattr(self.cnn, "_fc"):
nb_ft = self.cnn._fc.in_features
self.cnn._fc = nn.Identity()
elif hasattr(self.cnn, "classifier"):
nb_ft = self.cnn.classifier.in_features
self.cnn.classifier = nn.Identity()
elif hasattr(self.cnn, "last_linear"):
nb_ft = self.cnn.last_linear.in_features
self.cnn.last_linear = nn.Identity()
elif hasattr(self.cnn, "head"):
nb_ft = self.cnn.head.in_features
self.cnn.head = nn.Identity()
self.block1 = nn.Sequential(
nn.Conv2d(1, self.n, kernel_size=(7, 7), stride=(1,1), padding=(1, 1), bias=False),
nn.LeakyReLU(negative_slope=self.slope),
nn.Conv2d(self.n, self.n, kernel_size=(1, 1), stride=(1,1), padding=(1, 1), bias=False),
nn.LeakyReLU(negative_slope=self.slope),
nn.BatchNorm2d(self.n))
self.block2 = nn.Sequential(
nn.Conv2d(self.n, self.n, kernel_size=(3, 3), stride=(1,1), padding=(1, 1), bias=False),
nn.BatchNorm2d(self.n),
nn.LeakyReLU(negative_slope=self.slope),
nn.Conv2d(self.n, self.n, kernel_size=(3, 3), stride=(1,1), padding=(1, 1), bias=False),
nn.BatchNorm2d(self.n))
self.block3 = nn.Sequential(
nn.Conv2d(self.n, self.n, kernel_size=(3, 3), stride=(1,1), padding=(1, 1), bias=False),
nn.BatchNorm2d(self.n))
self.block4 = nn.Sequential(
nn.Conv2d(self.n, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False))
self.fc = nn.Linear(nb_ft, self.cfg.target_size)
def forward(self, x):
res1 = F.interpolate(x, size=(256, 256), mode='bilinear')
x = self.block1(x)
res2 = F.interpolate(x, size=(256, 256), mode='bilinear')
x = self.block2(res2)
x += res2
if self.r > 1:
for _ in range(self.r):
res2 = x
x = self.block2(x)
x += res2
x = self.block3(x)
x += res2
x = self.block4(x)
x += res1
x = self.cnn(x)
x = self.fc(x)
return x
https://www.kushajveersingh.com/blog/data-augmentation-with-resizer-network-for-image-classification;
https://github.com/KushajveerSingh/resize_network_cv