Deep InfoMax

Deep InfoMax

    • 2、算法模型
    • 3、目标函数

好特征的基本原则应当是“能够从整个数据集中辨别出该样本出来”,也就是说,提取出该样本(最)独特的信息。

2、算法模型

Deep InfoMax_第1张图片
上图是一个图像数据的基本编码器模型,编码器设为 E E E,输入一张图片 X X X,经过几层卷积,得到一个 M ∗ M ∗ c h a n n e l s M*M*channels MMchannels的特征图,再将这些特征经过卷积展开或全连接层的计算,得到一个一维的特征向量 Y Y Y,就是输入数据的高层语义信息。
论文指出,基于 D I M DIM DIM的编码器目标是:
在这里插入图片描述
Deep InfoMax_第2张图片
Deep InfoMax_第3张图片
Deep InfoMax_第4张图片
Deep InfoMax_第5张图片
Deep InfoMax_第6张图片

3、目标函数

Deep InfoMax_第7张图片

Deep InfoMax_第8张图片
Deep InfoMax_第9张图片

# y是经过卷积和多层线性层最后输出的vector, M是经过卷积和一层线性输出的vector.
# y.shape=(None,64), M.shape=(None,128,26,26)
y, M = self.encoder(x)

# rotate images to create pairs for comparison.
# 批量中第一个图像放在最后一位,其他的顺着向上移动一位. 相当于打乱了batch中图像的顺序
#M_prime就是将每个batch的第一张图片对应的中间层特征置于该batch特征的末尾(由#此构造出DeepInfoMax中,用于生成"Fake" pair的another image)
M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)  # unsqueeze 新增一个维度. M_prime.shape = (None,128,26,26)

y_expand = y.unsqueeze(-1).unsqueeze(-1)    # (None, 64, 1, 1)
y_expand = y_expand.expand(-1, -1, 26, 26)  # (None, 64, 26, 26) 最后两维是复制出来的,使其具有空间结构

y_M = torch.cat((M, y_expand), dim=1)       # (None, 196, 26, 26)  将 M 和 y_expand 叠加
y_M_prime = torch.cat((M_prime, y_expand), dim=1)  # (None,196,26,26) 将 M_prime 和 y_expand 叠加

# local 最终输出的是一个 (None,1,26,26)的vector. 有多个位置, 每个位置代表一个score
Ej = -F.softplus(-self.local_d(y_M)).mean()       # T(x, E_{\phi}(x))    score of anchor and positive
Em = F.softplus(self.local_d(y_M_prime)).mean()   # T(x', E_{\phi}(x))   score of anchor and negative
LOCAL = -1.*(Ej - Em) * self.beta                 # -1 * Jensen-Shannon MI estimator

# global 最终输出的是 (None,1), 将 y,m 合并成一个score.
Ej = -F.softplus(-self.global_d(y, M)).mean()
Em = F.softplus(self.global_d(y, M_prime)).mean()
GLOBAL = -1.*(Ej - Em) * self.alpha

prior = torch.rand_like(y)

# 只更新 prior_d, 不更新 encoder
term_a = torch.log(self.prior_d(prior)).mean()      # 先验经过discriminator (None,64) -> (None,1) 输出是 sigmoid
term_b = torch.log(1.0 - self.prior_d(y.detach())).mean()    # 真实的y经过 discriminator
PRIOR = -1.*(term_a + term_b) * self.gamma          # 二分类损失. 希望 prior_d 能够区分 prior 和 y

# 只更新 encoder, 不更新 prior_d
PRIOR_encoder = -1. * torch.log(self.prior_d(y)).mean() * self.gamma
return LOCAL + GLOBAL + PRIOR_encoder, PRIOR

1、Learning Deep Representations by Mutual Information Estimation and Maximization
2、原作者的源码
3、Deep InfoMax Pytorch
4、深度学习中的互信息:无监督提取特征
5、Keras 实现的一个版本
6、F-GAN & MINE
7、自编码器的最佳特征:最大化互信息

你可能感兴趣的:(图网络)