Effective Receptive Field (ERFs)

import torch
from torch import nn
from torchvision.utils import make_grid
from torch.nn import init

import matplotlib.pyplot as plt


class Model(nn.Module):

    def __init__(self, num_layers=18):
        super(Model, self).__init__()

        convs = [
            nn.Conv2d(1, 1, 3, padding=1, bias=False)
            for i in range(num_layers)
        ]
        for conv in convs:
            init.constant_(conv.weight, 1)

        self.layers = nn.Sequential(*convs)

    def forward(self, x):
        return self.layers(x)


model = Model()
x = torch.ones(1, 1, 256, 256)
x.requires_grad = True
y = model(x)
mask = torch.zeros_like(y, dtype=torch.float)
mask[0, 0, 128, 128] = 1
y = y * mask
torch.autograd.backward(y, x)

image = torch.cat([x, y, x.grad], dim=0)
image = make_grid(image)
# image = image / image.max()
image = image.permute(1, 2, 0)
plt.imshow(image)
plt.show()


你可能感兴趣的:(Effective Receptive Field (ERFs))