import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import *
from datasets import *
import torch.nn as nn
import torch.nn.functional as F
import torch
1. itertools库
迭代器(生成器)在Python中是一种很常用也很好用的数据结构,比起列表(list)来说,迭代器最大的优势就是延迟计算,按需使用,从而提高开发体验和运行效率,以至于在Python 3中map,filter等操作返回的不再是列表而是迭代器。
2. datetime库
datetime.date.today() 打印输出当前的系统日期
datetime.date.fromtimestamp(time.time()) 将时间戳转成日期格式
datetime.datetime.now() 打印当前的系统时间
current_time.replace(2016,5,12) 返回当前时间,但指定的值将被替换 datetime.datetime.strptime(“21/11/06 16:30”, “%d/%m/%y %H:%M”) 将字符串转换成日期格式
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--dataset_name', type=str, default="facades", help='name of the dataset')
parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--decay_epoch', type=int, default=100, help='epoch from which to start lr decay')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--img_height', type=int, default=256, help='size of image height')
parser.add_argument('--img_width', type=int, default=256, help='size of image width')
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=500, help='interval between sampling of images from generators')
parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints')
opt = parser.parse_args()
print(opt)
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100
均方损失函数
公式为:loss(xi,yi)=(xi−yi)2
>>> loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
>>> input = torch.autograd.Variable(torch.randn(3,4))
>>> target = torch.autograd.Variable(torch.randn(3,4))
>>> loss = loss_fn(input, target)
>>> print(input); print(target); print(loss)
tensor([[-1.3524, 0.5194, 1.0586, 0.1549],
[ 1.6697, 0.4262, 0.0257, 1.1458],
[ 0.6460, 0.3691, 0.5229, -2.1614]])
tensor([[-0.8691, 1.7308, 1.0579, 0.2359],
[-0.3626, -0.7589, -0.0547, 0.9764],
[ 0.3606, -0.5090, -0.9875, -0.6050]])
tensor([[ 2.3352e-01, 1.4676e+00, 4.3082e-07, 6.5639e-03],
[ 4.1302e+00, 1.4044e+00, 6.4638e-03, 2.8681e-02],
[ 8.1454e-02, 7.7096e-01, 2.2813e+00, 2.4222e+00]])
公式为:loss(xi,yi)=|xi−yi|
>>> import torch
>>> loss_fn = torch.nn.L1Loss(reduce=False, size_average=False)
>>> input = torch.autograd.Variable(torch.randn(3,4))
>>> target = torch.autograd.Variable(torch.randn(3,4))
>>> loss = loss_fn(input, target)
>>> print(input); print(target); print(loss)
tensor([[-0.2028, 1.0140, -0.9712, 1.6227],
[ 1.0678, -1.3599, -1.1543, 1.6353],
[-0.1146, -0.2229, 0.1262, -0.8661]])
tensor([[-0.7508, -0.7450, 0.0223, -0.8037],
[ 1.3009, 0.3976, -0.3933, 0.6665],
[ 0.0281, 1.9780, -1.6017, -1.6238]])
tensor([[ 0.5479, 1.7590, 0.9935, 2.4265],
[ 0.2331, 1.7575, 0.7610, 0.9688],
[ 0.1427, 2.2009, 1.7279, 0.7577]])
# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()
if opt.epoch != 0:
# Load pretrained models
generator.load_state_dict(torch.load('saved_models/%s/generator_%d.pth' % (opt.dataset_name, opt.epoch)))
discriminator.load_state_dict(torch.load('saved_models/%s/discriminator_%d.pth' % (opt.dataset_name, opt.epoch)))
else:
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
pytorch 提供了 state_dict() 和 load_state_dict() 两个参数用来保存和加载模型参数, 前者将模型参数保存为字典形式, 后者将字典形式的模型参数载入到模型当中.
1. 首先, 读取当前模型参数
model_dict = model.state_dict()
2. 读取预训练模型, 并选取要保留的部分
pre_dict = torch.load('path')
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
3. 使用预训练的模型更新当前模型参数
model_dict.update(pre_dict)
4. 加载模型参数
model.load_state_dict(model_dict)
generator.load_state_dict(torch.load('saved_models/%s/generator_%d.pth' % (opt.dataset_name, opt.epoch)))
该句将步骤结合在一起
# Configure dataloaders
transforms_ = [ transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),
batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
val_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode='val'),
batch_size=10, shuffle=True, num_workers=1)
def sample_images(batches_done):
"""Saves a generated sample from the validation set"""
imgs = next(iter(val_dataloader))
real_A = Variable(imgs['B'].type(Tensor))
real_B = Variable(imgs['A'].type(Tensor))
fake_B = generator(real_A)
img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)
第24000次图片:
源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/pix2pix