pytorch 1.5.1
torchvision 0.6.1
numpy 1.18.5
matplotlib 3.3.0

Main usage of class

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
import math

def log_sum_exp(x, axis=None):
    :param x: pytorch tensor
    :param axis: the dimension to
    :return: y = log(sum(exp(x-x_max)))
    x_max = torch.max(x, axis)[0]
    y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
    return y

def fGAN(p_samples, q_samples, measure=None):
    :param p_samples: positive samples
    :param q_samples: negative samples
    :param measure: different measurement, support GAN, JSD, X2, KL, H2 now
    :return: expectation of positive samples and negative samples
    log_2 = math.log(2.)

    if measure == 'GAN':
        Ep = - F.softplus(-p_samples)
        Eq = F.softplus(-q_samples) + q_samples
    elif measure == 'JSD':
        Ep = log_2 - F.softplus(-p_samples)
        Eq = F.softplus(-q_samples) + q_samples - log_2  # Note JSD will be shifted
    elif measure == 'X2':
        Ep = p_samples ** 2
        Eq = 0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
    elif measure == 'KL':
        Ep = p_samples
        Eq = torch.exp(q_samples - 1.)
    # elif measure == 'RKL':
    #     Ep = -torch.exp(-p_samples)
    #     Eq = q_samples - 1.
    elif measure == 'H2':
        Ep = 1. - torch.exp(-p_samples)
        Eq = torch.exp(q_samples) - 1.
        NotImplementedError("Check your measurement! Support GAN, JSD, X2, KL, H2 now")
        Ep, Eq = None, None

    return Ep, Eq

class T_net(nn.Module):
    define neural network to estimate function T

    def __init__(self, num_input, num_hidden):
        :param num_input: the dimension of input layer
        :param num_hidden: the dimension of hidden layer
        self.model = nn.Sequential(
            nn.Linear(num_input, num_hidden),
            nn.Linear(num_hidden, num_hidden),
            nn.Linear(num_hidden, 1)

    def _init(self):
        """Init network parameters."""
        net = self.model
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
            elif isinstance(m, nn.BatchNorm2d):
                init.constant(m.weight, 1)
                init.constant(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal(m.weight, std=1e-3)

    def forward(self, x, z):
        """Forward propagation"""
        T_input = torch.stack([x, z], dim=-1)
        return self.model(T_input)

class MINE:
    def __init__(self, num_input, num_hidden=128, lr=1e-3, ma_rate=0.01,
                 mode='DV', measure='GAN', device='cpu'):
        :param num_input: input dimension of X, which should equal to dimension of Y
        :param num_hidden: hidden unit dimension of NN
        :param lr: learning rate to update MINE
        :param ma_rate: moving average rate in mode 'DV'
        :param mode: three basics of calculating mutual information:
                    -DV, fGAN, NCE
        :param measure: for different modes, there are different measurement
                    -DV: naive, ma (see more details in https://arxiv.org/pdf/1801.04062.pdf)
                    -fGAN: GAN, JSD, X2, KL, H2(see more details in https://arxiv.org/pdf/1606.00709.pdf)
                    -NCE: info (see more details in https://arxiv.org/pdf/1808.06670.pdf)
        :param device: cuda or cpu

        please note that DV and NCE based measurements need large number of negative samples
        self.device = device
        self.update_mode = mode
        self.measure = measure
        self.ma_rate = ma_rate
        self.T = T_net(num_input=num_input,
        self.optimizer = optim.Adam(self.T.parameters(), lr=lr)
        self.et_ma = 1.

    def get_loss(self, samples):
        :param samples: the samples to get loss in 'torch.float32' type. it should be a dict which include joint
        samples and marginal samples. in this implement, the marginal is only another sample set of one distribution.
        see more details in https://arxiv.org/pdf/1801.04062.pdf
        :return: loss: the loss term
                 lb: the lower bound of mutual information (in DV mode)
        joint_samples = samples['joint']
        marginal_samples = samples['marginal']

        if self.update_mode is 'DV':
            t = torch.mean(self.T(joint_samples[:, 0], joint_samples[:, 1])).to(self.device)
            et = torch.exp(self.T(joint_samples[:, 0], marginal_samples)).to(self.device)
            lb = t - torch.log(torch.mean(et))
            if self.measure is 'naive':
                loss = -lb
            elif self.measure is 'ma':
                # et use move average
                et_ma = (1 - self.ma_rate) * self.et_ma + self.ma_rate * torch.mean(et)
                self.et_ma = et_ma
                loss = -(t - (1 / et_ma.mean()).detach() * torch.mean(et))
                NotImplementedError("Check your measure! DV support naive and ma now.")
        elif self.update_mode is 'fGAN':
            Txy = self.T(joint_samples[:, 0], joint_samples[:, 1])
            Tx_y = self.T(joint_samples[:, 0], marginal_samples)
            # print(Txy)

            Ep, Eq = fGAN(Txy, Tx_y, self.measure)

            loss = - Ep.mean() + Eq.mean()
            lb = loss

        elif self.update_mode is 'NCE':
            batch_size = joint_samples.shape[0]
            joint_samples_ = joint_samples[:, 0].expand(batch_size, batch_size)
            marginal_samples_ = marginal_samples.expand(batch_size, batch_size)

            et = torch.exp(self.T(joint_samples_, marginal_samples_.T))
            Eq = torch.log(torch.sum(et, dim=0))

            loss = -(self.T(joint_samples[:, 0], joint_samples[:, 1]) - Eq).mean()
            # # print(joint_samples_)
            # # print(marginal_samples_)
            lb = loss
            NotImplementedError("Check your mutual information calculation mode! Support DV, fGAN, NCE now.")

        return loss, lb

    def update(self, samples):
        :param samples: same as function "get_loss"
        :return: same as function "get_loss"
        be attention that the update function is used with only mutual information as the loss term
        you should write your own update function if you have other loss terms
        # update the parameter of neural network by back propagation
        loss, lb = self.get_loss(samples)


        return loss, lb

Take in samples in dictionary form: {‘joint’: ( x , z ) (x,z) (x,z), ‘marginal’: x x x}, where ( x , z ) (x,z) (x,z) and x x x are all ‘torch.float32’ type. And return the mutual information neural estimation of the inputs. For example:

import torch
import numpy as np
from Mine import MINE

# define data (x, z) in multi_normal distribution with correlation 0.4
data = np.random.multivariate_normal(mean=[0, 0],
                                     cov=[[1, 0.4], [0.4, 1]],

def _totorch(x):
    other type to torch.float32
    return torch.tensor(x, dtype=torch.float32) \
        if type(x) is not torch.float32 else x

# sample minibatch data from dataset
def sample_batch(data, bs):
    data_batch = {}
    rn_joint = np.random.choice(range(data.shape[0]), size=bs, replace=False)
    data_batch['joint'] = _totorch(data[rn_joint, :])
    rn_marginal = np.random.choice(range(data.shape[0]), size=bs, replace=False)
    data_batch['marginal'] = _totorch(data[rn_marginal, 1])
    return data_batch

mine = MINE(num_input=2,
data_batch = sample_batch(data, bs=64)

loss, lb = mine.get_loss(data_batch)

with output:

tensor(0.0001, grad_fn=<AddBackward0>)

support material under different mode


D K L ( P ∣ ∣ Q ) = sup ⁡ T : Ω → R E P [ T ] − log ⁡ E Q [ e T ] D_{KL}(\mathbb{P}||\mathbb{Q})=\sup_{T:\Omega\rightarrow\mathbb{R}}{\mathbb{E}_{\mathbb{P}}[T]-\log{\mathbb{E_{\mathbb{Q}}}[e^T]}} DKL(PQ)=T:ΩRsupEP[T]logEQ[eT]
MINE复现[DV, fGAN, infoNCE]_第1张图片
其中 ( x , z ) (x,z) (x,z)为代码中的sample[‘joint’], z z z为代码中的sample[‘marginal’]。为了稳定训练,可以使用移动平均(measure=‘ma’)来对梯度进行修正:
e m a T = ( 1 − γ ) e m a T + γ e T L = − 1 b ∑ T ( x , z ) + ∑ e T ∑ e m a T \begin{aligned} e^T_{ma} &= (1-\gamma)e^T_{ma}+\gamma e^T \\ \mathcal{L} &= -\frac{1}{b}\sum{T(x,z)}+\frac{\sum{e^T}}{\sum{e^T_{ma}}} \end{aligned} emaTL=(1γ)emaT+γeT=b1T(x,z)+emaTeT
其中$\gamma $是移动平均的控制量,对应代码中的’ma_rate’。更多细节请看论文:



L ( θ , ω ) = E x ∼ P [ g f ( V ω ( x ) ) ] + E x ∼ Q θ [ − f ∗ ( g f ( V ω ( x ) ) ) ] \mathcal{L(\theta,\omega)}=\mathbb{E}_{x\sim P}[g_f(V_{\omega}(x))]+\mathbb{E}_{x\sim Q_{\theta}}[-f^*(g_f(V_{\omega}(x)))] L(θ,ω)=ExP[gf(Vω(x))]+ExQθ[f(gf(Vω(x)))]
对于不同的f-散度,选择不同的激活函数 g f g_f gf和共轭函数 f ∗ f^* f即可计算不同的fGAN。

MINE复现[DV, fGAN, infoNCE]_第2张图片



也可以用NLP中常用的NCE loss,将此预测问题转化为二分类问题:
I ^ ω , ψ ( i n f o N C E ) : = E P [ T ω , ψ ( x , E ψ ( x ) ) − E P ~ [ log ⁡ ∑ x ′ e T ω , ψ ( x ′ , E ψ ( x ) ) ] ] \hat{I}^{({\rm info}NCE)}_{ω,ψ}:=\mathbb{E}_{\mathbb{P}}[T_{ω,ψ}(x,E_{\psi}(x))-\mathbb{E}_{\mathbb{\tilde{P}}}[\log{\sum_{x^{\prime}}}{e^{T_{ω,ψ}(x^{\prime},E_{\psi}(x))}}]] I^ω,ψ(infoNCE):=EP[Tω,ψ(x,Eψ(x))EP~[logxeTω,ψ(x,Eψ(x))]]
(Note: 此部分代码测试不完全,可能存在问题,谨慎使用。)



MINE复现[DV, fGAN, infoNCE]_第3张图片
MINE复现[DV, fGAN, infoNCE]_第4张图片
验证一下 2 D J S = 2 D J S D − log ⁡ 4 2D_{JS}=2D_{JSD}-\log4 2DJS=2DJSDlog4
MINE复现[DV, fGAN, infoNCE]_第5张图片


  • 更多关于infoNCE的测试
  • 添加互信息至GAN的测试demo

