Alias method(别名采样方法)

参考

https://blog.csdn.net/haolexiao/article/details/65157026
https://github.com/tatsuokun/context2vec/blob/master/src/core/loss_func.py

代码

class WalkerAlias(object):
    """
    This is from Chainer's implementation.
    You can find the original code at
    https://github.com/chainer/chainer/blob/v4.4.0/chainer/utils/walker_alias.py
    This class is
        Copyright (c) 2015 Preferred Infrastructure, Inc.
        Copyright (c) 2015 Preferred Networks, Inc.
    """

    def __init__(self, probs):
        prob = np.array(probs, np.float32)
        # 归一化.
        prob /= np.sum(prob)
        # 词典大小.
        threshold = np.ndarray(len(probs), np.float32)
        # 大小为两倍的词典大小.
        # 每列中最多只放两个事件的思想.
        values = np.ndarray(len(probs) * 2, np.int32)
        # 双指针思想.
        il, ir = 0, 0
        # e.g. [(0, 0), (0, 1), (0, 2), (0, 3), (0.001, 4), ...].
        pairs = list(zip(prob, range(len(probs))))
        # 按照prob的值从小到大排序 注意该点很重要,方便后面的回填,避免出现bug.
        pairs.sort()
        for prob, i in pairs:
            # 按照其均值归一化, (除以1/N, 即乘以N).
            p = prob * len(probs)
            # p>1, 说明当前列需要被截断.
            # 回填的思想, 如果当前的概率值大于均值的话, 就将遍历之前的ir到il之间没有填满的坑.
            # 主要是为了构造一个1*N的矩阵.
            while p > 1 and ir < il:
                # 为了填充没有满的那一列ir, 并且将索引保存到奇数列.
                values[ir * 2 + 1] = i
                # 本列一共减少了(1-threshold[ir])的概率值.
                p -= 1.0 - threshold[ir]
                # ir位置的坑已经被填满,故ir+=1.
                ir += 1
            # 概率值*词典大小.
            threshold[il] = p
            # 保存单词的索引, 偶数列.
            values[il * 2] = i
            il += 1
        # fill the rest
        for i in range(ir, len(probs)):
            values[i * 2 + 1] = 0

        assert ((values < len(threshold)).all())
        self.threshold = threshold
        self.values = values

    def sample(self, shape):
        """
        采样
        """
        # 从均匀分布中抽取样本.
        ps = np.random.uniform(0, 1, shape)
        # 均值归一化, (除以1/N, 即乘以N).
        pb = ps * len(self.threshold)
        # 转化为int类型, 可以认为index对应着词典中某个词的索引.
        index = pb.astype(np.int32)
        # 选择是奇数列还是偶数列, 注意, (pb - index) 返回的是0-1之间的数组, 巧妙!!!
        left_right = (self.threshold[index] < pb - index).astype(np.int32)

        return self.values[index * 2 + left_right]

如有错误,欢迎指正!

你可能感兴趣的:(NLP,预训练模型,context2vec,nlp)