torch alexnet

# -*- coding: utf-8 -*-
"""
Created on Tue Mar  5 17:24:47 2019

@author: Admin
"""
import torch
import torch.nn as nn
from layers import MaskedConv2d
class Net(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        
        self.conv1 = MaskedConv2d(in_channels=3,out_channels=96,kernel_size=2,stride=1,padding=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.conv2 = MaskedConv2d(in_channels=96,out_channels=256,kernel_size=2,stride=1,padding=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(2)

        self.conv3 = MaskedConv2d(in_channels=256,out_channels=384,kernel_size=3,stride=1,padding=1)
        self.relu3 = nn.ReLU(inplace=True)
        
        self.conv4 = MaskedConv2d(in_channels=384,out_channels=384,kernel_size=3,stride=1,padding=1)
        self.relu4 = nn.ReLU(inplace=True)
        
        self.conv5 = MaskedConv2d(in_channels=384,out_channels=256,kernel_size=2,stride=1,padding=1)
        self.relu5 = nn.ReLU(inplace=True)
        self.maxpool5 = nn.MaxPool2d(2)
        

        self.fc1 = nn.Linear(4 * 4 * 256, 4096)
        self.fc2 = nn.Linear(4096, 1024)
        self.out = nn.Linear(1024, 10)
        
    def forward(self, inputs):
        network = self.conv1(inputs)
        network = self.conv2(network)
        network = self.conv3(network)
        network = self.conv4(network)
        network = self.conv5(network)
        network = network.view(network.size(0), -1)
        network = self.fc1(network)
        network = self.fc2(network)
        out = self.out(network)
        return out, network
    
    def set_masks(self, masks):
        # Should be a less manual way to set masks
        # Leave it for the future
        self.conv1.set_mask(torch.from_numpy(masks[0]))
        self.conv2.set_mask(torch.from_numpy(masks[1]))
        self.conv3.set_mask(torch.from_numpy(masks[2]))
        self.conv4.set_mask(torch.from_numpy(masks[3]))
        self.conv5.set_mask(torch.from_numpy(masks[4]))

你可能感兴趣的:(PyTorch)