在TensorFlow框架下实现DBN网络

前言
在上一篇博客‘Windows下安装Tensorflow’的基础上,实现深度学习网络——DBN网络,以经典的手写字体识别为例子。

一、下载手写字体数据集,官方网址为:http://yann.lecun.com/exdb/mnist/
下载:
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

二、参考github上的DBN实现源码为:https://github.com/myme5261314/dbn_tf

原作者实现的过程:https://gist.github.com/myme5261314/005ceac0483fc5a581cc

注:博主选择的python编辑器为pycharm编辑器,在windows下安装tensorflow后,需要在pycharm中配置路径,具体参考:http://blog.csdn.net/wx7788250/article/details/60877166

三,将github上下载的代码导入到pycharm中,会有几处错误,下面一一修改:
问题1:官方PIL(python image library)目前只支持python2.7及以下版本,所以当导入代码后,会提示ImportError: No module named 'Image'也就是没有import到image这个库。
解决办法:下载非官方库从http://www.lfd.uci.edu/~gohlke/pythonlibs/ 该网址上提供了支持64位系统的PIL文件,网站上叫做Pillow,下载下来,是个 .whl 结尾的文件,这个其实就是python使用的一种压缩文件,后缀名改成zip,可以打开。
这个需要用 pip 安装,在cmd下,进入该.whl下载的文件夹,输入

pip install Pillow‑4.1.1‑cp35‑cp35m‑win_amd64.whl

即可安装。
注意,这里有一段 Pillow is a replacement for PIL, the Python Image Library, which provides image processing functionality and supports many file formats.
Use from PIL import Image instead of import Image.
意思就是说,要用 ‘ from PIL import Image’ 代替 ‘import Image’
也就是将导入的程序代码中import Image 改为 from PIL import Image 即可。

from PIL import Image 

问题2:代码中会有几处print的错误,由于版本问题在python3.X中,print的输出要加‘()’,所以修改后的样子为:

  print (sess.run(
            err_sum, feed_dict={X: trX, rbm_w: n_w, rbm_vb: n_vb, rbm_hb: n_hb}))

问题3: 修改上述问题后,看似没有错误了,运行代码rbm_MNIST_test.py发现依旧存在问题,问题描述如下:

C:\Users\Administrator\AppData\Local\Programs\Python\Python35\python.exe D:/hlDL/tensorflow-DBN/dbn_tf-master/rbm_MNIST_test.py
Traceback (most recent call last):
  File "D:/hlDL/tensorflow-DBN/dbn_tf-master/rbm_MNIST_test.py", line 16, in <module>
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
  File "D:\hlDL\tensorflow-DBN\dbn_tf-master\input_data.py", line 150, in read_data_sets
    train_images = extract_images(local_file)
  File "D:\hlDL\tensorflow-DBN\dbn_tf-master\input_data.py", line 39, in extract_images
    buf = bytestream.read(rows * cols * num_images)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python35\lib\gzip.py", line 274, in read
    return self._buffer.read(size)
TypeError: only integer scalar arrays can be converted to a scalar index

修改方式:在input_data.py中找到代码

def _read32(bytestream):
  dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  return numpy.frombuffer(bytestream.read(4), dtype=dt)

修改为:

def _read32(bytestream):
  dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]

在末尾出加上[0],再次运行,问题得到解决。这个问题似乎是新版本的Numpy存在的,最近的更新中,将单个元素数组作为一个标量来进行索引。

四、运行结果
在TensorFlow框架下实现DBN网络_第1张图片
注:根据具体代码,输出结果每10000次输出一次精度,共迭代60000次,同时每10000次输出一个识别结果的image,命名为:rbm_x.png。这里贴出第一幅识别结果和最后一次识别结果图如下:
在TensorFlow框架下实现DBN网络_第2张图片
rbm_0
在TensorFlow框架下实现DBN网络_第3张图片
rbm_5

你可能感兴趣的:(深度学习)