#!/usr/bin/env python
# encoding: utf-8
'''
@author: taoshouzheng
@contact: [email protected]
@file: 1 lstm + linear.py
@time: 2019/11/7 9:15
'''
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
from torch.autograd import Variable
import math
import random
import numpy as np
from torch import optim
from sklearn.preprocessing import label_binarize
from sklearn.metrics import classification_report
class MyNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyNet, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.lstm = nn.LSTM(self.input_size, self.hidden_size)
self.linear = nn.Linear(self.hidden_size, self.output_size)
def init_state(self, batch_size, hidden_size):
h_init = Variable(torch.rand(1, batch_size, hidden_size))
c_init = Variable(torch.rand(1, batch_size, hidden_size))
return h_init, c_init
def forward(self, x, h, c):
output, (new_h, new_c) = self.lstm(x, (h, c))
result = self.linear(new_h)
return result
def prediction(self, x, h, c):
output, (new_h, new_c) = self.lstm(x, (h, c))
result = self.linear(new_h)
result = result.squeeze(0).squeeze(0)
print('11111', result.shape)
label1 = torch.max(result, 1)[1]
print('11111', label1.shape)
return label1
def reset_weights(model):
for weight in model.parameters():
init.constant_(weight, 0.5)
net = MyNet(10, 20, 3)
reset_weights(net)
epoch = 100
# 输入
input = Variable(torch.ones(5, 200, 10))
print('input的形状')
print(input.shape)
# 标签
label = []
for i in range(200):
label.append(random.choice([0, 1, 2]))
label = np.array(label, dtype=np.int)
# label_one_hot = label_binarize(label, np.arange(3))
target = Variable(torch.LongTensor(label))
print('target的形状')
print(target.shape)
h_init, c_init = net.init_state(200, 20)
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for i in range(100):
output = net(input, h_init, c_init)
output = output.squeeze(0)
_, pred = torch.max(output, 1)
loss = criterion(output, target)
print('epoch', i + 1, ':', loss)
loss.backward()
optimizer.step()
net.eval()
input = Variable(torch.ones(5, 200, 10))
# 标签
label = []
for i in range(200):
label.append(random.choice([0, 1, 2]))
y_test = np.array(label, dtype=np.int)
y_pred = net.prediction(input, h_init, c_init)
print(type(y_pred))
print(y_pred.shape)
ans = classification_report(y_test, y_pred, digits=5) # digits为输出浮点值的位数,support为每个标签出现的次数
print(ans)