pytorch框架--网络方面--完整训练示例

完整训练示例

使用pytorch自带数据集,构建简单网络,进行训练

import torch
import torchvision
from torch import nn
# 导入记好了,2维卷积,2维最大池化,展成1维,全连接层,构建网络结构辅助工具
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 加载数据
# 参数:下载保存路径、train=训练集(True)或者测试集(False)、download=在线(True) 或者 本地(False)、数据类型转换
test_data = torchvision.datasets.CIFAR10("./dataset",
                                         train=False,
                                         download=True,
                                         transform=torchvision.transforms.ToTensor())
# # 格式打包
# # 参数:数据、1组几个、下一轮是否打乱、进程个数、最后一组是否凑成一组
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=True)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            # 输入, 输出, 卷积核、补几圈零
            Conv2d(3, 32, (5, 5), padding=2),
            # 池化核
            MaxPool2d(2),
            Conv2d(32, 32, (5, 5), padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, (5, 5), padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10))

    def forward(self, x):
        x = self.model1(x)
        return x


if __name__ == '__main__':
    # #测试
    # tudui = Tudui()
    # # 验证网络 须知输入图像,设定全1矩阵测试
    # input = torch.ones((64, 3, 32, 32))
    # output = tudui(input)
    # print(output.shape)
    # # 绘制网络结构图
    # writer = SummaryWriter("log")
    # # 参数:网络结构对象、输入图像矩阵
    # writer.add_graph(tudui, input)
    # writer.close()

    # 模型加载
    tudui = Tudui()
    # 损失函数
    loss = nn.CrossEntropyLoss()
    # 优化器
    optim = torch.optim.Adam(tudui.parameters(), lr=0.0001)
    for epoch in range(20):
        # 每一轮损失
        running_loss = 0.0
        for data in test_loader:
            # 加载数据
            imgs, targets = data
            # 将数据放入网络
            outputs = tudui(imgs)

            # 损失函数:网络输出(预测)、标签
            result_loss = loss(outputs, targets)

            # 优化器 梯度清零
            optim.zero_grad()
            # 反向传播
            result_loss.backward()
            # 调用优化器
            optim.step()

            # 累计损失
            running_loss += result_loss
            # \r{} 可不换行直接显示,加上[]是为了好看一点
            print("\r[{}]".format(result_loss), end="")
        print(running_loss)

你可能感兴趣的:(Pytorch框架,python,pytorch,计算机视觉)