pytorch LSTM 训练CIFAR10数据集

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch import optim
from torch.autograd import Variable
import torch.nn.functional as F

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])
])

trainsets = torchvision.datasets.CIFAR10(root='./data',train=True,download=False,transform=transform)
trainloader = torch.utils.data.DataLoader(trainsets,batch_size=100,shuffle=True)
testsets = torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=transform)
testloader = torch.utils.data.DataLoader(trainsets,batch_size=100,shuffle=False)

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.LSTM = nn.LSTM(32*3,128,batch_first=True,num_layers=3)#将彩色图片输入给LSTM怎么办
        self.output = nn.Linear(128,10)
    def forward(self,x):

        out,(h_n,c_n) = self.LSTM(x)
        return self.output(out[:,-1,:])

if __name__ == '__main__':
    net = Net()
    Loss = nn.CrossEntropyLoss()
    Opt = optim.Adam(net.parameters(),lr=0.01)
    for i in range(100):
        for data,lable in trainloader:
            data = Variable(data)
            lable = Variable(lable)
            data = data.view(-1,32,32*3)
            out = net(data)
            loss = Loss(out,lable)
            Opt.zero_grad()
            loss.backward()
            Opt.step()
            print(loss.data.item())

 

你可能感兴趣的:(人工智能,深度学习,计算机视觉)