这篇文章将介绍怎么使用hypernetworks来完成一些实验,本实验基于https://github.com/g1910/HyperNetworks.git
PrimaryNetwork
是主要观察的类,主要观察.forward
中如何生成参数部分。
class PrimaryNetwork(nn.Module):
def __init__(self, z_dim=64):
super(PrimaryNetwork, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.z_dim = z_dim
self.hope = HyperNetwork(z_dim=self.z_dim)
self.zs_size = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1],
[2, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2],
[4, 2], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4]]
self.filter_size = [[16,16], [16,16], [16,16], [16,16], [16,16], [16,16], [16,32], [32,32], [32,32], [32,32],
[32,32], [32,32], [32,64], [64,64], [64,64], [64,64], [64,64], [64,64]]
self.res_net = nn.ModuleList()
for i in range(18):
down_sample = False
if i > 5 and i % 6 == 0:
down_sample = True
self.res_net.append(ResNetBlock(self.filter_size[i][0], self.filter_size[i][1], downsample=down_sample))
self.zs = nn.ModuleList()
for i in range(36):
# 这里表示的是
self.zs.append(Embedding(self.zs_size[i], self.z_dim))
self.global_avg = nn.AvgPool2d(8)
self.final = nn.Linear(64,10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
'''
注意看这里,w1 w2是生成的权值,这个权值会用来在res_net中来参与计算。
这里是hypernetwork生成一个比较大的网络的主要部分
'''
for i in range(18):
# if i != 15 and i != 17:
w1 = self.zs[2*i](self.hope)
w2 = self.zs[2*i+1](self.hope)
x = self.res_net[i](x, w1, w2)
x = self.global_avg(x)
x = self.final(x.view(-1,64))
return x
同样重要的,还有Hypernetwork
class HyperNetwork(nn.Module):
def __init__(self, f_size = 3, z_dim = 64, out_size=16, in_size=16):
super(HyperNetwork, self).__init__()
self.z_dim = z_dim
self.f_size = f_size
self.out_size = out_size
self.in_size = in_size
self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size*self.f_size*self.f_size)).cuda(),2))
self.b1 = Parameter(torch.fmod(torch.randn((self.out_size*self.f_size*self.f_size)).cuda(),2))
self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size*self.z_dim)).cuda(),2))
self.b2 = Parameter(torch.fmod(torch.randn((self.in_size*self.z_dim)).cuda(),2))
def forward(self, z):
h_in = torch.matmul(z, self.w2) + self.b2
h_in = h_in.view(self.in_size, self.z_dim)
h_final = torch.matmul(h_in, self.w1) + self.b1
kernel = h_final.view(self.out_size, self.in_size, self.f_size, self.f_size)
return kernel
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import argparse
import torch.optim as optim
from primary_net import PrimaryNetwork
########### Data Loader ###############
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='../data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
shuffle=False, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#############################
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
############
net = PrimaryNetwork()
best_accuracy = 0.
if args.resume:
ckpt = torch.load('./hypernetworks_cifar_paper.pth')
net.load_state_dict(ckpt['net'])
best_accuracy = ckpt['acc']
net.cuda()
learning_rate = 0.002
weight_decay = 0.0005
milestones = [168000, 336000, 400000, 450000, 550000, 600000]
max_iter = 1000000
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
criterion = nn.CrossEntropyLoss()
total_iter = 0
epochs = 0
print_freq = 50
while total_iter < max_iter:
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
lr_scheduler.step()
running_loss += loss.data[0]
if i % print_freq == (print_freq-1):
print("[Epoch %d, Total Iterations %6d] Loss: %.4f" % (epochs + 1, total_iter + 1, running_loss/print_freq))
running_loss = 0.0
total_iter += 1
epochs += 1
correct = 0.
total = 0.
for tdata in testloader:
timages, tlabels = tdata
toutputs = net(Variable(timages.cuda()))
_, predicted = torch.max(toutputs.cpu().data, 1)
total += tlabels.size(0)
correct += (predicted == tlabels).sum()
accuracy = (100. * correct) / total
print('After epoch %d, accuracy: %.4f %%' % (epochs, accuracy))
if accuracy > best_accuracy:
print('Saving model...')
state = {
'net': net.state_dict(),
'acc': accuracy
}
torch.save(state, './hypernetworks_cifar_paper.pth')
best_accuracy = accuracy
print('Finished Training')