【论文笔记】A Neural Representation of Sketch Drawings

谷歌的论文,基于seq2seq+VAE编码并生成手绘序列
https://arxiv.org/pdf/1704.03477.pdf
本文主要是论文的概述翻译,记录

文章目录

    • 1.Introduction
    • 2.Related Work
    • 3.方法
      • 3.1 数据集
      • 3.2 Sketch-RNN
      • 3.3Unconditional Generation
      • 3.4 Training
    • 4.Experiments
      • 4.1 Conditional Reconstruction
      • 4.2 Latent Space Interpolation
      • 4.3 Sketch Drawing Analogies(类比)
      • 4.4 Predicting Different Endings of Incomplete Sketches
    • 5.Applications and Future Work;6.结论;略
    • 附录

1.Introduction

  • 生成模型的发展:GANs(Generative Adversarial Networks)、VI(Variational Inference)、AR(Autoregressive)
  • 当前多用于处理图像像素数据(pixel images),而人的理解是矢量序列的,本文即提出用于矢量图像的生成模型(手绘草图)
  • 文本贡献:适用于线序列的条件/非条件生成模型框架;提出的sketch-RNN可以生成矢量格式的有意义的图像;开发了一种可以使训练更鲁棒的方法;将矢量图映射到了潜在空间;最后讨论了本文可能的应用领域

2.Related Work

  • 对于模仿绘画,有通过既定的文件执行绘画的机器人、和基于强化学习的方法,并非生成
  • 神经网络用于生成的多是栅格图像;早期对于线的有HMM模型方法;最近有基于RNN的Mixture Densify Network及其改进的方法用于生成连续数据点、和汉字
  • 最近有用Sequence-to-Sequence 模型结合VAE(变分自编码器)的用于英语语言编码到潜在空间的研究
  • 最后提了一些现有的公开数据集

3.方法

3.1 数据集

  • 取自QuickDraw应用,有20s以内绘制出的草图,上百类,每类有70k的训练样本,及2.5k验证,2.5k测试
  • 数据序列组织为[dx,dy,p1,p2,p3],分别表示x、y方向的变化,p1、p2、p3表示继续绘制、结束子序列、结束绘图三个状态
    【论文笔记】A Neural Representation of Sketch Drawings_第1张图片

3.2 Sketch-RNN

【论文笔记】A Neural Representation of Sketch Drawings_第2张图片

  • 基本结构为Sequence-to-Sequence 变分自编码器
  • 其中编码器用双向RNN,以草图作为输入,潜在空间向量作为输出(用常规的VAE,先出 μ \mu μ σ \sigma σ,再按正态分布算 z z z)。

h → = e n c o d e → ( S ) , h ← = e n c o d e ← ( S r e v e r s e ) , h = [ h → ; h ← ] h_\rightarrow=encode_\rightarrow(S), h_\leftarrow=encode_\leftarrow(S_{reverse}), h=[h_\rightarrow;h_\leftarrow] h=encode(S),h=encode(Sreverse),h=[h;h]
μ = W μ h + b μ , σ ^ = W σ h + b σ , σ = e x p ( σ ^ 2 ) , z = μ + σ ⊙ N ( 0 , 1 ) \mu=W_{\mu}h+b_{\mu}, \hat{\sigma}=W_{\sigma}h+b_{\sigma}, \sigma=exp(\frac{\hat{\sigma}}{2}), z=\mu+\sigma \odot \mathcal{N}(0,1) μ=Wμh+bμ,σ^=Wσh+bσ,σ=exp(2σ^),z=μ+σN(0,1)

  • 解码器用自回归RNN,以序列后一点作为当前点的输出;由于之前是双向RNN编码,所以z先过一个tanh得到解码器的初始状态
    [ h 0 ; c 0 ] = t a n h ( W z z + b z ) [h_0;c_0]=tanh(W_zz+b_z) [h0;c0]=tanh(Wzz+bz)

  • S 0 S_0 S0定义为(0,0,1,0,0)

  • (dx,dy)通过M元正态分布的高斯混合模型(GMM)计算概率,(q1,q2,q3)作为类别来计算,M也是一个类别分布,是GMM的混合权重
    p ( △ x , △ y ) = ∑ j = 1 M N ( △ x , △ y ∣ μ x , j , μ y , j , σ x , j , σ y , j , ρ x y , j ) , w h e r e   ∑ j = 1 M Π j = 1 p(\triangle x,\triangle y)=\sum_{j=1}^{M} \mathcal{N}(\triangle x,\triangle y | \mu_{x,j},\mu_{y,j},\sigma_{x,j},\sigma_{y,j},\rho_{xy,j}),where \ \sum_{j=1}^{M}\Pi_j=1 p(x,y)=j=1MN(x,yμx,j,μy,j,σx,j,σy,j,ρxy,j),where j=1MΠj=1

  • 因此,解码器的输出维度为5M+M+3,即6M+3维
    x i = [ S i − 1 ; z ] , [ h i ; c i ] = f o r w a r d ( x i , [ h i − 1 ; c i − 1 ] ) , y i = W y h i + b y , y i ∈ R 6 M + 3   [ ( Π ^ μ x μ y σ ^ x σ ^ y ρ ^ x y ) 1 . . . ( Π ^ μ x μ y σ ^ x σ ^ y ρ ^ x y ) M ( q 1 ^ q 2 ^ q 3 ^ ) ] = y i x_i=[S_{i-1};z],[h_i;c_i]=forward(x_i,[h_{i-1};c_{i-1}]),y_i=W_yh_i+b_y,y_i\in \mathbb{R}^{6M+3} \\ \ \\ [(\hat{\Pi} \mu_x \mu_y \hat{\sigma}_x \hat{\sigma}_y \hat{\rho}_xy)_1...(\hat{\Pi} \mu_x \mu_y \hat{\sigma}_x \hat{\sigma}_y \hat{\rho}_xy)_M(\hat{q_1}\hat{q_2}\hat{q_3})]=y_i xi=[Si1;z],[hi;ci]=forward(xi,[hi1;ci1]),yi=Wyhi+by,yiR6M+3 [(Π^μxμyσ^xσ^yρ^xy)1...(Π^μxμyσ^xσ^yρ^xy)M(q1^q2^q3^)]=yi

  • 为了使标准差值非负,使用exp和tanh来约束为-1~1之间
    σ x = exp ⁡ ( σ ^ x ) , σ y = exp ⁡ ( σ ^ y ) , ρ x y = tanh ⁡ ( ρ ^ x y ) \sigma_x=\exp(\hat{\sigma}_x),\sigma_y=\exp(\hat{\sigma}_y),\rho_{xy}=\tanh(\hat{\rho}_{xy}) σx=exp(σ^x),σy=exp(σ^y),ρxy=tanh(ρ^xy)

  • 类别分布概率计算为
    q k = exp ⁡ ( q ^ k ) ∑ j = 1 3 exp ⁡ ( q ^ j ) , k ∈ { 1 , 2 , 3 } Π k = exp ⁡ ( Π ^ k ) ∑ j = 1 M exp ⁡ ( Π ^ J ) , k ∈ { 1 , . . . , M } q_k=\frac{\exp(\hat{q}_k)}{\sum_{j=1}^{3}\exp(\hat{q}_j)},k\in\{1,2,3\}\\ \Pi_k=\frac{\exp(\hat{\Pi}_k)}{\sum_{j=1}^{M}\exp(\hat{\Pi}_J)},k\in\{1,...,M\} qk=j=13exp(q^j)exp(q^k),k{1,2,3}Πk=j=1Mexp(Π^J)exp(Π^k),k{1,...,M}

  • 本方法存在着p1,p2,p3状态数据不平衡问题,通用的方法是样本加权,但这样并不适用于多类别数据集,本文的解决方法是设定最大长度,实际结束后的都用(0,0,0,0,1)来标记

  • 在训练阶段,我们每次获取本时间步的结果。而在生成阶段,我们将本时间步的输出结果作为下一时间步的输入,直到输出的p3=1或达到最大长度为止

  • 设置了一个温度参数 τ \tau τ,来增加序列的随机性, τ \tau τ取值在0~1之间,约接近0,模型结果越确定。

3.3Unconditional Generation

  • 我们可以只训练模型的解码器,没有编码器、没有输入、没有潜在空间向量,设置初始隐藏状态为0,那么会得到一个纯生成的模型

3.4 Training

  • 模型采用变分自编码器的方法,其损失函数由重建损失 L R L_R LR和KL散度损失 L K L L_{KL} LKL组成。

  • 对于重建损失,分别为偏移量 ( △ x , △ y ) (\triangle x,\triangle y) (x,y)的对数损失 L s L_s Ls和画笔状态 ( p 1 , p 2 , p 3 ) (p_1,p_2,p_3) (p1,p2,p3)的对数损失 L p L_p Lp(注: N s N_s Ns为序列实际长度)
    L s = − 1 N max ⁡ ∑ i = 1 N s log ⁡ ( ∑ j = 1 M Π j , i N ( △ x i , △ y i ∣ μ x , j , i , μ y , j , i , σ x , j , i , σ y , j , i , ρ x y , j , i ) ) L p = − 1 N max ⁡ ∑ i = 1 N m a x ∑ k = 1 3 p k , i log ⁡ ( q k , i ) , L R = L s + L p L_s=-\frac{1}{N_{\max}} \sum_{i=1}^{N_s} \log(\sum_{j=1}^{M} \Pi_{j,i} \mathcal{N}(\triangle x_i,\triangle y_i|\mu_{x,j,i},\mu_{y,j,i},\sigma_{x,j,i},\sigma_{y,j,i},\rho_{xy,j,i}))\\ L_p=-\frac{1}{N_{\max}} \sum_{i=1}^{N_{max}}\sum_{k=1}^{3} p_{k,i} \log(q_{k,i}),L_R=L_s+L_p Ls=Nmax1i=1Nslog(j=1MΠj,iN(xi,yiμx,j,i,μy,j,i,σx,j,i,σy,j,i,ρxy,j,i))Lp=Nmax1i=1Nmaxk=13pk,ilog(qk,i),LR=Ls+Lp

  • KL散度损失度量的是潜在向量 z z z和独立同分布的高斯向量之间的差异,(可以使不同草图在潜在空间中距离更近,使插值有意义)
    L K L = − 1 2 N z ( 1 + σ ^ − μ 2 − exp ⁡ ( σ ^ ) ) L o s s = L R + w K L L K L L_{KL}=-\frac{1}{2N_z} (1+\hat{\sigma}-\mu^2-\exp(\hat{\sigma}) )\\ Loss=L_R+w_{KL}L_{KL} LKL=2Nz1(1+σ^μ2exp(σ^))Loss=LR+wKLLKL

4.Experiments

  • 分别尝试了多类别和单类别,以及不同 w K L w_{KL} wKL
  • 编码器用双向LSTM,解码器用HyperLSTM

4.1 Conditional Reconstruction

  • 单独训练猫/猪的数据,设置不同的温度 τ \tau τ τ \tau τ越小重建约稳定;另外模型训练后能起到一定的修正作用;即使是输入个牙刷,重建时也会同时保留二者的特征
    【论文笔记】A Neural Representation of Sketch Drawings_第3张图片

4.2 Latent Space Interpolation

4.3 Sketch Drawing Analogies(类比)

  • 通过在潜在空间插值得到草图的变化过程,且设置更高的 w K L w_{KL} wKL能够产生更好的数据流形关系;
    【论文笔记】A Neural Representation of Sketch Drawings_第4张图片

4.4 Predicting Different Endings of Incomplete Sketches

  • 一个应用点,根据初始的几笔,来补充完整的草图

5.Applications and Future Work;6.结论;略

附录

  • 数据预处理时,将偏移量 ( △ x , △ y ) (\triangle x,\triangle y) (x,y)缩放为方差为1的大小。不执行0均值操作(因为均值本身就很小)。

  • 在计算KL散度损失时,引入退火算法,效果更好

  • 模型设置方面,编码器512个神经元,解码器2048个。用M=20的混合组成;设置recurrent dropout保留90%;batch_size=100;Adam,学习率为0.0001,梯度裁剪为1;KL_{min}=0.2,R=0.99999(模拟退火的参数?)

  • 点的数量不能多于300个,本文用了道格拉斯普克法算法将数据点压缩到200个以下

  • 对于复杂的图像,重建效果较差,且更倾向于圆滑的效果

  • 类别数不宜过多

  • 其他略

你可能感兴趣的:(深度学习,论文笔记)