最近看了一圈技术栈,感觉无论是自然语言处理或者说是计算机视觉,网上都有一大堆成体系的教学可供参考。但是反观推荐算法这个方向却是寥寥无几。写这篇文章出于两个目的:1.巩固自己的学习,2.对外输出所学。
何向南老师github:
https://github.com/hexiangnan/neural_collaborative_filtering
我们先看一下数据集组成。
然后今天说的是 load_dataset做了什么事情。
先上代码:
导包
import pandas as pd
import numpy as np
import math
from collections import defaultdict
import heapq
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.backends.cudnn as cudnn
import os
载入数据。
def load_dataset(test_num=100):
train_data = pd.read_csv("./ncf_data/ml-1m.train.rating", \
sep='\t', header=None, names=['user', 'item'], \
usecols=[0, 1], dtype={0: np.int32, 1: np.int32})
user_num = train_data['user'].max() + 1
item_num = train_data['item'].max() + 1
train_data = train_data.values.tolist()
#load ratings as a dok matrix
train_mat = sp.dok_matrix((user_num,item_num),dtype=np.float32)
for x in train_data:
train_mat[x[0], x[1]] = 1.0
test_data = []
with open("/data/fjsdata/ctKngBase/ml/ml-1m.test.negative", 'r') as fd:
line = fd.readline()
while line != None and line != '':
arr = line.split('\t')
u = eval(arr[0])[0]
test_data.append([u, eval(arr[0])[1]])#one postive item
for i in arr[1:]:
test_data.append([u, int(i)]) #99 negative items
line = fd.readline()
return train_data, test_data, user_num, item_num, train_mat
先说一下 ml-1m.train.rating 文件
这个文件有列,分别是user,item,评分,时间戳(这个我也记不清是不是了)。
#load ratings as a dok matrix
train_mat = sp.dok_matrix((user_num,item_num),dtype=np.float32)
for x in train_data:
train_mat[x[0], x[1]] = 1.0
上面这段代码是把所有打分交互过的用户,项目,组成一个矩阵,数据结构是这个样子的:(User,Item):1
这里补充一下,哪怕是用户打分只有1分,对应字典也是1.0
如图:
处理test_data
先看一下数据格式:
这里说明一下,由于作者在paper中没有明确说明(也可能是我没仔细看)
这个元组里面是用户项目交互,元组外面的一堆是未交互
所以在这里我们代码意思是把元组拿出来,作为积极,剩下的u对应下面这一串未交互的为消极。
test_data = []
with open("./ncf_data/ml-1m.test.negative", 'r') as fd:
line = fd.readline()
while line != None and line != '':
arr = line.split('\t')
u = eval(arr[0])[0]
test_data.append([u, eval(arr[0])[1]])#one postive item
for i in arr[1:]:
test_data.append([u, int(i)]) #99 negative items
line = fd.readline()