文章目录
- 实验要求
- 一、加载Fashion-MNIST数据集
- 二、通过Dataloader读取小批量数据样本
- 三、构建模型
- 四、损失函数与优化器
- 五、测试集的准确度与损失计算
- 六、模型训练及测试
- 实验结果
实验要求
- 利用torch.nn实现softmax在Fashion-MNIST数据集上进行训练和测试
- 从loss,训练集以及测试集上的准确率等多个角度对结果进行分析
一、加载Fashion-MNIST数据集
mnist_train = torchvision.datasets.FashionMNIST(
root='~/Datasets/FashionMNIST',
train=True,
download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(
root='~/Datasets/FashionMNIST',
train=False,
download=True,
transform=transforms.ToTensor()
)
二、通过Dataloader读取小批量数据样本
batch_size = 256
train_iter = torch.utils.data.DataLoader(
mnist_train,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_iter = torch.utils.data.DataLoader(
mnist_train,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
三、构建模型
num_inputs = 784
num_outputs = 10
class softmaxnet(torch.nn.Module):
def __init__(self, n_features, n_labels):
super(softmaxnet, self).__init__()
self.linear = torch.nn.Linear(n_features, n_labels)
def softmax(self, X):
X_exp = X.exp()
partition = X_exp.sum(dim=1, keepdim=True)
return X_exp / partition
def forward(self, x):
x_ = x.view((-1, num_inputs))
y_ = self.linear(x_)
y_hat = self.softmax(y_)
return y_hat
四、损失函数与优化器
net = softmaxnet(num_inputs, num_outputs)
lr = 0.1
loss = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr)
五、测试集的准确度与损失计算
def get_test_info(data_iter, net):
right_count, all_count = 0.0, 0
for x, y in data_iter:
y_ = net(x)
l = loss(y_, y)
right_count += (y_.argmax(dim=1)==y).sum().item()
all_count += y.shape[0]
return right_count/all_count, l.item()
六、模型训练及测试
num_epoch = 5
for epoch in range(num_epoch):
train_r_num, train_all_num = 0.0, 0
for X, y in tqdm(train_iter):
y_ = net(X)
l = loss(y_, y)
l.backward()
optimizer.step()
optimizer.zero_grad()
train_r_num += (y_.argmax(dim=1) == y).sum().item()
train_all_num += y.shape[0]
test_acc, test_ave_loss = get_test_info(test_iter, net)
print('epoch %d, train loss %.4f, train acc %.3f' % (epoch+1, l.item(), train_r_num/train_all_num))
print('test loss %.4f, test acc %.3f' % (test_ave_loss, test_acc))
实验结果