应用场景是:我们要把一个稀疏矩阵分解为两个低秩的矩阵相乘;两个低秩的矩阵除了降维之外,还分别代表不同的含义。
以推荐为例: 用户点击商品的关系矩阵R则是稀疏的,我们分解为两个低秩矩阵,分别代表用户特征矩阵U和商品特征矩阵I,这个特征则就是隐含的语义信息。
形式化的表述一下:
打分矩阵R是近似低秩的,即m*n的矩阵可以分解为:
U(m*k)和V(n*k)的乘积来近似:
其中U则是用户喜好特征矩阵;V则是产品特征矩阵。
矩阵分解的损失函数:
1. 两个变量,固定其中一个。对另外一个求极值。
2. 固定A,对B求偏导,利用导数为0求极值。从而继续优化。
损失函数如上,求矩阵U和V
1 初始化U,固定U
2 求导得到V的求解公式,然后得到V
3 固定V
4 求导得到U的求解方法,得到U
5 如此反复1-4,直到C达到条件或者step走完。
最小二乘法的loss function是平方损失函数,这也是回归学习最常用的损失函数。所谓回归,则是利用某个h来拟合数据。h可以是线性的或者非线性的。
回归问题的损失函数如果是平方损失函数,,可以由最小二乘法(least squares)求解。
最小二乘法的算法原理,固定一个变量,通过求偏导得到公式后,求极值;同样的套路用于另外一个变量。这种思想值得学习。
如果得到了U/V,那么通过U和V的乘积,就可以得到R的稀疏填充。当然,U本身也能用于计算用户相似度,从而用于用户聚类。同理V也可以类似操作,从而实现“人以群分,物以类聚”的目标。
第一部分是方差,第二部分是规则化项,用了Ui和Vj
最小二乘法则是固定一个求解另一个,那么里面就2个变量,那么则求偏导得到极小值。
对 vj 求偏导:
这个推导过程开始非常不理解,这里主要涉及转置矩阵的求导问题。
举个例子就好了。
如果对A^T关于A求导,应该先转置后才可以求导。所以才会出现
uTi
继续对公式1推导:
类似的方法得到U矩阵。
这里再细化其中的中间推导步骤:
对损失函数进行求偏导。
由于是求单个 uij 那么需要找到每个项的更新公式。
由上述的损失函数来定义:
import numpy as np
def loss(R ,P, Q, K, alpha=0.001, lambd=0.02):
#print R.shape,P.shape,Q.shape
#print R
#print P
#print Q
tempQ = Q.T
loss = 0
for i in xrange(len(R)):
for j in xrange(len(R[0])):
if R[i,j]>0:##no value is continue
#print "p[i,:]",P[i,:]
#print "Q[:,j]",tempQ[:,j]
loss +=pow((R[i,j]-np.dot(P[i,:],tempQ[:,j])),2)#(observe-predict)^2
for k in xrange(K):
loss+=(lambd/2)*(pow(P[i][k],2)+pow(tempQ[k][j],2))
return loss
def gd(R,P,Q,K, alpha=0.001, lambd=0.02):
QT = Q.T
for ite in xrange(100):
for i in xrange(len(P)):
for j in xrange(len(QT)):
if R[i,j]>0:
Rij = R[i,j]
evalRij = 0
for k in xrange(K):
evalRij+=P[i,k]*QT[k,j]
error = Rij - evalRij
#update pik qkj
for k in xrange(K):
P[i,k] = P[i,k]+alpha*(2*error*QT[k,j]-lambd*P[i,k])
QT[k,j] = QT[k,j]+alpha*(2*error*P[i,k]-lambd*QT[k,j])
##caculate loss
#print P,QT
lossc = loss(R,P,QT.T,K,0.001,0.02)
print "gradient descent,iterate:",ite,"loss:",lossc
return P,Q
def als(R ,P, Q, K, alpha=0.001, lambd=0.02):
shape = (K,K)
earray = np.ones(shape) #array
E = np.mat(earray)
# print(np.eye(3,3))##dialog
print E
lenP = len(P)
lenQ = len(Q)
steps = 100
for step in xrange(steps):
lossvalue = 12332423
for i in xrange(lenP):
#print i
#print Q.shape,(Q.T).shape
#print Q
M1 = np.dot((Q.T),Q)
M1 = lambd*E+M1
M1_1 = np.mat(M1).I
#print "M1-1:", M1_1.shape ##(2,2)
QT = Q.T
#print "QT:" , QT.shape ##(2,4)
Ri = np.mat(R[i,:]).T
#print "Ri:", Ri.shape ##(4,1)
Pi = np.dot(np.dot(M1_1,QT),Ri) ##(2,1)
#print Pi
P[i,:] = Pi[:,0].T ## update the i-th
#print R
#print P
#print Q
#print "P[i]",P[i,:]
# print Pi[:,0].T
#print Pi
#print R.shape,P.shape,Q.shape
lossvalue = loss(R,P,Q,K)
# print "loss:",loss(R ,P, Q, K,0.00,0.02)
#print "update User-Feature, loss:",lossvalue
for j in xrange(lenQ):
M1 = np.dot((P.T), P)
M1 = lambd * E + M1
M1_1 = np.mat(M1).I
#print "M1-1:", M1_1.shape ##(2,2)
QT = P.T
#print "PT:", QT.shape ##(2,5)
Ri = np.mat(R[:,j]).T
#print "Rj:", Ri.shape ##(5,1)
Qi = np.dot(np.dot(M1_1, QT), Ri) ##(2,1)
#print Qi
Q[j, :] = Qi[:, 0].T ## update the i-th
# print "P[i]",P[i,:]
# print Pi[:,0].T
# print Pi
# print R.shape,P.shape,Q.shape
lossvalue = loss(R, P, Q, K)
#print "RootMeanSE, loss:",lossvalue
#print "update Item-Feature"
if lossvalue<0.01:
break
print "als,iteration:", step,"loss:",lossvalue
return P,Q
if __name__ == "__main__":
R = [
[5,3,0,1],
[4,0,0,1],
[1,1,0,5],
[1,0,0,4],
[0,1,5,4],
]##(5,4,), k=2
R = np.array(R)
N, M, K = len(R), len(R[0]),2
print N,M,K #n=5,m=4,k=2
U = np.random.rand(N,K)
V = np.random.rand(M,K)
# print loss function before optimize
losst = loss(R,U,V,K,0.001,0.02)
# gradient descent
gd(R,U,V,K,0.001,0.02)
# als
P,Q = als(R,U, V, K, 0.001, 0.02)
# print np.dot(P,Q.T)
#als(R,P,Q,K,0.001,0.02)
# nP, nQ = mf_als(R, P, Q, K)
#print numpy.dot(nP, nQ.T)
运行结果
5 4 2
gradient descent,iterate: 0 loss: 92.3719163472
gradient descent,iterate: 1 loss: 92.0064419565
gradient descent,iterate: 2 loss: 91.638861411
gradient descent,iterate: 3 loss: 91.2691298415
gradient descent,iterate: 4 loss: 90.8972092357
gradient descent,iterate: 5 loss: 90.5230684137
gradient descent,iterate: 6 loss: 90.1466830058
gradient descent,iterate: 7 loss: 89.7680354342
gradient descent,iterate: 8 loss: 89.3871148957
gradient descent,iterate: 9 loss: 89.0039173456
gradient descent,iterate: 10 loss: 88.6184454813
gradient descent,iterate: 11 loss: 88.230708725
gradient descent,iterate: 12 loss: 87.8407232053
gradient descent,iterate: 13 loss: 87.4485117355
gradient descent,iterate: 14 loss: 87.0541037888
gradient descent,iterate: 15 loss: 86.6575354697
gradient descent,iterate: 16 loss: 86.2588494798
gradient descent,iterate: 17 loss: 85.8580950782
gradient descent,iterate: 18 loss: 85.455328036
gradient descent,iterate: 19 loss: 85.0506105826
gradient descent,iterate: 20 loss: 84.644011345
gradient descent,iterate: 21 loss: 84.2356052782
gradient descent,iterate: 22 loss: 83.8254735868
gradient descent,iterate: 23 loss: 83.4137036363
gradient descent,iterate: 24 loss: 83.000388855
gradient descent,iterate: 25 loss: 82.5856286247
gradient descent,iterate: 26 loss: 82.1695281598
gradient descent,iterate: 27 loss: 81.7521983756
gradient descent,iterate: 28 loss: 81.3337557442
gradient descent,iterate: 29 loss: 80.914322138
gradient descent,iterate: 30 loss: 80.4940246613
gradient descent,iterate: 31 loss: 80.0729954683
gradient descent,iterate: 32 loss: 79.6513715696
gradient descent,iterate: 33 loss: 79.2292946243
gradient descent,iterate: 34 loss: 78.806910721
gradient descent,iterate: 35 loss: 78.3843701448
gradient descent,iterate: 36 loss: 77.9618271326
gradient descent,iterate: 37 loss: 77.5394396156
gradient descent,iterate: 38 loss: 77.1173689505
gradient descent,iterate: 39 loss: 76.6957796387
gradient descent,iterate: 40 loss: 76.2748390345
gradient descent,iterate: 41 loss: 75.8547170432
gradient descent,iterate: 42 loss: 75.4355858094
gradient descent,iterate: 43 loss: 75.0176193952
gradient descent,iterate: 44 loss: 74.6009934517
gradient descent,iterate: 45 loss: 74.1858848816
gradient descent,iterate: 46 loss: 73.772471496
gradient descent,iterate: 47 loss: 73.3609316653
gradient descent,iterate: 48 loss: 72.9514439659
gradient descent,iterate: 49 loss: 72.5441868228
gradient descent,iterate: 50 loss: 72.1393381505
gradient descent,iterate: 51 loss: 71.7370749924
gradient descent,iterate: 52 loss: 71.3375731602
gradient descent,iterate: 53 loss: 70.9410068748
gradient descent,iterate: 54 loss: 70.5475484097
gradient descent,iterate: 55 loss: 70.1573677376
gradient descent,iterate: 56 loss: 69.7706321829
gradient descent,iterate: 57 loss: 69.3875060799
gradient descent,iterate: 58 loss: 69.0081504386
gradient descent,iterate: 59 loss: 68.6327226198
gradient descent,iterate: 60 loss: 68.2613760192
gradient descent,iterate: 61 loss: 67.8942597638
gradient descent,iterate: 62 loss: 67.5315184194
gradient descent,iterate: 63 loss: 67.1732917125
gradient descent,iterate: 64 loss: 66.8197142655
gradient descent,iterate: 65 loss: 66.4709153481
gradient descent,iterate: 66 loss: 66.1270186443
gradient descent,iterate: 67 loss: 65.7881420365
gradient descent,iterate: 68 loss: 65.4543974072
gradient descent,iterate: 69 loss: 65.125890459
gradient descent,iterate: 70 loss: 64.8027205538
gradient descent,iterate: 71 loss: 64.4849805704
gradient descent,iterate: 72 loss: 64.1727567828
gradient descent,iterate: 73 loss: 63.8661287575
gradient descent,iterate: 74 loss: 63.5651692711
gradient descent,iterate: 75 loss: 63.2699442482
gradient descent,iterate: 76 loss: 62.980512719
gradient descent,iterate: 77 loss: 62.6969267969
gradient descent,iterate: 78 loss: 62.4192316759
gradient descent,iterate: 79 loss: 62.1474656469
gradient descent,iterate: 80 loss: 61.8816601336
gradient descent,iterate: 81 loss: 61.6218397459
gradient descent,iterate: 82 loss: 61.368022352
gradient descent,iterate: 83 loss: 61.1202191673
gradient descent,iterate: 84 loss: 60.8784348593
gradient descent,iterate: 85 loss: 60.6426676692
gradient descent,iterate: 86 loss: 60.4129095477
gradient descent,iterate: 87 loss: 60.1891463043
gradient descent,iterate: 88 loss: 59.971357771
gradient descent,iterate: 89 loss: 59.7595179765
gradient descent,iterate: 90 loss: 59.5535953326
gradient descent,iterate: 91 loss: 59.3535528301
gradient descent,iterate: 92 loss: 59.159348244
gradient descent,iterate: 93 loss: 58.9709343462
gradient descent,iterate: 94 loss: 58.788259126
gradient descent,iterate: 95 loss: 58.6112660156
gradient descent,iterate: 96 loss: 58.4398941211
gradient descent,iterate: 97 loss: 58.2740784571
gradient descent,iterate: 98 loss: 58.1137501841
gradient descent,iterate: 99 loss: 57.9588368481
[[ 1. 1.]
[ 1. 1.]]
als,iteration: 0 loss: 21.5978330148
als,iteration: 1 loss: 17.3428069177
als,iteration: 2 loss: 16.3502902377
als,iteration: 3 loss: 16.2712840045
als,iteration: 4 loss: 16.2936489968
als,iteration: 5 loss: 16.3098592641
als,iteration: 6 loss: 16.3174969147
als,iteration: 7 loss: 16.3211261631
als,iteration: 8 loss: 16.3231790031
als,iteration: 9 loss: 16.32465131
als,iteration: 10 loss: 16.3259224131
als,iteration: 11 loss: 16.3271331615
als,iteration: 12 loss: 16.3283352168
als,iteration: 13 loss: 16.3295472843
als,iteration: 14 loss: 16.3307759489
als,iteration: 15 loss: 16.3320233491
als,iteration: 16 loss: 16.3332899979
als,iteration: 17 loss: 16.3345758181
als,iteration: 18 loss: 16.3358805223
als,iteration: 19 loss: 16.3372037526
als,iteration: 20 loss: 16.3385451309
als,iteration: 21 loss: 16.339904278
als,iteration: 22 loss: 16.3412808197
als,iteration: 23 loss: 16.3426743898
als,iteration: 24 loss: 16.3440846298
als,iteration: 25 loss: 16.3455111898
als,iteration: 26 loss: 16.3469537282
als,iteration: 27 loss: 16.348411911
als,iteration: 28 loss: 16.3498854124
als,iteration: 29 loss: 16.3513739138
als,iteration: 30 loss: 16.3528771043
als,iteration: 31 loss: 16.3543946799
als,iteration: 32 loss: 16.3559263435
als,iteration: 33 loss: 16.3574718051
als,iteration: 34 loss: 16.359030781
als,iteration: 35 loss: 16.3606029939
als,iteration: 36 loss: 16.3621881728
als,iteration: 37 loss: 16.3637860527
als,iteration: 38 loss: 16.3653963745
als,iteration: 39 loss: 16.3670188849
als,iteration: 40 loss: 16.3686533359
als,iteration: 41 loss: 16.3702994852
als,iteration: 42 loss: 16.3719570957
als,iteration: 43 loss: 16.3736259354
als,iteration: 44 loss: 16.3753057774
als,iteration: 45 loss: 16.3769963994
als,iteration: 46 loss: 16.3786975841
als,iteration: 47 loss: 16.3804091188
als,iteration: 48 loss: 16.3821307953
als,iteration: 49 loss: 16.3838624095
als,iteration: 50 loss: 16.385603762
als,iteration: 51 loss: 16.3873546572
als,iteration: 52 loss: 16.3891149038
als,iteration: 53 loss: 16.3908843144
als,iteration: 54 loss: 16.3926627052
als,iteration: 55 loss: 16.3944498966
als,iteration: 56 loss: 16.3962457123
als,iteration: 57 loss: 16.3980499797
als,iteration: 58 loss: 16.3998625296
als,iteration: 59 loss: 16.4016831964
als,iteration: 60 loss: 16.4035118176
als,iteration: 61 loss: 16.405348234
als,iteration: 62 loss: 16.4071922895
als,iteration: 63 loss: 16.4090438312
als,iteration: 64 loss: 16.4109027091
als,iteration: 65 loss: 16.412768776
als,iteration: 66 loss: 16.4146418879
als,iteration: 67 loss: 16.4165219033
als,iteration: 68 loss: 16.4184086834
als,iteration: 69 loss: 16.4203020923
als,iteration: 70 loss: 16.4222019965
als,iteration: 71 loss: 16.424108265
als,iteration: 72 loss: 16.4260207694
als,iteration: 73 loss: 16.4279393836
als,iteration: 74 loss: 16.429863984
als,iteration: 75 loss: 16.4317944491
als,iteration: 76 loss: 16.4337306598
als,iteration: 77 loss: 16.4356724992
als,iteration: 78 loss: 16.4376198524
als,iteration: 79 loss: 16.4395726066
als,iteration: 80 loss: 16.4415306513
als,iteration: 81 loss: 16.4434938777
als,iteration: 82 loss: 16.4454621792
als,iteration: 83 loss: 16.4474354508
als,iteration: 84 loss: 16.4494135896
als,iteration: 85 loss: 16.4513964946
als,iteration: 86 loss: 16.4533840662
als,iteration: 87 loss: 16.455376207
als,iteration: 88 loss: 16.457372821
als,iteration: 89 loss: 16.4593738139
als,iteration: 90 loss: 16.4613790932
als,iteration: 91 loss: 16.4633885677
als,iteration: 92 loss: 16.465402148
als,iteration: 93 loss: 16.4674197461
als,iteration: 94 loss: 16.4694412756
als,iteration: 95 loss: 16.4714666513
als,iteration: 96 loss: 16.4734957898
als,iteration: 97 loss: 16.4755286087
als,iteration: 98 loss: 16.4775650272
als,iteration: 99 loss: 16.4796049658
Process finished with exit code 0
本文从稀疏矩阵分解引入,构建了平方损失函数,从而引入最小二乘法和随机梯度下降两种求解方法,并进行了求导的推导,并用python来实战观察两种方法的效果。
从中可以看出,理解一个算法,除了理解应用场景,数学原理,还要实战操作,三者相互印证可以理解的稍微深入一些。