LSTM搭建自编码器提取特征,KNN分类
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 超参数
EPOCH = 200
LR = 0.005
data = load_iris()
y = data.target
x = data.data
#X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3)
#print(y_train)
class RNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.rnn = torch.nn.LSTM(
input_size=4,
hidden_size=64,
num_layers=1,
batch_first=True
)
self.out = torch.nn.Linear(in_features=64, out_features=3)
self.rnn_2 = torch.nn.LSTM(
input_size=3,
hidden_size=64,
num_layers=1,
batch_first=True
)
self.out_2 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
# 一下关于shape的注释只针对单项
# output: [batch_size, time_step, hidden_size]
# h_n: [num_layers,batch_size, hidden_size] # 虽然LSTM的batch_first为True,但是h_n/c_n的第一维还是num_layers
# c_n: 同h_n
output, (h_n, c_n) = self.rnn(x)
# output_in_last_timestep=output[:,-1,:] # 也是可以的
output_in_last_timestep = h_n[-1, :, :]
# print(output_in_last_timestep.equal(output[:,-1,:])) #ture
encode = self.out(output_in_last_timestep)
output1, (h_n1, c_n1) = self.rnn_2(encode.view(-1, 1, 3))
# output_in_last_timestep=output[:,-1,:] # 也是可以的
output_in_last_timestep1 = h_n1[-1, :, :]
# print(output_in_last_timestep.equal(output[:,-1,:])) #ture
decode = self.out_2(output_in_last_timestep1)
return encode, decode
net = RNN()
# 3. 训练
# 3. 网络的训练(和之前CNN训练的代码基本一样)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
loss_F = torch.nn.MSELoss()
for epoch in range(500): # 数据集只迭代一次
x1 = torch.from_numpy(x).unsqueeze(0).float()
x2 = torch.from_numpy(x).unsqueeze(0).float()
_, pred = net(x1.view(-1, 1, 4))
loss = loss_F(pred, x2) # 计算loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())
pred, _ = net(x1.view(-1, 1, 4))
print(pred)
print(pred.shape)
pred = pred.squeeze(1).detach().numpy()
print(pred)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
knn = KNeighborsClassifier(n_neighbors=5)
scores = cross_val_score(knn, pred, y, cv=6, scoring='accuracy')
print(scores)