构建一个简单的VGG网络

import torch
from torch.autograd import Variable
from torch import nn

def vgg_block(num_convs, input_channels, output_channels):
    net = [
        nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1),
        nn.ReLU()
    ]
    for i in range(num_convs - 1):
        net.append(nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1))
        net.append(nn.ReLU(True))
    net.append(nn.MaxPool2d(2,2))
    return nn.Sequential(*net)

def vgg_stack(numconvs, channels):
    net = []
    for n,c in zip(numconvs, channels):
        in_c = c[0]
        out_c = c[1]
        net.append(vgg_block(n, in_c, out_c))
    return  nn.Sequential(*net)

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
print(vgg_net)
print(vgg_net[0])
test_x = Variable(torch.zeros(1, 3, 256, 256))
test_y = vgg_net(test_x)
print(test_y.shape)

你可能感兴趣的:(深度学习)