PyTorch实现简单的数字分类

使电脑学会0.1-0.5为一类,大于0.5为一类


import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable


train_x_1=torch.linspace(0.1,0.5)  #数字归一化,非常重要

train_x_2=torch.linspace(0.5,1.0)  #数字归一化,非常重要

train_x=torch.cat((train_x_1,train_x_2))
print(train_x.size())
train_x=torch.unsqueeze(train_x,1)
print(train_x.size())




#print(train_x)

label_1=torch.zeros_like(train_x_1).type(torch.LongTensor)
label_2=torch.ones_like(train_x_2).type(torch.LongTensor)
label=torch.cat((label_1,label_2))
#label=torch.unsqueeze(label,1)


model=torch.nn.Sequential(
    torch.nn.Linear(1,5),
    torch.nn.Linear(5,2),

    torch.nn.Softmax()
)

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)

train_x=Variable(train_x)
label=Variable(label)
for echo in range(10000):
    print(train_x.data.size())

    predict=model(train_x)
    loss=criterion(predict,label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Loss:',loss)

for echo in range(200):
    x=float(input('输入测试数据'))
    x=torch.FloatTensor([[x]])
    x=Variable(x)
   # x=torch.unsqueeze(x,0)
    print('预测结果:',model(x))





你可能感兴趣的:(PyTorch实现简单的数字分类)