MINE复现[DV, fGAN, infoNCE]

Environment

pytorch 1.5.1
torchvision 0.6.1
numpy 1.18.5
matplotlib 3.3.0

Main usage of class

# -*- coding: UTF-8 -*-
"""=================================================
@IDE     :Pycharm
@Date    :2020.9.22
@Desc    :The module is the implement of mutual information neural estimation, it is used widely
          in measuring the distance between 2 distributions like GAN or self-supervised DL.
          See more details in readme.pdf
=================================================="""
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
import math


def log_sum_exp(x, axis=None):
    """
    :param x: pytorch tensor
    :param axis: the dimension to
    :return: y = log(sum(exp(x-x_max)))
    """
    x_max = torch.max(x, axis)[0]
    y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
    return y


def fGAN(p_samples, q_samples, measure=None):
    """
    :param p_samples: positive samples
    :param q_samples: negative samples
    :param measure: different measurement, support GAN, JSD, X2, KL, H2 now
    :return: expectation of positive samples and negative samples
    """
    log_2 = math.log(2.)

    if measure == 'GAN':
        Ep = - F.softplus(-p_samples)
        Eq = F.softplus(-q_samples) + q_samples
    elif measure == 'JSD':
        Ep = log_2 - F.softplus(-p_samples)
        Eq = F.softplus(-q_samples) + q_samples - log_2  # Note JSD will be shifted
    elif measure == 'X2':
        Ep = p_samples ** 2
        Eq = 0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
    elif measure == 'KL':
        Ep = p_samples
        Eq = torch.exp(q_samples - 1.)
    # elif measure == 'RKL':
    #     Ep = -torch.exp(-p_samples)
    #     Eq = q_samples - 1.
    elif measure == 'H2':
        Ep = 1. - torch.exp(-p_samples)
        Eq = torch.exp(q_samples) - 1.
    else:
        NotImplementedError("Check your measurement! Support GAN, JSD, X2, KL, H2 now")
        Ep, Eq = None, None

    return Ep, Eq


class T_net(nn.Module):
    """
    define neural network to estimate function T
    """

    def __init__(self, num_input, num_hidden):
        """
        :param num_input: the dimension of input layer
        :param num_hidden: the dimension of hidden layer
        """
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(num_input, num_hidden),
            nn.ELU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ELU(),
            nn.Linear(num_hidden, 1)
        )
        self._init()

    def _init(self):
        """Init network parameters."""
        net = self.model
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                init.xavier_uniform(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant(m.weight, 1)
                init.constant(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal(m.weight, std=1e-3)

    def forward(self, x, z):
        """Forward propagation"""
        T_input = torch.stack([x, z], dim=-1)
        return self.model(T_input)


class MINE:
    def __init__(self, num_input, num_hidden=128, lr=1e-3, ma_rate=0.01,
                 mode='DV', measure='GAN', device='cpu'):
        """
        :param num_input: input dimension of X, which should equal to dimension of Y
        :param num_hidden: hidden unit dimension of NN
        :param lr: learning rate to update MINE
        :param ma_rate: moving average rate in mode 'DV'
        :param mode: three basics of calculating mutual information:
                    -DV, fGAN, NCE
        :param measure: for different modes, there are different measurement
                    -DV: naive, ma (see more details in https://arxiv.org/pdf/1801.04062.pdf)
                    -fGAN: GAN, JSD, X2, KL, H2(see more details in https://arxiv.org/pdf/1606.00709.pdf)
                    -NCE: info (see more details in https://arxiv.org/pdf/1808.06670.pdf)
        :param device: cuda or cpu

        please note that DV and NCE based measurements need large number of negative samples
        """
        self.device = device
        self.update_mode = mode
        self.measure = measure
        self.ma_rate = ma_rate
        self.T = T_net(num_input=num_input,
                       num_hidden=num_hidden,
                       ).to(self.device)
        self.optimizer = optim.Adam(self.T.parameters(), lr=lr)
        self.et_ma = 1.

    def get_loss(self, samples):
        """
        :param samples: the samples to get loss in 'torch.float32' type. it should be a dict which include joint
        samples and marginal samples. in this implement, the marginal is only another sample set of one distribution.
        see more details in https://arxiv.org/pdf/1801.04062.pdf
        :return: loss: the loss term
                 lb: the lower bound of mutual information (in DV mode)
        """
        joint_samples = samples['joint']
        marginal_samples = samples['marginal']

        if self.update_mode is 'DV':
            t = torch.mean(self.T(joint_samples[:, 0], joint_samples[:, 1])).to(self.device)
            et = torch.exp(self.T(joint_samples[:, 0], marginal_samples)).to(self.device)
            lb = t - torch.log(torch.mean(et))
            if self.measure is 'naive':
                loss = -lb
            elif self.measure is 'ma':
                # et use move average
                et_ma = (1 - self.ma_rate) * self.et_ma + self.ma_rate * torch.mean(et)
                self.et_ma = et_ma
                loss = -(t - (1 / et_ma.mean()).detach() * torch.mean(et))
            else:
                NotImplementedError("Check your measure! DV support naive and ma now.")
        elif self.update_mode is 'fGAN':
            Txy = self.T(joint_samples[:, 0], joint_samples[:, 1])
            Tx_y = self.T(joint_samples[:, 0], marginal_samples)
            # print(Txy)

            Ep, Eq = fGAN(Txy, Tx_y, self.measure)

            loss = - Ep.mean() + Eq.mean()
            lb = loss

        elif self.update_mode is 'NCE':
            batch_size = joint_samples.shape[0]
            joint_samples_ = joint_samples[:, 0].expand(batch_size, batch_size)
            marginal_samples_ = marginal_samples.expand(batch_size, batch_size)

            et = torch.exp(self.T(joint_samples_, marginal_samples_.T))
            Eq = torch.log(torch.sum(et, dim=0))

            loss = -(self.T(joint_samples[:, 0], joint_samples[:, 1]) - Eq).mean()
            # # print(joint_samples_)
            # # print(marginal_samples_)
            lb = loss
        else:
            NotImplementedError("Check your mutual information calculation mode! Support DV, fGAN, NCE now.")

        return loss, lb

    def update(self, samples):
        """
        :param samples: same as function "get_loss"
        :return: same as function "get_loss"
        be attention that the update function is used with only mutual information as the loss term
        you should write your own update function if you have other loss terms
        """
        # update the parameter of neural network by back propagation
        loss, lb = self.get_loss(samples)

        self.optimizer.zero_grad()
        autograd.backward(loss)
        self.optimizer.step()

        return loss, lb

Take in samples in dictionary form: {‘joint’: ( x , z ) (x,z) (x,z), ‘marginal’: x x x}, where ( x , z ) (x,z) (x,z) and x x x are all ‘torch.float32’ type. And return the mutual information neural estimation of the inputs. For example:

import torch
import numpy as np
from Mine import MINE

# define data (x, z) in multi_normal distribution with correlation 0.4
data = np.random.multivariate_normal(mean=[0, 0],
                                     cov=[[1, 0.4], [0.4, 1]],
                                     size=10000)


def _totorch(x):
    """
    other type to torch.float32
    """
    return torch.tensor(x, dtype=torch.float32) \
        if type(x) is not torch.float32 else x


# sample minibatch data from dataset
def sample_batch(data, bs):
    data_batch = {}
    rn_joint = np.random.choice(range(data.shape[0]), size=bs, replace=False)
    data_batch['joint'] = _totorch(data[rn_joint, :])
    rn_marginal = np.random.choice(range(data.shape[0]), size=bs, replace=False)
    data_batch['marginal'] = _totorch(data[rn_marginal, 1])
    return data_batch


mine = MINE(num_input=2,
            mode='fGAN',
            measure='JSD',
            lr=1e-3,
            device='cpu')
data_batch = sample_batch(data, bs=64)

loss, lb = mine.get_loss(data_batch)
print(loss)

with output:

tensor(0.0001, grad_fn=<AddBackward0>)

see more function parameters details in MINE.py

support material under different mode

DV

根据KL散度的Donsker-Varadhan表示
D K L ( P ∣ ∣ Q ) = sup ⁡ T : Ω → R E P [ T ] − log ⁡ E Q [ e T ] D_{KL}(\mathbb{P}||\mathbb{Q})=\sup_{T:\Omega\rightarrow\mathbb{R}}{\mathbb{E}_{\mathbb{P}}[T]-\log{\mathbb{E_{\mathbb{Q}}}[e^T]}} DKL(PQ)=T:ΩRsupEP[T]logEQ[eT]
不断推进这个函数的上界就可以得到KL-散度的估计值,算法的伪代码如下所示
MINE复现[DV, fGAN, infoNCE]_第1张图片
其中 ( x , z ) (x,z) (x,z)为代码中的sample[‘joint’], z z z为代码中的sample[‘marginal’]。为了稳定训练,可以使用移动平均(measure=‘ma’)来对梯度进行修正:
e m a T = ( 1 − γ ) e m a T + γ e T L = − 1 b ∑ T ( x , z ) + ∑ e T ∑ e m a T \begin{aligned} e^T_{ma} &= (1-\gamma)e^T_{ma}+\gamma e^T \\ \mathcal{L} &= -\frac{1}{b}\sum{T(x,z)}+\frac{\sum{e^T}}{\sum{e^T_{ma}}} \end{aligned} emaTL=(1γ)emaT+γeT=b1T(x,z)+emaTeT
其中$\gamma $是移动平均的控制量,对应代码中的’ma_rate’。更多细节请看论文:

https://arxiv.org/pdf/1801.04062.pdf

fGAN

除了DV下界,也可以使用f-散度对互信息进行估计:
L ( θ , ω ) = E x ∼ P [ g f ( V ω ( x ) ) ] + E x ∼ Q θ [ − f ∗ ( g f ( V ω ( x ) ) ) ] \mathcal{L(\theta,\omega)}=\mathbb{E}_{x\sim P}[g_f(V_{\omega}(x))]+\mathbb{E}_{x\sim Q_{\theta}}[-f^*(g_f(V_{\omega}(x)))] L(θ,ω)=ExP[gf(Vω(x))]+ExQθ[f(gf(Vω(x)))]
对于不同的f-散度,选择不同的激活函数 g f g_f gf和共轭函数 f ∗ f^* f即可计算不同的fGAN。

MINE复现[DV, fGAN, infoNCE]_第2张图片

更多细节请看论文:https://arxiv.org/pdf/1606.00709.pdf

infoNCE

也可以用NLP中常用的NCE loss,将此预测问题转化为二分类问题:
I ^ ω , ψ ( i n f o N C E ) : = E P [ T ω , ψ ( x , E ψ ( x ) ) − E P ~ [ log ⁡ ∑ x ′ e T ω , ψ ( x ′ , E ψ ( x ) ) ] ] \hat{I}^{({\rm info}NCE)}_{ω,ψ}:=\mathbb{E}_{\mathbb{P}}[T_{ω,ψ}(x,E_{\psi}(x))-\mathbb{E}_{\mathbb{\tilde{P}}}[\log{\sum_{x^{\prime}}}{e^{T_{ω,ψ}(x^{\prime},E_{\psi}(x))}}]] I^ω,ψ(infoNCE):=EP[Tω,ψ(x,Eψ(x))EP~[logxeTω,ψ(x,Eψ(x))]]
(Note: 此部分代码测试不完全,可能存在问题,谨慎使用。)

更多细节请看论文:https://arxiv.org/pdf/1808.06670.pdf

expriment

对以上的不同算法进行比较,首先生成协方差为0.6的二维高斯分布的数据:
MINE复现[DV, fGAN, infoNCE]_第3张图片
然后比较不同的方法对MI的计算:
MINE复现[DV, fGAN, infoNCE]_第4张图片
验证一下 2 D J S = 2 D J S D − log ⁡ 4 2D_{JS}=2D_{JSD}-\log4 2DJS=2DJSDlog4
MINE复现[DV, fGAN, infoNCE]_第5张图片

todo

  • 更多关于infoNCE的测试
  • 添加互信息至GAN的测试demo

or visit my GitHub:
https://github.com/BrutalStark/MINE

你可能感兴趣的:(pytorch,GAN)