PyTorch入门程序

转载自我的个人网站 https://wzw21.cn/2022/02/20/hello-pytorch/

在 PyTorch For Audio and Music Processing 入门代码的基础上添加了一些注释和新的内容

  1. Download dataset
  2. Create data loader
  3. Build model
  4. Train
  5. Save trained model
  6. Load model
  7. Predict
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
def download_mnist_datasets():
  train_data = datasets.MNIST(
      root="data",
      download=True,
      train=True,
      transform=ToTensor() 
  )
  val_data = datasets.MNIST(
      root="data",
      download=True,
      train=False,
      transform=ToTensor() 
  )
  return train_data, val_data
class SimpleNet(nn.Module):

  def __init__(self): # constructor
    super().__init__()
    self.flatten = nn.Flatten()
    self.dense_layers = nn.Sequential(
        nn.Linear(28*28, 256), # Fully Connected layer (input_shape, output_shape)
        nn.ReLU(),
        nn.Linear(256, 10)
    )
    self.softmax = nn.Softmax(dim=1)

  def forward(self, input_data):
    flattened_data = self.flatten(input_data)
    logits = self.dense_layers(flattened_data) # logits here means the input of the final softmax
    predictions = self.softmax(logits)
    return predictions

Need more code than Tensorflow 2.x or Keras!

def train_one_epoch(model, data_loader, loss_fn, optimizer, device):
  model.train() # change to train mode
  loss_sum = 0.
  correct = 0
  for inputs, targets in data_loader:
    inputs, targets = inputs.to(device), targets.to(device)

    # calculate loss
    predictions = model(inputs) # this will call forward function automatically
    loss = loss_fn(predictions, targets)

    # backpropagate loss and update weights
    optimizer.zero_grad() # reset grads
    loss.backward() # calculate grads
    optimizer.step() # update weights

    loss_sum += loss.item() # item() returns the value of this tensor as a standard Python number
    
    with torch.no_grad():
      _, predictions_indexes = torch.max(predictions, 1) # get predicted indexes
      correct += torch.sum(predictions_indexes == targets)
      # or correct += (predictions.argmax(1) == targets).type(torch.float).sum().item()

  print(f"Train loss: {(loss_sum / len(data_loader)):.4f}, train accuracy: {(correct / len(data_loader.dataset)):.4f}")

def val_one_epoch(model, data_loader, loss_fn, device):
  model.eval() # change to eval mode
  loss_sum = 0.
  correct = 0
  with torch.no_grad():
    for inputs, targets in data_loader:
      inputs, targets = inputs.to(device), targets.to(device)

      predictions = model(inputs)
      loss = loss_fn(predictions, targets)

      loss_sum += loss.item()
      _, predictions_indexes = torch.max(predictions, 1)
      correct += torch.sum(predictions_indexes == targets)

  print(f"Validation loss: {(loss_sum / len(data_loader)):.4f}, validation accuracy: {(correct / len(data_loader.dataset)):.4f}")

def train(model, train_data_loader, val_data_loader, loss_fn, optimizer, device, epochs):
  for i in range(epochs):
    print(f"Epoch {i+1}")
    train_one_epoch(model, train_data_loader, loss_fn, optimizer, device)
    val_one_epoch(model, val_data_loader, loss_fn, device)
    print("-----------------------")
  print("Training is done")
def predict(model, input, target, class_mapping):
  # input's shape = torch.Size([1, 28, 28])
  model.eval() # change to eval mode
  with torch.no_grad(): # don't have to calculate grads here
    predictions = model(input)
    # predictions' shape = torch.Size([1, 10])
    predicted_index = predictions[0].argmax(0)
    predicted = class_mapping[predicted_index]
    expected = class_mapping[target]
  return predicted, expected
if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"
print(f"Using {device} device")

BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = .001

class_mapping = [
  "0",
  "1",
  "2",
  "3",
  "4",
  "5",
  "6",
  "7",
  "8",
  "9"
]
Using cuda device
# download MNIST dataset
train_data, val_data = download_mnist_datasets()
print("Dataset downloaded")

# create a data loader for the train set
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
Dataset downloaded
# build model
simple_net = SimpleNet().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(simple_net.parameters(), lr=LEARNING_RATE)

# train model
train(simple_net, train_data_loader, val_data_loader, loss_fn, optimizer, device, EPOCHS)

# save model
torch.save(simple_net.state_dict(), "simple_net.pth")
print("Model saved")
# torch.save(model.state_dict(), "my_model.pth")  # only save parameters
# torch.save(model, "my_model.pth")  # save the whole model
# checkpoint = {"net": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch}
Epoch 1
Train loss: 1.5717, train accuracy: 0.9036
Validation loss: 1.5280, validation accuracy: 0.9388
-----------------------
Epoch 2
Train loss: 1.5148, train accuracy: 0.9506
Validation loss: 1.5153, validation accuracy: 0.9507
-----------------------
Epoch 3
Train loss: 1.5008, train accuracy: 0.9629
Validation loss: 1.5016, validation accuracy: 0.9625
-----------------------
Epoch 4
Train loss: 1.4924, train accuracy: 0.9707
Validation loss: 1.4958, validation accuracy: 0.9680
-----------------------
Epoch 5
Train loss: 1.4871, train accuracy: 0.9760
Validation loss: 1.4919, validation accuracy: 0.9702
-----------------------
Epoch 6
Train loss: 1.4837, train accuracy: 0.9789
Validation loss: 1.4884, validation accuracy: 0.9742
-----------------------
Epoch 7
Train loss: 1.4811, train accuracy: 0.9814
Validation loss: 1.4885, validation accuracy: 0.9736
-----------------------
Epoch 8
Train loss: 1.4787, train accuracy: 0.9837
Validation loss: 1.4896, validation accuracy: 0.9724
-----------------------
Epoch 9
Train loss: 1.4771, train accuracy: 0.9851
Validation loss: 1.4884, validation accuracy: 0.9739
-----------------------
Epoch 10
Train loss: 1.4758, train accuracy: 0.9863
Validation loss: 1.4889, validation accuracy: 0.9732
-----------------------
Training is done
Model saved
# load model
reloaded_simple_net = SimpleNet()
state_dict = torch.load("simple_net.pth")
reloaded_simple_net.load_state_dict(state_dict)

# make an inference
input, target = val_data[0][0], val_data[0][1]
predicted, expected = predict(reloaded_simple_net, input, target, class_mapping)
print(f"Predicted: '{predicted}', expected: '{expected}'")
Predicted: '7', expected: '7'

你可能感兴趣的:(操作步骤与实践,pytorch,深度学习,python)