基于python的RBM实现

主要用的是pytorch来写的,numpy也行,把对应的地方一换应该也是可以的

"""
this code is for RBM based on CD-1 learning
author:media
date:2020-4-23
"""
import torch


class RBM:

    def __init__(self, _vis_dim, _hid_dim):
        self.vis_dim = _vis_dim
        self.hid_dim = _hid_dim
        self.lr = 0.01
        self.bis_v = torch.zeros(_vis_dim)
        self.bis_h = torch.zeros(_hid_dim)
        self.weight = torch.rand(_vis_dim, _hid_dim)

    def train(self, x):
        """ this part is for cd-1 sampling"""
        h_temp = torch.sigmoid(self.bis_h + x.mm(self.weight))  # visible -> hidden
        h = torch.sign(h_temp-torch.randint(0, 2, (1, self.hid_dim)))  # binary process
        v_temp = torch.sigmoid(self.bis_v + h.mm(self.weight.t()))  # hidden -> new_visible
        v_out = torch.sign(v_temp-torch.randint(0, 2, (1, self.vis_dim)))  # binary process
        """ this part is for updating the connection weights"""
        h_out = torch.sigmoid(self.bis_h + v_out.mm(self.weight))  # new_visible -> new_hidden
        self.weight = self.weight + self.lr * ((x-v_out).t().mm(h_temp))
        tt = torch.mean(h_temp-h_out, dim=0)
        self.bis_h = self.bis_h + self.lr * torch.mean(h_temp-h_out, dim=0)
        self.bis_v = self.bis_v + self.lr * torch.mean(x - v_out, dim=0)

    def predict(self, x):
        h_temp = torch.sigmoid(self.bis_h + x.mm(self.weight))  # visible -> hidden
        h = torch.sign(h_temp - torch.randint(0, 2, (1, self.hid_dim)))  # binary process
        v_temp = torch.sigmoid(self.bis_v + h.mm(self.weight.t()))  # hidden -> new_visible
        v_out = torch.sign(v_temp - torch.randint(0, 2, (1, self.vis_dim)))  # binary process
        return v_out


def train_method_1(n_epoch):
    for i in range(n_epoch):
        for (data, _) in enumerate(trainloader):
            rbm.train(data)


def train_method_2(door):
    square_error = 1000
    while square_error > door:
        batch_num = 0
        whole_error = 0
        for (data, _) in enumerate(trainloader):
            rbm.train(data)
            v_out = rbm.predict(data)
            whole_error += torch.nn.MSELoss(data, v_out)
            batch_num += 1
        square_error = whole_error/batch_num



在下面又写了两个训练方法,第一个是针对设定迭代阈值作为终止条件,第二个是使用可视层的均方误差作为终止条件,因为这两个是我临时起意加上去的,所有准确性可能有待商榷,尤其是第二个计算均方差那里,懒得去实现直接调用了nn的计算均方误差的库(我觉得大概率是过不了的)。
因为小弟我还是初学者,希望抛砖引玉让各位大佬指正我的问题,谢谢!

你可能感兴趣的:(python编程)