tflearn中lstm文本分类相关实现

最近偶然看到tflearn这个东西,相比tensorflow已经是对原始相关接口做了很大的封装,api通俗易懂去源码写的还真是不错,本例子实现一个lstm网络,看看吧,就几十行代码就可以实现,要部署到线上也应该及其简单,对于tensorflow不熟悉的人可以说是个很好的东西:

 
    
  1. import os
  2. import tflearn
  3. print(os.path.dirname(__file__))
  4. from tflearn.data_utils import to_categorical,pad_sequences
  5. from tflearn.datasets import imdb
  6. """该数据集如果没有的话会自动下载
  7. 如果程序下载失败,建议用命令行下载
  8. #curl -O "http://www.iro.umontreal.ca/~lisa/deep/data/imdb.pkl"
  9. """
  10. train,test,_=imdb.load_data(path="imdb.pkl",n_words=10000,valid_portion=0.1)
  11. trainx,trainy=train
  12. testx,testy=test
  13. print(trainx[0:2])
  14. print(testx[0:2])
  15. """
  16. pad_sequences是一个把序列转化为固定长度的函数,可以去看源码,有两种展开方式,
  17. 第一种向前填充,一种是可以向后填充,填充值可以自己指定,可以指定是0,也可以指定是1
  18. 主要这里的输入是一个int类型的矩阵,
  19. """
  20. trainx=pad_sequences(trainx,maxlen=100,value=0.)
  21. testx=pad_sequences(testx,maxlen=100,value=0.)
  22. """
  23. to_categorical 是一个把分类弄成[0,1]
  24. [1,0]的形式
  25. """
  26. trainy=to_categorical(trainy,nb_classes=2)
  27. testy=to_categorical(testy,nb_classes=2)
  28. """
  29. 输出指定数据大小,前面已经把固定长度弄成100,这里None是批次
  30. """
  31. net=tflearn.input_data([None,100])
  32. """
  33. input_dim就是看分语料中取了多少个单词,一般取top多少
  34. """
  35. net=tflearn.embedding(net,input_dim=10000,output_dim=128)
  36. """
  37. 这里n_units=128就是网络大小情况 dropout=0.8
  38. """
  39. net=tflearn.lstm(net,128,dropout=0.8)
  40. """
  41. 输出目标是2个
  42. """
  43. net = tflearn.fully_connected(net, 2, activation='softmax')
  44. net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
  45. loss='categorical_crossentropy')
  46. # Training
  47. #模型初始化
  48. model = tflearn.DNN(net, tensorboard_verbose=0,tensorboard_dir="/tmp/tflearn_logs/")
  49. """
  50. show_metric=True可以看到过程中的准确率
  51. """
  52. show_metric=True
  53. model.fit(trainx, trainy, validation_set=(testx, testy), show_metric=True,
  54. batch_size=32)
  55. #model.predict()


运行结果情况:

 
    
  1. Training Step: 7032 | total loss: 0.12957 | time: 92.288s
  2. | Adam | epoch: 010 | loss: 0.12957 - acc: 0.9725 -- iter: 22272/22500
  3. 0.11797 | time: 92.396s
  4. | Adam | epoch: 010 | loss: 0.11797 - acc: 0.9752 -- iter: 22304/22500
  5. Training Step: 7034 | total loss: 0.11303 | time: 92.513s
  6. | Adam | epoch: 010 | loss: 0.11303 - acc: 0.9746 -- iter: 22336/22500
  7. Training Step: 7035 | total loss: 0.11608 | time: 92.632s
  8. | Adam | epoch: 010 | loss: 0.11608 - acc: 0.9740 -- iter: 22368/22500
  9. Training Step: 7036 | total loss: 0.10833 | time: 92.764s
  10. | Adam | epoch: 010 | loss: 0.10833 - acc: 0.9766 -- iter: 22400/22500
  11. Training Step: 7037 | total loss: 0.09968 | time: 92.877s
  12. | Adam | epoch: 010 | loss: 0.09968 - acc: 0.9789 -- iter: 22432/22500
  13. Training Step: 7038 | total loss: 0.09557 | time: 93.015s
  14. | Adam | epoch: 010 | loss: 0.09557 - acc: 0.9779 -- iter: 22464/22500
  15. Training Step: 7039 | total loss: 0.09463 | time: 93.138s
  16. | Adam | epoch: 010 | loss: 0.09463 - acc: 0.9739 -- iter: 22496/22500
  17. Training Step: 7040 | total loss: 0.09200 | time: 95.779s
  18. | Adam | epoch: 010 | loss: 0.09200 - acc: 0.9734 | val_loss: 0.76267 - val_acc: 0.7944 -- iter: 22500/22500
  19. --


查看tensorboard的情况:

 
   
  1. tensorboard --logdir="/tmp/tflearn_logs/"

你可能感兴趣的:(python编程,机器学习)