import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available. Training on CPU...')
else:
print('CUDA is available! Training on GPU...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist data', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=batch_size, shuffle=False)
model = models.resnet18()
num_classes = 10
for param in model.parameters():
param.requires_grad = False
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Sequential(
nn.Dropout(),
nn.Linear(model.fc.in_features, num_classes),
nn.LogSoftmax(dim=1)
)
model.to(device)
# print(model)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
filename = "recognize_handwritten_digits.pt"
def save_checkpoint(epoch, model, optimizer, filename):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, filename)
num_epochs = 50
train_loss = []
for epoch in range(num_epochs):
running_loss = 0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(train_loader):
# 将数据放到设备上
inputs, labels = inputs.to(device), labels.to(device)
# 前向计算
outputs = model(inputs)
# 计算损失和梯度
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
# 更新模型参数
optimizer.step()
# 记录损失和准确率
running_loss += loss.item()
train_loss.append(loss.item())
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
accuracy_train = 100 * correct / total
# 在测试集上计算准确率
with torch.no_grad():
running_loss_test = 0
correct_test = 0
total_test = 0
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss_test += loss.item()
_, predicted = torch.max(outputs.data, 1)
correct_test += (predicted == labels).sum().item()
total_test += labels.size(0)
accuracy_test = 100 * correct_test / total_test
# 输出每个 epoch 的损失和准确率
print("Epoch [{}/{}], Loss: {:.4f}, Train Accuracy: {:.2f}%,Loss: {:.4f}, Test Accuracy: {:.2f}%"
.format(epoch + 1, num_epochs, running_loss / len(train_loader),
accuracy_train, running_loss_test / len(test_loader), accuracy_test))
save_checkpoint(epoch, model, optimizer, filename)
plt.plot(train_loss, label='Train Loss')
# 添加图例和标签
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss')
# 显示图形
plt.show()