首先奉上sparse autoencoder的官方资料:https://web.stanford.edu/class/cs294a/sparseAutoencoder.pdf
为了看懂后面的代码,先对一些有困惑的代码做一些解释。
1.参考文献1:https://ask.csdn.net/questions/749205
关于python在list中使用for i in range()的问题,list是用方括号[]表示的列表,例如
[7 for i in range(3)]
则会生成一个长度为3且全部值均为7的列表
2.参考文献2:https://www.cnblogs.com/sbj123456789/p/9231571.html
关于pytorch中tensor的squeeze()和unsqueeze()函数
squeeze(arg)表示第arg维度值为1,则去掉该维度,否则tensor保持不变,看以下例子:
unsqueeze(arg)与squeeze(arg)作用相反,表示在第arg维增加一个维度为1的维度,看以下代码:
好了,有了以上的铺垫再看懂sparse autoencoder的代码基本问题不大了,至少对我来说是这个样子。需要指出的是,虽然叫做稀疏自编码模型,但模型上仍然是一个普通的自编码模型,只不过是在损失函数上增加了一些东西体现稀疏约束,在实现自编码模型的时候与传统的自编码模型并没有任何不同!ok,先定义一个最简单的自编码模型,为了对应全文,我们仍然起名叫做SparseAutoencoderModel.py:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SparseAutoencoder(nn.Module):
def __init__(self,n_inp,n_hidden):
super(SparseAutoencoder,self).__init__()
self.encoder=nn.Linear(n_inp,n_hidden)
self.decoder=nn.Linear(n_hidden,n_inp)
def forward(self,x):
encoded=F.sigmoid(self.encoder(x))
decoded=F.sigmoid(self.decoder(encoded))
return encoded,decoded
在定义好基础模型以后,即可训练一个SparseAutoencoder.再重复一遍,所谓的sparse只不过是在损失函数中添加了体现sparse的东西!TrainSparseAutoencoderModel.py的代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import os
import matplotlib.pyplot as plt
import math
from SparseAutoencoderModel import SparseAutoencoder
# Define some global constants
BATCH_SIZE=32
BETA=3
RHO=0.01 #体现稀疏性的一个约束,读本文最上方吴恩达关于稀疏自编码的讲义即可得知其含义
N_INP=784
N_HIDDEN=300
N_EPOCHS=1
USE_SPARSE=True #是否使用稀疏自编码,如果是,就在损失函数中添加一个KL散度的损失
#生成长度为N_HIDDEN值为RHO的list并转为tensor并添加一个维度
rho=torch.FloatTensor([RHO for _ in range(N_HIDDEN)]).unsqueeze(0)
root='./data'
if not os.path.exists(root):
os.makedir(root)
transform=transforms.Compose([torch.ToTensor()])
train_set=datasets.FashionMNIST(root=root,train=True,transform=transform,download=True)
test_set=datasets.FashionMNIST(root=root,train=False,transform=transform,download=True)
train_loader=torch.utils.data.DataLoader(dataset=train_set,
batch_size=BATCHSIZE,
shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=test_set,
batch_size=BATCHSIZE,
shuffle=False)
auto_encoder=SparseAutoencoder(N_INP,N_HIDDEN)
optimizer=optim.Adam(auto_encoder.parameters(),lr=1e-3)
#定义一个kl损失来体现稀疏约束
def kl_divergence(p,q)
p=F.softmax(p,dim=1)
q=F.softmax(q,dim=1)
s1=torch.sum(p*torch.log(p/q))
s2=torch.sum((1-p)*torch.log((1-p)/(1-q)))
return s1+s2
#以下代码纯属为了可视化 set plot and view data for visualization
# N_COLS*N_ROWS=32正好等于我们设置的BATCHSIZE
N_COLS=8
N_ROWS=4
# 看看本文最上面的铺垫,就明白下面这一行代码是为了把每张图像放进view_data里面
view_data=[test_set[i][0] for i in range(N_ROWS*N_COLS)]
plt.figure(figsize=(20,4)) #设置figure的大小,但是这个“大小”具体是如何定义的我没有查
for epoch in range(EPOCHS):
for b_index, (x,_) in enumerate(train_loader)
#把原始图像拉成列向量,原因在于定义的encoder输入是一个784维向量,而不是28x28原始图像
# x.size()返回值应该是torch.Size([32,1,28,28]),其中32表示BATCHSIZE,即一次输入图像
# 的数量,1表示图像的通道数,28x28表示图像尺寸 x.size()[0]返回32
# x.view(x.size()[0],-1)把输入重塑为32行,每一行维度为1x28x28,-1参数会让view自己算
# 第二个参数应该为多少维度,或者说-1的位置表示所有维度大小乘积/x.size()[0]
# 执行完以下代码后,x.shape变为torch.Size([32,784])
x=x.view(x.size()[0],-1)
x=Variable(x) #转换为Variable使之能生成计算图(最新版废弃Variable?)
#关注encoded的尺寸,其shape为[32,300],每一行对应一张图像
encoded,decoded=auto_encoder(x)
MSE_loss=(x-decoded)**2
# 注意,上面的MSE_loss的shape仍然为(32,784)
# MSE_loss.view(1,-1)后其shape变成[1,25088]向量的形式,sum(1)对列求和
MSE_loss=MSE_loss.view(1,-1).sum(1)/BATCH_SIZE
if USE_SPARSE:
#若keepdim值为True,则在输出张量中,除了被操作的dim维度值降为1,其它维度与输入
# 张量input相同。否则,dim维度相当于被执行torch.squeeze()维度压缩操作,
# 导致此维度消失,最终输出张量会比输入张量少一个维度。被操作对象encoded原来的尺寸
# 为32x300,经过以下代码后rho_hat的尺寸为1x300,正好和前面的rho对应
rho_hat=torch.sum(encoded,dim=0,keepdim=True)
#希望每个样本对应的平均激活都接近于rho,其中rho是我们预定义的一个稀疏指标
sparsity_penalty=BETA*kl_divergence(rho,rho_hat)
loss=MSE_loss+sparsity_penalty
else:
loss=MSE_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch: {:3d}, Loss: {:.4f}".format(epoch+1,loss.data[0]))
# 以下代码是为了可视化
for i in range(N_ROWS*N_COLS):
# original image
r=i//N_COLS
C=i%N_COLS+1
ax=plt.subplot(2*N_ROWS,N_COLS,2*r*N_COLS+c)
plt.imshow(view_data[i].squeeze())
plt.gray
ax.get_xaxis().set_visible(False)
ay.get_yaxis().set_visible(False)
# reconstructed image
ax=plt.subplot(2*N_ROWS,N_COLS,2*r*N_COLS+c+N_COLS)
x=Variable(view_data[x])
e,y=autoencoder(x.view(1,-1))
plt.imshow(y.detach().squeeze().numpy().reshape(28,28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax_get_yaxis().set_visible(False)
plt.show()
以上关键代码可与吴恩达老师的讲义内容对应起来