谱归一化(Spectral Normalization)的理解

《Spectral Normalization for Generative Adversarial Networks》【1】是Takeru Miyato在2018年2月发表的一篇将谱理论应用于Gan上的文章,在2017年,本文的第3作者Yuichi Yoshida就发表了一篇著名的谱范数正则(Spectral Norm Regularization)的文章【2】,如有兴趣也可参看我的上一篇Blog:https://blog.csdn.net/StreamRock/article/details/83539937
【1】、【2】两篇文章从不同的角度讨论了:参数矩阵的谱范数对多层神经网络的泛化的影响,并分别给出了两个不同的应对方法:前者对Discriminator矩阵参数进行归一化处理,后者可以加入任意多层网络(在更新梯度时加入了谱范数正则项)。本文将在【1】的阅读理解基础上,探讨其实现的方法。

一、Gan的Lipschitz稳定性约束

Gan好是好,但训练难,主要体现在:1)模式坍塌,即最后生成的对象就只有少数几个模式;2)不收敛,在训练过程中,Discriminator很早就进入了理想状态,总能perfectly分辨出真假,因此无法给Generator提供梯度信息,而导致训练无法进行下去。Martin Arjovsky在《Towards principled methods for training generative adversarial networks》【4】、《Wasserstein GAN》【5】文章中,对Gan难训练的原因做了详细的讨论,并给出一种新的Loss定义,即Wasserstein Distance:
W ( P r , P g ) = inf ⁡ γ ∈ ∏ ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] ( 1 ) W(P_r,P_g)=\inf_{\gamma\in\prod(P_r,P_g)}E_{(x,y)\sim \gamma}[\Vert x-y\Vert]\qquad(1) W(Pr,Pg)=γ(Pr,Pg)infE(x,y)γ[xy](1)
实际Wasserstein Distance的计算是通过它的变形来完成的:
W ( P r , P g ) = sup ⁡ ∥ f ∥ L i p E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] ( 2 ) W(P_r,P_g)=\sup_{\Vert f \Vert_{Lip}}E_{x∼P_r}[f(x)]−E_{x∼P_g}[f(x)]\qquad(2) W(Pr,Pg)=fLipsupExPr[f(x)]ExPg[f(x)](2)
(2)式只要求 f ( ⋅ ) f(\cdot) f() 满足Lipschitz约束即可,在Gan中,判别器的映射函数可充当(2)式中的 f ( ⋅ ) f(\cdot) f() ,于是加入此一约束的Gan网络有了一个新的名称:WGan。
引入Wasserstein Distance,将传统Gan转变为WGan是有许多好处的,因为Wasserstein Distance具有如下优点:
1、 W ( P r , P g ) ≥ 0 W(P_r,P_g)\ge0 W(Pr,Pg)0, 等号在 P r , P g P_r,P_g Pr,Pg分布完全重合时成立;
2、 W ( P r , P g ) W(P_r,P_g) W(Pr,Pg)是对称的,较常用的 KL Divergence 的不对称,有优势;
3、即使两个分布 P r , P g P_r,P_g Pr,Pg 的支撑不相交,亦可以作为衡量差异的距离,并在满足一定条件下可微,具备了后向传输的能力。
当 WGan 的 Discriminator 采用了这种距离来训练后,可以消除传统Gan训练时出现的收敛问题,使训练过程变得稳定。另外,要实施此策略也很简单,只需在传统Gan的Discriminator的参数矩阵上加上Lipschitz约束即可,其它的几乎不用改。


Lipschitz约束简单而言就是:要求在整个 f ( ⋅ ) f(\cdot) f() 的定义域内有
∥ f ( x ) − f ( x ′ ) ∥ 2 ∥ x − x ′ ∥ 2 ≤ M ( 3 ) \frac{\Vert f(x)-f(x') \Vert_2}{\Vert x-x' \Vert_2} \le M \qquad(3) xx2f(x)f(x)2M(3)
其中,M是一个常数。满足公式(3)的函数 f ( ⋅ ) f(\cdot) f(),具体表现为:函数变化不会太快,其梯度总是有限的,即使最剧烈时,也被限制在小于等于M的范围。


WGan首先提出Discriminator的参数矩阵需要满足Lipschitz约束,但其方法比较简单粗暴:直接对参数矩阵中元素进行限制,不让其大于某个值。这种方法,是可以保证Lipschitz约束的,但在削顶的同时,也破坏了整个参数矩阵的结构——各参数之间的比例关系。针对这个问题,【1】提出了一个既满足Lipschitz条件,又不用破坏矩阵结构的方法——Spectral Normalization。

二、多层神经网络的分析

为简便分析,可将Discriminator看作是多层网络,因为CNNs可看作是特殊的多层网络。对于多层网络的第n层,其输入与输出关系可以表示为:
x n = a n ( W n x n − 1 + b n ) ( 4 ) \mathbf x_n = a_n(W_n\mathbf x_{n-1}+\mathbf b_n)\qquad(4) xn=an(Wnxn1+bn)(4)
其中, a n ( ⋅ ) a_n(\cdot) an() 是该层网络的非线性激活函数,可采用ReLU; W l W_l Wl 是网络参数矩阵, b l \mathbf b_l bl 是网络的偏置,为推导方便,对 b l \mathbf b_l bl 进行省略处理,则(4)式可写为:
x n = D n W n x n − 1 ( 5 ) \mathbf x_n = D_n W_n\mathbf x_{n-1} \qquad(5) xn=DnWnxn1(5)
其中 D n D_n Dn 是对角矩阵,用于表示ReLU的作用,当其对应输入为负数时,对角元素为0;当其对应输入为正数时,对角元素为1。于是,多层神经网络(假设是N层)输入输出关系可以写成:
f ( x ) = D N W N ⋯ D 1 W 1 x ( 6 ) f(\mathbf x)=D_NW_N\cdots D_1W_1 \mathbf x \qquad(6) f(x)=DNWND1W1x(6)
Lipschitz约束是对 f ( x ) f(\mathbf x) f(x) 的梯度提出的要求:
∥ ∇ x ( f ( x ) ) ∥ 2 = ∥ D N W N ⋯ D 1 W 1 ∥ 2 ≤ ∥ D N ∥ 2 ∥ W N ∥ 2 ⋯ ∥ D 1 ∥ 2 ∥ W 1 ∥ 2 ( 7 ) \Vert \nabla_x(f(\mathbf x)) \Vert_2 = \Vert D_NW_N\cdots D_1W_1 \Vert_2\le \Vert D_N \Vert_2 \Vert W_N\Vert_2\cdots \Vert D_1\Vert_2 \Vert W_1 \Vert_2 \qquad(7) x(f(x))2=DNWND1W12DN2WN2D12W12(7)
此处 ∥ W ∥ \Vert W \Vert W 表示矩阵W的谱范数,它的定义如下:
σ ( A ) : = max ⁡ ∥ h ∥ ≠ 0 ∥ A h ∥ 2 ∥ h ∥ 2 = max ⁡ ∥ h ∥ = 1 ∥ A ∥ 2 ( 8 ) \sigma(A) :=\max_{\Vert h \Vert\neq0} \frac{\Vert Ah \Vert_2}{\Vert h \Vert_2}=\max_{\Vert h \Vert = 1} \Vert A \Vert_2 \qquad(8) σ(A):=h̸=0maxh2Ah2=h=1maxA2(8)
σ ( W ) \sigma(W) σ(W)是矩阵W的最大奇异值,对于对角矩阵D,有 σ ( D ) = max ⁡ ( d 1 , ⋯   , d n ) \sigma(D) =\max(d_1,\cdots,d_n) σ(D)=max(d1,,dn),即对角元素上最大的元素。由此,(7)可表示为:
∥ ∇ x ( f ( x ) ) ∥ 2 ≤ ∏ i = 1 N σ ( W i ) ( 9 ) \Vert \nabla_x(f(\mathbf x)) \Vert_2 \le \prod_{i=1}^N \sigma(W_i) \qquad(9) x(f(x))2i=1Nσ(Wi)(9)
因为,ReLU所对应的对角矩阵的谱范数最大为1。为使 f ( x ) f(\mathbf x) f(x) 满足Lipschitz约束,可对(7)进行归一化:
∥ ∇ x ( f ( x ) ) ∥ 2 = ∥ D N W N σ ( W N ) ⋯ D 1 W 1 σ ( W 1 ) ∥ 2 ≤ ∏ i = 1 N σ ( W i ) σ ( W i ) = 1 ( 10 ) \Vert \nabla_x(f(\mathbf x)) \Vert_2 = \Vert D_N \frac {W_N}{\sigma(W_N)}\cdots D_1\frac {W_1}{\sigma(W_1)} \Vert_2 \le \prod_{i=1}^N \frac {\sigma(W_i)}{\sigma(W_i)} =1\qquad(10) x(f(x))2=DNσ(WN)WND1σ(W1)W12i=1Nσ(Wi)σ(Wi)=1(10)
由此可见,只需让每层网络的网络参数除以该层参数矩阵的谱范数即可满足Lipschitz=1的约束,由此诞生了谱归一化(Spectral Normailization)。

三、谱归一化的实现

为获得每层参数矩阵的谱范数,需要求解 W i W_i Wi 的奇异值,这将耗费大量的计算资源,因而可采用“幂迭代法”来近似求取,其迭代过程如下:
1 、 v l 0 ←  a random Gaussian vector 2 、 loop k : u l k ← W l v l k − 1 ,  normalization:  u l k ← u l k ∥ u l k ∥ , v l k ← ( W l ) T u l k ,  normalization:  v l k ← v l k ∥ v l k ∥ , end loop 3 、 σ l ( W ) = ( u l k ) T W v l k 1、v_l^{0} \leftarrow \text{ a random Gaussian vector} \\ 2、\text{loop k :} \\ u_l^{k}\leftarrow W_lv_l^{k-1}, \text{ normalization: } u_l^{k}\leftarrow \frac{u_l^{k}}{\Vert u_l^{k} \Vert},\\ v_l^k\leftarrow (W_l)^Tu_l^k , \text{ normalization: } v_l^{k}\leftarrow \frac{v_l^{k}}{\Vert v_l^{k} \Vert},\\ \text{end loop} \\ 3、\sigma_l(W)= (u_l^k)^T W v_l^k 1vl0 a random Gaussian vector2loop k :ulkWlvlk1, normalization: ulkulkulk,vlk(Wl)Tulk, normalization: vlkvlkvlk,end loop3σl(W)=(ulk)TWvlk
求得谱范数后,每个参数矩阵上的参数皆除以它,以达到归一化目的。其实,上述算法在迭代了足够次数后, u k \mathbf u^k uk就是该矩阵( W W W)的最大奇异值对应的特征矢量,有:
W W T u = σ ( W ) ⋅ u ⇒ u T W W T u = 1 ⋅ σ ( W ) ,  as  ∥ u ∥ = 1 σ ( W ) = u T W v ,  as  v = W T u WW^T \mathbf u=\sigma(W)\cdot \mathbf u \Rightarrow \mathbf u^TWW^T \mathbf u = 1\cdot \sigma(W), \text{ as } \Vert \mathbf u \Vert=1\\ \sigma(W) = \mathbf u^TW\mathbf v, \text{ as } \mathbf v=W^T \mathbf u WWTu=σ(W)uuTWWTu=1σ(W), as u=1σ(W)=uTWv, as v=WTu
谱归一具体的pytorch实现代码可以参考【3】,以下摘抄部分如下:
1、计算谱范数

import torch
import torch.nn.functional as F

#define _l2normalization
def _l2normalize(v, eps=1e-12):
    return v / (torch.norm(v) + eps)

def max_singular_value(W, u=None, Ip=1):
    """
    power iteration for weight parameter
    """
    #xp = W.data
    if not Ip >= 1:
        raise ValueError("Power iteration should be a positive integer")
    if u is None:
        u = torch.FloatTensor(1, W.size(0)).normal_(0, 1).cuda()
    _u = u
    for _ in range(Ip):
        _v = _l2normalize(torch.matmul(_u, W.data), eps=1e-12)
        _u = _l2normalize(torch.matmul(_v, torch.transpose(W.data, 0, 1)), eps=1e-12)
    sigma = torch.sum(F.linear(_u, torch.transpose(W.data, 0, 1)) * _v)
    return sigma, _u

2、构造带归一化的层
线性层:

class SNLinear(Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(SNLinear, self).__init__(in_features, out_features, bias)
        self.register_buffer('u', torch.Tensor(1, out_features).normal_())

    @property
    def W_(self):
        w_mat = self.weight.view(self.weight.size(0), -1)
        sigma, _u = max_singular_value(w_mat, self.u)
        self.u.copy_(_u)
        return self.weight / sigma

    def forward(self, input):
        return F.linear(input, self.W_, self.bias)

卷积层:

class SNConv2d(conv._ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(SNConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)
        self.register_buffer('u', torch.Tensor(1, out_channels).normal_())

    @property
    def W_(self):
        w_mat = self.weight.view(self.weight.size(0), -1)
        sigma, _u = max_singular_value(w_mat, self.u)
        self.u.copy_(_u)
        return self.weight / sigma

    def forward(self, input):
        return F.conv2d(input, self.W_, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

由这两个层的构造可看到:谱范数的计算和应用谱范数的归一化层。这些层可以加到Discriminator中,如下:

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels=None, use_BN = False, downsample=False):
        super(ResBlock, self).__init__()
        #self.conv1 = SNConv2d(n_dim, n_out, kernel_size=3, stride=2)
        hidden_channels = in_channels
        self.downsample = downsample

        self.resblock = self.make_res_block(in_channels, out_channels, hidden_channels, use_BN, downsample)
        self.residual_connect = self.make_residual_connect(in_channels, out_channels)
    def make_res_block(self, in_channels, out_channels, hidden_channels, use_BN, downsample):
        model = []
        if use_BN:
            model += [nn.BatchNorm2d(in_channels)]

        model += [nn.ReLU()]
        model += [SNConv2d(in_channels, hidden_channels, kernel_size=3, padding=1)]
        model += [nn.ReLU()]
        model += [SNConv2d(hidden_channels, out_channels, kernel_size=3, padding=1)]
        if downsample:
            model += [nn.AvgPool2d(2)]
        return nn.Sequential(*model)
    def make_residual_connect(self, in_channels, out_channels):
        model = []
        model += [SNConv2d(in_channels, out_channels, kernel_size=1, padding=0)]
        if self.downsample:
            model += [nn.AvgPool2d(2)]
            return nn.Sequential(*model)
        else:
            return nn.Sequential(*model)

    def forward(self, input):
        return self.resblock(input) + self.residual_connect(input)

class OptimizedBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OptimizedBlock, self).__init__()
        self.res_block = self.make_res_block(in_channels, out_channels)
        self.residual_connect = self.make_residual_connect(in_channels, out_channels)
    def make_res_block(self, in_channels, out_channels):
        model = []
        model += [SNConv2d(in_channels, out_channels, kernel_size=3, padding=1)]
        model += [nn.ReLU()]
        model += [SNConv2d(out_channels, out_channels, kernel_size=3, padding=1)]
        model += [nn.AvgPool2d(2)]
        return nn.Sequential(*model)
    def make_residual_connect(self, in_channels, out_channels):
        model = []
        model += [SNConv2d(in_channels, out_channels, kernel_size=1, padding=0)]
        model += [nn.AvgPool2d(2)]
        return nn.Sequential(*model)
    def forward(self, input):
        return self.res_block(input) + self.residual_connect(input)

class SNResDiscriminator(nn.Module):
    def __init__(self, ndf=64, ndlayers=4):
        super(SNResDiscriminator, self).__init__()
        self.res_d = self.make_model(ndf, ndlayers)
        self.fc = nn.Sequential(SNLinear(ndf*16, 1), nn.Sigmoid())
    def make_model(self, ndf, ndlayers):
        model = []
        model += [OptimizedBlock(3, ndf)]
        tndf = ndf
        for i in range(ndlayers):
            model += [ResBlock(tndf, tndf*2, downsample=True)]
            tndf *= 2
        model += [nn.ReLU()]
        return nn.Sequential(*model)
    def forward(self, input):
        out = self.res_d(input)
        out = F.avg_pool2d(out, out.size(3), stride=1)
        out = out.view(-1, 1024)
        return self.fc(out)

生成器SNResDiscriminator 用到两个构建模块ResBlock、OptimizedBlock,这两个模块都用SNConv2d层来构建带有谱归一化的卷积层。在SNConv2d实现中,用到@property def W_(self),是我第一次见到的,接下来要好好研究研究。

小结:

Gan要想训练稳定进行,就需要其Discriminator的映射函数满足Lipschitz约束,[1]提出谱范数可作为Lipschitz约束的实施方法,进而给出归一化的实现思路,整个过程十分精巧,值得学习。


[1] Spectral Normalization for Generative Adversarial Networks, Takeru Miyato, 2018.2, (arXiv:1802.05957v1)
[2] Spectral Norm Regularization for Improving the Generalizability of Deep Learning, Yuchi Yoshida, National Institute of Informatics, 2017. 5, (arXiv: 1705.10941v1)
[3] https://github.com/godisboy/SN-GAN
[4] Towards principled methods for training generative adversarial networks
[5] Wasserstein GAN

你可能感兴趣的:(机器学习与神经网络,Gan,Lipschitz,谱范数)