之前接触过一些rbm的知识,网上也有很多写得很好的博客,但是也是一知半解,其实这里面包含的东西可以说非常广泛,也非常经典,rbm以及后面的dbn都是hinton老爷子的得意之作,在当时深度神经网络以及反向传播算法还没有那么大规模流行开的时候,可以说也是引领了一番研究这个的热潮,直到现在,hinton老爷子也在宣传他的这个东西,这里面蕴含了很多他自己独到的见解,或者对后面神经网络的发展有所益处。
去年 6 月份写的博文《Yusuke Sugomori 的 C 语言 Deep Learning 程序解读》是囫囵吞枣地读完一个关于 DBN 算法的开源代码后的笔记,当时对其中涉及的算法原理基本不懂。近日再次学习 RBM,觉得有必要将其整理成笔记,算是对那个代码的一个补充。
目录链接
(一)预备知识
(二)网络结构
(三)能量函数和概率分布
(四)对数似然函数
(五)梯度计算公式(六)对比散度算法
(七)RBM 训练算法
(八)RBM 的评估
作者: peghoty
出处: http://blog.csdn.net/itplus/article/details/19408773
欢迎转载/分享, 但请务必声明文章出处.
感觉这个东西跟早期的Hopfield网络以及自编码器很像,但是也有些区别,我这里也有一份基于tensorflow的实现代码,大家可以借鉴一下:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
#!pip install pillow
from PIL import Image
#import Image
from utils1 import tile_raster_images
import matplotlib.pyplot as plt
#%matplotlib inline
#读入数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#定义参数
vb = tf.placeholder("float", [784])
hb = tf.placeholder("float", [500])
W = tf.placeholder("float", [784, 500])
###构建网络
##前向过程
#输入值
X = tf.placeholder("float", [None, 784])
#隐藏层的概率分布
_h0= tf.nn.sigmoid(tf.matmul(X, W) + hb) #probabilities of the hidden units
#根据隐藏层的概率分布采样得到隐藏层的值
h0 = tf.floor(_h0 + tf.random_uniform(tf.shape(_h0))) #sample_h_given_X
##反向重构
#由隐藏层的值反向得到输入层的概率分布
_v1 = tf.nn.sigmoid(tf.matmul(h0, tf.transpose(W)) + vb)
#根据输入层的概率分布采样得到输入层的值
v1 = tf.floor(_v1 + tf.random_uniform(tf.shape(_v1))) #sample_v_given_h
#再次根据新得到的输入层的值计算隐藏层的概率分布,用于梯度计算
_h1 = tf.nn.sigmoid(tf.matmul(v1, W) + hb)
###参数更新
alpha = 1.0
w_pos_grad = tf.matmul(tf.transpose(X), _h0)
w_neg_grad = tf.matmul(tf.transpose(v1), _h1)
CD = (w_pos_grad - w_neg_grad) / tf.to_float(tf.shape(X)[0])
update_w = W + alpha * CD
update_vb = vb + alpha * tf.reduce_mean(X - v1, 0)
update_hb = hb + alpha * tf.reduce_mean(_h0 - _h1, 0)
###定义错误率
err = tf.reduce_mean(tf.square(X - v1))
###创建一个回话并初始化向量
cur_w = np.zeros([784, 500], np.float32)
cur_vb = np.zeros([784], np.float32)
cur_hb = np.zeros([500], np.float32)
prv_w = np.zeros([784, 500], np.float32)
prv_vb = np.zeros([784], np.float32)
prv_hb = np.zeros([500], np.float32)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
###查看第一次运行的误差
print(sess.run(err, feed_dict={X: trX, W: prv_w, vb: prv_vb, hb: prv_hb}))
###运行网络
epochs = 5
batchsize = 100
weights = []
errors = []
for epoch in range(epochs):
for start, end in zip( range(0, len(trX), batchsize), range(batchsize, len(trX), batchsize)):
batch = trX[start:end]
cur_w = sess.run(update_w, feed_dict={ X: batch, W: prv_w, vb: prv_vb, hb: prv_hb})
cur_vb = sess.run(update_vb, feed_dict={ X: batch, W: prv_w, vb: prv_vb, hb: prv_hb})
cur_hb = sess.run(update_hb, feed_dict={ X: batch, W: prv_w, vb: prv_vb, hb: prv_hb})
prv_w = cur_w
prv_vb = cur_vb
prv_hb = cur_hb
if start % 10000 == 0:
errors.append(sess.run(err, feed_dict={X: trX, W: cur_w, vb: cur_vb, hb: cur_hb}))
weights.append(cur_w)
print('Epoch: %d' % epoch,'reconstruction error: %f' % errors[-1])
print(weights[-1].T.shape)
print(weights[-1].T)
plt.plot(errors)
plt.xlabel("Batch Number")
plt.ylabel("Error")
plt.show()
###画出可视化结果
tile_raster_images(X=cur_w.T, img_shape=(28, 28), tile_shape=(25, 20), tile_spacing=(1, 1))
image = Image.fromarray(tile_raster_images(X=cur_w.T, img_shape=(28, 28) ,tile_shape=(25, 20), tile_spacing=(1, 1)))
### Plot image
plt.rcParams['figure.figsize'] = (18.0, 18.0)
imgplot = plt.imshow(image)
imgplot.set_cmap('gray')
plt.show()
https://www.cnblogs.com/peghoty/p/3798500.html
https://www.cnblogs.com/kemaswill/p/3269138.html
https://www.cnblogs.com/kemaswill/p/3269138.html
https://www.cnblogs.com/mtcnn/p/9421777.html
https://zhuanlan.zhihu.com/p/22794772
http://blog.sciencenet.cn/blog-110554-876316.html
https://www.cnblogs.com/xiaokangzi/p/4492466.html
https://www.cnblogs.com/tornadomeet/p/3439503.html
https://www.cnblogs.com/kemaswill/p/3203605.html
https://www.cnblogs.com/wn19910213/p/3581024.html
https://www.cnblogs.com/pinard/p/6530523.html
https://recomm.cnblogs.com/blogpost/9422835?page=3
https://www.xuebuyuan.com/3256855.html
http://www.bubuko.com/infodetail-30480.html
https://www.yuanmas.com/info/9ezZNer7z6.html
https://www.cnblogs.com/xiaojingang/articles/4398503.html
https://blog.csdn.net/u010681136/article/details/40189349
https://www.cnblogs.com/gravity/p/4421846.html
https://www.cnblogs.com/ccienfall/
https://www.cnblogs.com/jicanghai/p/5299209.html
https://www.cnblogs.com/rsmx/p/12909838.html
https://www.cnblogs.com/zhangchaoyang/articles/5537643.html?utm_source=tuicool&utm_medium=referral
https://www.cnblogs.com/tornadomeet/archive/2013/03/27/2984725.html
http://www.mamicode.com/info-detail-299807.html
http://www.cnki.com.cn/Article/CJFDTotal-SSJS201609023.htm
https://www.zhihu.com/question/323493963/answer/689921657
https://zhidao.baidu.com/question/1736336773302059747.html
https://www.cnblogs.com/mengqimoli/p/11132551.html
https://baijiahao.baidu.com/s?id=1599798281463567369&wfr=spider&for=pc