此博文主要是为了用RNN 做图像分类,来了解pytorch的RNN用于训练图像时的用法。
mnist数据集中的图像:图像1由28*28个像素点组成如下图所示
对于此图像我们可以将每张图像看作是长28的序列,序列中的每个元素的特征维度为28.
import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
data_tf = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5], [0.5]) # 标准化
])
train_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=True, transform=data_tf)
test_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=False, transform=data_tf)
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
class rnn_classify(nn.Module):
def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
super(rnn_classify, self).__init__()
self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm
self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果
def forward(self, x):
'''
x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)
'''
x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)
x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)
out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)
out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)
out = self.classifier(out) # 得到分类结果
return out
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)
from datetime import datetime
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
if torch.cuda.is_available():
net = net.cuda()
prev_time = datetime.now()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
for im, label in train_data:
if torch.cuda.is_available():
im = Variable(im.cuda()) # (bs, 3, h, w)
label = Variable(label.cuda()) # (bs, h, w)
else:
im = Variable(im)
label = Variable(label)
# forward
output = net(im)
loss = criterion(output, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += get_acc(output, label)
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None:
valid_loss = 0
valid_acc = 0
net = net.eval()
for im, label in valid_data:
if torch.cuda.is_available():
im = Variable(im.cuda(), volatile=True)
label = Variable(label.cuda(), volatile=True)
else:
im = Variable(im, volatile=True)
label = Variable(label, volatile=True)
output = net(im)
loss = criterion(output, label)
valid_loss += loss.item()
valid_acc += get_acc(output, label)
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))
prev_time = cur_time
print(epoch_str + time_str)
train(net, train_data, test_data, 10, optimzier, criterion)
完整代码:
import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
from datetime import datetime
import torch.nn.functional as F
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
# 定义数据
data_tf = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5], [0.5]) # 标准化
])
train_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=True, transform=data_tf)
test_set = MNIST(r'C:\Users\Administrator.SKY-20180518VHY\Desktop\pytorch\data', train=False, transform=data_tf)
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
# 定义模型
class rnn_classify(nn.Module):
def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
super(rnn_classify, self).__init__()
self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm
self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果
def forward(self, x):
'''
x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)
'''
x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)
x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)
out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)
out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)
out = self.classifier(out) # 得到分类结果
return out
# 定义网络、损失函数、优化器
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)
# 定义训练函数
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
if torch.cuda.is_available():
net = net.cuda()
prev_time = datetime.now()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
for im, label in train_data:
if torch.cuda.is_available():
im = Variable(im.cuda()) # (bs, 3, h, w)
label = Variable(label.cuda()) # (bs, h, w)
else:
im = Variable(im)
label = Variable(label)
# forward
output = net(im)
loss = criterion(output, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += get_acc(output, label)
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None:
valid_loss = 0
valid_acc = 0
net = net.eval()
for im, label in valid_data:
if torch.cuda.is_available():
im = Variable(im.cuda(), volatile=True)
label = Variable(label.cuda(), volatile=True)
else:
im = Variable(im, volatile=True)
label = Variable(label, volatile=True)
output = net(im)
loss = criterion(output, label)
valid_loss += loss.item()
valid_acc += get_acc(output, label)
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))
prev_time = cur_time
print(epoch_str + time_str)
注:参考廖星宇大神的《深度学习之pytorch》