[ICLR 2018]Learning Latent Permutations with Gumbel-Sinkhorn Networks

结论

将离散的排序问题变成求取一个置换矩阵,实现ML下的排序。

The Sinkhorn Operator

现在已有证明,在有温度系数的softmax中:
s o f t m a x τ ( x ) u = e x p ( x i / τ ) ∑ j = 1 e x p ( x j / τ ) softmax_{\tau}(x)_u=\frac{exp(x_i/\tau)}{\sum_{j=1}exp(x_j/\tau)} softmaxτ(x)u=j=1exp(xj/τ)exp(xi/τ)
τ → 0 + \tau \to 0^+ τ0+,上式的结果就会变成x中最大值的one-hot结果。
粗糙的理解是,分母接近0的时候会无限放大整个值,也会无限放大各值之间的差距。因为是反比例函数,所以数值越大放大得越快,最后导致相比最大的数的结果,其他的结果都会变得相对非常小。再加上softmax一归一,就出现了最大值的位置非常接近1,其他接近0的情况。

定义Sinkhorn Operator为 S ( X ) S(X) S(X),其中 X X X N N N维度的方阵。
S 0 ( X ) = e x p ( X ) S^0(X)=exp(X) S0(X)=exp(X)
S l ( X ) = τ c ( τ r ( S l − 1 ( X ) ) ) S^l(X)=\tau_c(\tau_r(S^{l-1}(X))) Sl(X)=τc(τr(Sl1(X)))
S ( X ) = lim ⁡ l → ∞ S l ( X ) S(X)=\lim_{l \to \infty}S^l(X) S(X)=limlSl(X)
其中 τ r ( X ) = X ⊘ ( X 1 N 1 N ⊤ ) , τ c ( X ) = X ⊘ ( 1 N 1 N ⊤ X ) \tau_r(X)=X\oslash(X1_N1^{\top}_N), \tau_c(X)=X\oslash(1_N1^{\top}_NX) τr(X)=X(X1N1N),τc(X)=X(1N1NX)
⊘ \oslash 表示的是每个对应的元素相除: C = A ⊘ B → C i j = A i j / B i j C=A\oslash B \to C_{ij}=A_{ij}/B_{ij} C=ABCij=Aij/Bij。而 1 N 1_N 1N表示的是全为1的列向量。
因此上面的 τ \tau τ实际上就是行和列的均一。
可以证明 S ( X ) S(X) S(X)必然收敛到一个叫Birkhoff polytope的空间上,记作:
B N = { P ∈ [ 0 , 1 ] ∈ R N , N , P 1 N = 1 N , P ⊤ 1 N = 1 N } \mathcal{B}_N=\{P \in [0,1] \in \mathbb{R}^{N,N}, P1_N=1_N, P^{\top}1_N=1_N\} BN={P[0,1]RN,N,P1N=1N,P1N=1N}
也就是横竖都只有一个1的方阵。

严谨起见,需要梯度下降的训练时用S,而测试时则使用正常的匈牙利算法。
_匈牙利算法:给一个矩阵,每一列选一个数,并保证每列选的行数不同,且所有选择的数加起来最小。代码中使用scipy库中的optimize._linear_sum_assignment 来实现。

可以看下这个操作的代码:

def log_sinkhorn(log_alpha, n_iter):
    for _ in range(n_iter):
        #先把x作为e的幂变换回来,然后sum,再log回去
        log_alpha = log_alpha - torch.logsumexp(log_alpha, -1, keepdim=True)
        log_alpha = log_alpha - torch.logsumexp(log_alpha, -2, keepdim=True)
        return log_alpha.exp()
    
log_sinkhorn(torch.log(X), n_iter=20) #进行二十次Sinkhorn operator 

相比于直接把X扔进去迭代,先把它映射到log空间中可以提高稳定性(应该也可以加快收敛)
循环里可以做减法是因为 log ⁡ ( a b ) = l o g ( a ) − l o g ( b ) \log(\frac{a}{b})=log(a)-log(b) log(ba)=log(a)log(b)
绝了

Sinkhorn Network

目标就是把混乱的矩阵 X ~ \tilde{X} X~转换为正常的矩阵 X i = P θ , X ~ i − 1 X ~ i + ϵ i X_i=P^{-1}_{\theta, \tilde{X}_i}\tilde{X}_i+ \epsilon_i Xi=Pθ,X~i1X~i+ϵi,其中 ϵ i \epsilon_i ϵi是噪声,文中说这个噪音可以保证所有结果都是唯一的,这样才能保证 P P P的收敛。毕竟如果每次结果扔进去都是一样的,结果就不会变了。
噪声代码:

def sample_gumbel(shape, device='cpu', eps=1e-20):
    u = torch.rand(shape, device=device)
    return -torch.log(-torch.log(u + eps) + eps)

# 生成一个3x3的噪声矩阵
sample_gumbel((3, 3))

因为rand出来的是 [ 0 , 1 ) [0, 1) [0,1)的随机数,所以要加个非常小的数避免过定义域。
这个取噪声的方法很神奇,套两次log。。。

目标就是找到最好的重构:
f ( θ , X , X ~ ) = ∑ i = 1 M ∣ ∣ X i − P θ , X ~ i − 1 X ~ i ∣ ∣ 2 f(\theta, X, \tilde{X})=\sum^M_{i=1}||X_i-P^{-1}_{\theta, \tilde{X}_i} \tilde{X}_i||^2 f(θ,X,X~)=i=1M∣∣XiPθ,X~i1X~i2
这个就可以视作损失函数。而关键的 P θ , X ~ = S ( g ( X ~ , θ ) / τ ) P_{\theta, \tilde{X}}=S(g(\tilde{X}, \theta)/\tau) Pθ,X~=S(g(X~,θ)/τ),也就是上面那个式子迭代出的转置矩阵,至于 g g g其实就是个映射函数。

这里文章说明,因为前面有丰富的映射关系,所以 P θ , X ~ i − 1 X ~ i P^{-1}_{\theta, \tilde{X}_i} \tilde{X}_i Pθ,X~i1X~i可以换用 P θ , X ~ i ⊤ X ~ i P^{\top}_{\theta, \tilde{X}_i} \tilde{X}_i Pθ,X~iX~i。(既然这样干脆别换啊,转置和原本的矩阵有啥区别啊)

尝试复现了算法,目标改成了数字排序。能收敛,但只能收敛一点点,非常不稳定,不知道是代码问题还是算法问题。

你可能感兴趣的:(机器学习,算法,python)