之前介绍了两篇Nupic的技术细节—-脑皮层学习算法 —nupic的深入学习(一),脑皮层学习算法 —nupic的深入学习(二),但缺少了利用Nupic的具体实例。这篇文章会利用Nupic算法,基于已有的用户访问网站类别数据,预测用户的访问网站的类别。
数据,代码都在Github上。可下载。
列表中的元素就是网站类别
PAGE_CATEGORIES = [
"frontpage", "news", "tech", "local", "opinion", "on-air", "misc", "weather",
"msn-news", "health", "living", "business", "msn-sports", "sports", "summary",
"bbs", "travel"
]
下面就是算法要读取的数据(msnbc990928.zip)
% Different categories found in input file:
frontpage news tech local opinion on-air misc weather msn-news health living business msn-sports sports summary bbs travel
% Sequences:
1 1
2
3 2 2 4 2 2 2 3 3
5
1
6
1 1
6
6 7 7 7 6 6 8 8 8 8
6 9 4 4 4 10 3 10 5 10 4 4 4
1 1 1 11 1 1 1
12 12
数据序列中,每行代表一个用户的点击情况,比如第一行,用户先点击了frontpage 1次,然后点击了news 1次,算法要做的工作是,基于已有的用户点击行为,预测下一刻用户的点击行为。
在Github上下载源码,运行
python webdata.py
算法会完成所有操作,结果会打印在控制台。
算法分为两个架构:1.配置神经网络各个组件的参数;2.依次读取单个用户的数据,训练算法;3. 利用算法预测
分步骤讲述如下:
(一) 配置神经网络各个组件的参数
#网页的类别
# List of page categories used in the dataset
PAGE_CATEGORIES = [
"frontpage", "news", "tech", "local", "opinion", "on-air", "misc", "weather",
"msn-news", "health", "living", "business", "msn-sports", "sports", "summary",
"bbs", "travel"
]
#配置编码器,这里利用SDRCategoryEncoder
# Configure the sensor/input region using the "SDRCategoryEncoder" to encode
# the page category into SDRs suitable for processing directly by the TM
SENSOR_PARAMS = {
"verbosity": 0,
"encoders": {
"page": {
"fieldname": "page",
"name": "page",
"type": "SDRCategoryEncoder",
# The output of this encoder will be passed directly to the TM region,
# therefore the number of bits should match TM's "inputWidth" parameter
"n": 1024,
# Use ~2% sparsity
"w": 21
},
},
}
#配置时间池组件,使算法有学习功能的组件
# Configure the temporal memory to learn a sequence of page SDRs and make
# predictions on the next page of the sequence.
TM_PARAMS = {
"seed": 1960,
# Use "nupic.bindings.algorithms.TemporalMemoryCPP" algorithm
"temporalImp": "tm_cpp",
# Should match the encoder output
"inputWidth": 1024,
"columnCount": 1024,
# Use 1 cell per column for first order prediction.
# Use more cells per column for variable order predictions.
"cellsPerColumn": 1,
}
#配置Classifier组件,使得算法能够输出预测的网站类别
# Configure the output region with a classifier used to decode TM SDRs back
# into pages
CL_PARAMS = {
"implementation": "cpp",
"regionName": "SDRClassifierRegion",
# alpha parameter controls how fast the classifier learns/forgets. Higher
# values make it adapt faster and forget older patterns faster.
"alpha": 0.001,
"steps": 1,
}
#将所有的参数组合在一起,构成完成的Model
#顺序是# page => [encoder] => [TM] => [classifier] => prediction
# Create a simple HTM network that will receive the current page as input, pass
# the encoded page SDR to the temporal memory to learn the sequences and
# interpret the output SDRs from the temporary memory using the SDRClassifier
# whose output will be a list of predicted next pages and their probabilities.
MODEL_PARAMS = {
"version": 1,
"model": "HTMPrediction",
"modelParams": {
"inferenceType": "TemporalMultiStep",
"sensorParams": SENSOR_PARAMS,
# The purpose of the spatial pooler is to create a stable representation of
# the input SDRs. In our case the category encoder output is already a
# stable representation of the category therefore adding the spatial pooler
# to this network would not help and could potentially slow down the
# learning process
"spEnable": False,
"spParams": {},
"tmEnable": True,
"tmParams": TM_PARAMS,
"clParams": CL_PARAMS,
},
}
(二) 依次读取单个用户的数据,训练算法
model.enableLearning() #开启算法学习功能
for count in xrange(LEARNING_RECORDS): #遍历用于训练的用户
# Learn each user session as a single sequence
session = readUserSession(datafile) #读出每个用户的数据
model.resetSequenceStates() #初始化每个用户作为第一个训练样本
for page in session: #对于每个用户,逐次输入序列元素,训练
model.run({"page": page})
(三) 利用算法预测
# Infer one page of the sequence at the time
model.resetSequenceStates() #初始化每个用户作为第一个训练样本
session = readUserSession(datafile) #读取一个用户的点击序列
for page in session: #遍历该序列,预测下一步用户的点击
result = model.run({"page": page})
#取得预测结果,预测结果类似inferences: {'msn-news': 0.43472087316837504, 'misc': 0.04838564363576856, 'weather': 0.05264233343663612, …
inferences = result.inferences["multiStepPredictions"][1]
# Print predictions ordered by probabilities
predicted = sorted(inferences.items(), #将预测结果排序,概率大优先
key=itemgetter(1),
reverse=True)
prediction_table.add_row([page, zip(*predicted)[0]])
print "User Session to Predict: ", session
print prediction_table
### Sample output:
```text
The following table shows the encoded SDRs for every page category in the dataset
+---------------+-----------------------------------------------------------------------------+
| Page Category | Encoded SDR (on bit indices) |
+---------------+-----------------------------------------------------------------------------+
| bbs | [ 19 26 115 171 293 364 390 442 470 477 550 598 624 670 705 719 744 748 |
| | 788 850 956] |
| business | [ 48 104 144 162 213 280 305 355 376 403 435 628 694 724 780 850 854 870 |
| | 891 930 955] |
| frontpage | [ 4 7 35 37 48 91 118 143 155 313 339 410 560 627 736 762 795 864 |
| | 885 889 966] |
| health | [ 50 67 124 209 214 229 288 337 380 402 437 474 566 584 614 |
| | 661 754 840 846 894 1008] |
| living | [195 198 209 219 261 317 332 348 353 369 371 375 399 495 501 556 595 758 |
| | 799 813 920] |
| local | [ 3 48 221 275 284 457 466 516 574 626 645 688 699 761 855 867 899 925 |
| | 942 987 997] |
| misc | [ 40 61 90 106 127 179 202 208 217 373 417 523 577 580 722 751 865 925 |
| | 926 928 938] |
| msn-news | [ 29 71 72 74 149 241 261 263 276 365 465 528 529 575 577 |
| | 661 781 799 830 980 1019] |
| msn-sports | [119 138 150 164 197 263 391 454 510 581 589 614 661 700 724 742 809 886 |
| | 889 978 989] |
| news | [ 18 44 71 109 191 322 333 337 375 402 447 587 653 660 794 |
| | 837 853 913 936 954 1019] |
| on-air | [ 27 80 134 158 187 199 214 286 374 439 445 484 490 590 670 |
| | 771 823 934 952 965 1014] |
| opinion | [163 165 216 241 251 260 307 336 382 449 493 540 607 668 679 717 736 866 |
| | 888 902 981] |
| sports | [ 20 39 65 141 147 230 232 248 332 361 467 476 689 847 851 |
| | 862 866 889 936 958 1010] |
| summary | [ 32 34 106 206 302 340 414 564 566 568 596 619 645 657 761 813 879 888 |
| | 897 944 997] |
| tech | [108 276 327 372 411 431 479 577 592 606 650 690 747 756 763 913 936 949 |
| | 961 981 983] |
| travel | [149 164 179 239 316 319 365 427 437 470 632 729 739 748 787 818 821 824 |
| | 834 906 919] |
| weather | [ 9 12 21 38 45 146 203 205 284 400 471 506 520 532 595 613 621 639 |
| | 805 970 987] |
+---------------+-----------------------------------------------------------------------------+
Start Learning page sequences using the first 10000 user sessions
Learned 10000 Sessions
Finished Learning
Start Inference using a new user session from the dataset
User Session to Predict: ['on-air', 'misc', 'misc', 'misc', 'on-air', 'misc', 'misc', 'misc', 'on-air', 'on-air', 'on-air', 'on-air', 'tech', 'msn-news', 'tech', 'msn-news', 'local', 'tech', 'local', 'local', 'local', 'local', 'local', 'local']
+----------+---------------------------------------------------------------------------------------+
| Page | Prediction |
+----------+---------------------------------------------------------------------------------------+
| on-air | ('on-air', 'misc', 'frontpage', 'news', 'summary', 'msn-news', 'weather', 'local') |
+----------+---------------------------------------------------------------------------------------+
| misc | ('misc', 'frontpage', 'on-air', 'local', 'msn-news', 'msn-sports', 'news', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| misc | ('misc', 'frontpage', 'on-air', 'local', 'msn-news', 'msn-sports', 'news', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| misc | ('misc', 'frontpage', 'on-air', 'local', 'msn-news', 'msn-sports', 'news', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| on-air | ('on-air', 'misc', 'frontpage', 'news', 'summary', 'msn-news', 'weather', 'local') |
+----------+---------------------------------------------------------------------------------------+
| misc | ('misc', 'frontpage', 'on-air', 'local', 'msn-news', 'msn-sports', 'news', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| misc | ('misc', 'frontpage', 'on-air', 'local', 'msn-news', 'msn-sports', 'news', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| misc | ('misc', 'frontpage', 'on-air', 'local', 'msn-news', 'msn-sports', 'news', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| on-air | ('on-air', 'misc', 'frontpage', 'news', 'summary', 'msn-news', 'weather', 'local') |
+----------+---------------------------------------------------------------------------------------+
| on-air | ('on-air', 'misc', 'frontpage', 'news', 'summary', 'msn-news', 'weather', 'local') |
+----------+---------------------------------------------------------------------------------------+
| on-air | ('on-air', 'misc', 'frontpage', 'news', 'summary', 'msn-news', 'weather', 'local') |
+----------+---------------------------------------------------------------------------------------+
| on-air | ('on-air', 'misc', 'frontpage', 'news', 'summary', 'msn-news', 'weather', 'local') |
+----------+---------------------------------------------------------------------------------------+
| tech | ('tech', 'frontpage', 'news', 'msn-news', 'on-air', 'business', 'local', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| msn-news | ('msn-news', 'frontpage', 'local', 'weather', 'misc', 'on-air', 'msn-sports', 'tech') |
+----------+---------------------------------------------------------------------------------------+
| tech | ('tech', 'frontpage', 'news', 'msn-news', 'on-air', 'business', 'local', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| msn-news | ('msn-news', 'frontpage', 'local', 'weather', 'misc', 'on-air', 'msn-sports', 'tech') |
+----------+---------------------------------------------------------------------------------------+
| local | ('local', 'frontpage', 'misc', 'news', 'msn-news', 'on-air', 'weather', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| tech | ('tech', 'frontpage', 'news', 'msn-news', 'on-air', 'business', 'local', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| local | ('local', 'frontpage', 'misc', 'news', 'msn-news', 'on-air', 'weather', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| local | ('local', 'frontpage', 'misc', 'news', 'msn-news', 'on-air', 'weather', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| local | ('local', 'frontpage', 'misc', 'news', 'msn-news', 'on-air', 'weather', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| local | ('local', 'frontpage', 'misc', 'news', 'msn-news', 'on-air', 'weather', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| local | ('local', 'frontpage', 'misc', 'news', 'msn-news', 'on-air', 'weather', 'sports') |
+----------+---------------------------------------------------------------------------------------+
| local | ('local', 'frontpage', 'misc', 'news', 'msn-news', 'on-air', 'weather', 'sports') |
+----------+---------------------------------------------------------------------------------------+
Compute prediction accuracy by checking if the next page in the sequence is within the predicted pages calculated by the model:
- Prediction Accuracy: 0.614173228346
- Accuracy Predicting Top 3 Pages: 0.825196850394
代码中算法利用了10000条数据作为训练数据,然后训练了100条数据,这些都可以自己修改源码。从结果中看到,算法的正确率是0.6左右。如果将算法的预测扩大为3个候选(即预测下一次可能是A,B,或C,而不仅仅预测为单个),则准确率大大增加,为0.8左右。
算法没有与其他预测算法做过比较,因此也无从知道优劣,有兴趣者可以基于该数据集,利用其他预测算法尝试,对比两种算法的准确率。