starGAN原理代码分析

下载:

git clone https://github.com/yunjey/StarGAN.git
cd StarGAN/

下载celebA训练数据:

bash download.sh

训练:

python main.py --mode='train' --dataset='CelebA' --c_dim=5 --image_size=128 \
                 --sample_path='stargan_celebA/samples' --log_path='stargan_celebA/logs' \
                 --model_save_path='stargan_celebA/models' --result_path='stargan_celebA/results'

代码分析

生成网络

第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,

layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
layers.append(nn.ReLU(inplace=True))

2个卷积层,stride=2,即下采样,

# Down-Sampling
curr_dim = conv_dim
for i in range(2):
    layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
    layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))
    layers.append(nn.ReLU(inplace=True))
    curr_dim = curr_dim * 2

残差层,

# Bottleneck
for i in range(repeat_num):
    layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

残差网络结构,

class ResidualBlock(nn.Module):
    """Residual Block."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True))

    def forward(self, x):
        return x + self.main(x)

上采样,

# Up-Sampling
for i in range(2):
    layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
    layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
    layers.append(nn.ReLU(inplace=True))
    curr_dim = curr_dim // 2

最后一层,得到输出维度为3,

layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)

对于输入图像x,label向量c,串联如下,

def forward(self, x, c):
    # replicate spatially and concatenate domain information
    c = c.unsqueeze(2).unsqueeze(3)
    c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
    x = torch.cat([x, c], dim=1)
    return self.main(x)

判别网络

判别网络输入为图像,用于判别输入图像真假,已经输入图像的类别,

class Discriminator(nn.Module):
    """Discriminator. PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()

        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

        k_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)

    def forward(self, x):
        h = self.main(x)
        out_real = self.conv1(h)
        out_aux = self.conv2(h)
        return out_real.squeeze(), out_aux.squeeze()

conv1输出维度为1,即判别输入的真假,conv2输出维度为c_dim,即判别输入图像的label.

训练数据,损失函数,参数更新,

输入包括

real_x,real_c,fake_c

fake_c为随机生成的,

# Generat fake labels randomly (target domain labels)
rand_idx = torch.randperm(real_label.size(0))
fake_label = real_label[rand_idx]
if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()

训练判别网络

将真实图像输入判别网络,

# Compute loss with real images
out_src, out_cls = self.D(real_x)
d_loss_real = - torch.mean(out_src)

判别网络的输入为真实图像,输出out_cls为真实图像对应的标签的概率,则可以计算交叉损失熵,

if self.dataset == 'CelebA':
    d_loss_cls = F.binary_cross_entropy_with_logits(
        out_cls, real_label, size_average=False) / real_x.size(0)

将真实图像输入real_x和假的标签fake_c输入生成网络,得到生成图像fake_x,

fake_x = self.G(real_x, fake_c)

将生成图像输入判别网络,

fake_x = Variable(fake_x.data)
out_src, out_cls = self.D(fake_x)
d_loss_fake = torch.mean(out_src)

总的损失函数为,

# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls

根据d_loss更新判别网络参数,

# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()

计算梯度惩罚因子alpha,根据alpha结合real_x,fake_x,输入判别网络,计算梯度,得到梯度损失函数,

# Compute gradient penalty
alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
out, out_cls = self.D(interpolated)

grad = torch.autograd.grad(outputs=out,
                           inputs=interpolated,
                           grad_outputs=torch.ones(out.size()).cuda(),
                           retain_graph=True,
                           create_graph=True,
                           only_inputs=True)[0]

grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1)**2)

根据梯度损失函数d_loss_gp优化判别网路,

# Backward + Optimize
d_loss = self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()

训练生成网络

生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像.

将原图像输入生成网络,得到生成图像fake_x,同时将fake_x图像输入生成网络,希望生成的图像与真实图像尽量相似,

# Original-to-target and target-to-original domain
fake_x = self.G(real_x, fake_c)
rec_x = self.G(fake_x, real_c)
# Compute losses
g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

将fake_x输入判别网路,

out_src, out_cls = self.D(fake_x)
g_loss_fake = - torch.mean(out_src)

计算损失函数,

g_loss_fake = - torch.mean(out_src)

对于fake_x,对应的label为fake_label,将fake_x输入判别网络,判别网络预测label概率为out_cls,因此可以计算交叉损失熵,

g_loss_cls = F.binary_cross_entropy_with_logits(
    out_cls, fake_label, size_average=False) / fake_x.size(0)

生成网络参数更新,

# Backward + Optimize
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()

训练数据处理

以celebA数据为例,下载后的数据包括label文件,和图像.

文件的第一行为图像的总数,为202599.

第二行为数据处理的类别,包括40种,

5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young

第三行及之后的每行为,图像名,已经对应的40种类别的label,label值为1或-1,之后提取为值1为1,-1为0.

000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1

list_attr_celeba.txt文件提取函数为,

def preprocess(self):
    attrs = self.lines[1].split()
    for i, attr in enumerate(attrs):
        self.attr2idx[attr] = i
        self.idx2attr[i] = attr

    self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
    self.train_filenames = []
    self.train_labels = []
    self.test_filenames = []
    self.test_labels = []

    lines = self.lines[2:]#the image and labels
    random.shuffle(lines)   # random shuffling
    for i, line in enumerate(lines):

        splits = line.split()
        filename = splits[0]#image name
        values = splits[1:]# labels

        label = []
        for idx, value in enumerate(values):
            attr = self.idx2attr[idx]# there are 40 classes,find the idx th class name

            if attr in self.selected_attrs:#check if the attr in the selected classes
                if value == '1':#if the ckss label is 1 then label equal 2,otherwise,0
                    label.append(1)
                else:
                    label.append(0)

        if (i+1) < 2000:
            self.test_filenames.append(filename)
            self.test_labels.append(label)
        else:
            self.train_filenames.append(filename)
            self.train_labels.append(label)

self.selected_attrs表示我们训练选用的任务类别集合.最后得到图像名数组self.train_filenames,及其对应的label数组 self.train_labels.

之后采用from torch.utils.data import DataLoader加载训练数据,

data_loader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         shuffle=shuffle)
fixed_x = []
real_c = []
for i, (images, labels) in enumerate(self.data_loader):
    fixed_x.append(images)
    real_c.append(labels)
    if i == 3:
        break

读取后的图像数组为fixed_x,lable为real_c.图像为(bath_size,c_dim,imagesize,imagesize),label为(batch_size,len(self.selected_attrs)).

得到固定的输入图像数组,label,labelist,用于sample.

# Fixed inputs and target domain labels for debugging
fixed_x = torch.cat(fixed_x, dim=0)#4*batch_szie,(64,3,128,128)
fixed_x = self.to_var(fixed_x, volatile=True)
real_c = torch.cat(real_c, dim=0)

if self.dataset == 'CelebA':
    fixed_c_list = self.make_celeb_labels(real_c)

labellist生成函数为,

def make_celeb_labels(self, real_c):
    """Generate domain labels for CelebA for debugging/testing.

    if dataset == 'CelebA':
        return single and multiple attribute changes
    elif dataset == 'Both':
        return single attribute changes
    """
    y = [torch.FloatTensor([1, 0, 0]),  # black hair
         torch.FloatTensor([0, 1, 0]),  # blond hair
         torch.FloatTensor([0, 0, 1])]  # brown hair

    fixed_c_list = []

    # single attribute transfer
    for i in range(self.c_dim):
        fixed_c = real_c.clone()
        for c in fixed_c:
            if i < 3:
                c[:3] = y[i]
            else:
                c[i] = 0 if c[i] == 1 else 1   # opposite value
        fixed_c_list.append(self.to_var(fixed_c, volatile=True))

    # multi-attribute transfer (H+G, H+A, G+A, H+G+A)
    if self.dataset == 'CelebA':
        for i in range(4):
            fixed_c = real_c.clone()
            for c in fixed_c:
                if i in [0, 1, 3]:   # Hair color to brown
                    c[:3] = y[2] 
                if i in [0, 2, 3]:   # Gender
                    c[3] = 0 if c[3] == 1 else 1
                if i in [1, 2, 3]:   # Aged
                    c[4] = 0 if c[4] == 1 else 1
            fixed_c_list.append(self.to_var(fixed_c, volatile=True))
    return fixed_c_list

fixed_c_list长度为c_dim+4=5+4=9,

训练的时候,fake_label为随机产生0-batch_size的索引,并由索引,从real_label取值,

# Start training
start_time = time.time()
for e in range(start, self.num_epochs):
    for i, (real_x, real_label) in enumerate(self.data_loader):

        # Generat fake labels randomly (target domain labels)
        rand_idx = torch.randperm(real_label.size(0))
        fake_label = real_label[rand_idx]

        if self.dataset == 'CelebA':
            real_c = real_label.clone()
            fake_c = fake_label.clone()
        else:
            real_c = self.one_hot(real_label, self.c_dim)
            fake_c = self.one_hot(fake_label, self.c_dim)

        # Convert tensor to variable
        real_x = self.to_var(real_x)#(16,3,128,128)
        real_c = self.to_var(real_c) #(16,5)          # input for the generator
        fake_c = self.to_var(fake_c)#(16,5)
        real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
        fake_label = self.to_var(fake_label)

你可能感兴趣的:(深度学习,图像处理)