(Pytorch)GoogleNet代码复现CIFAR-10数据集

model.py

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter


class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)  # BasicConv2d类
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)  # Inception类
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)  # InceptionAux类
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    # 正向传播
    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)  # N x 64 x 112 x 112
        x = self.maxpool1(x)  # N x 64 x 56 x 56

        x = self.conv2(x)  # N x 64 x 56 x 56
        x = self.conv3(x)  # N x 192 x 56 x 56
        x = self.maxpool2(x)  # N x 192 x 28 x 28

        x = self.inception3a(x)  # N x 256 x 28 x 28
        x = self.inception3b(x)  # N x 480 x 28 x 28
        x = self.maxpool3(x)  # N x 480 x 14 x 14

        x = self.inception4a(x)  # N x 512 x 14 x 14
        if self.training and self.aux_logits:  # eval model不执行该部分
            aux1 = self.aux1(x)
        x = self.inception4b(x)  # N x 512 x 14 x 14
        x = self.inception4c(x)  # N x 512 x 14 x 14
        x = self.inception4d(x)  # N x 528 x 14 x 14
        if self.training and self.aux_logits:  # eval model不执行该部分
            aux2 = self.aux2(x)
        x = self.inception4e(x)  # N x 832 x 14 x 14
        x = self.maxpool4(x)  # N x 832 x 7 x 7

        x = self.inception5a(x)  # N x 832 x 7 x 7
        x = self.inception5b(x)  # N x 1024 x 7 x 7

        x = self.avgpool(x)  # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)  # N x 1024
        x = self.dropout(x)
        x = self.fc(x)  # N x 1000 (num_classes)

        if self.training and self.aux_logits:  # eval model不执行该部分
            return x, aux2, aux1
        return x

    # 初始化权重
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


# 类Inception,有四个分支
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)  # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)  # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        # 四个分支连接起来
        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


# 辅助分类器:类InceptionAux,包括avepool+conv+fc1+fc2
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)  # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)  # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)  # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)  # N x 1024
        x = self.fc2(x)  # N x num_classes
        return x

# writer = SummaryWriter("model_logs")


# 类BasicConv2d,包括conv+relu
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x


# model = GoogLeNet()
# writer.add_graph(model,torch.rand(1, 3, 224, 224))
# writer.close()

train.py

import os
import sys
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torchvision
import json
import torch.optim as optim
from tqdm import tqdm
from model import GoogLeNet

# 检测使用 gpu or cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

image_path = "data_set/flower_data/"  # 数据集路径
train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=data_transform["train"])
train_num = len(train_dataset)
# {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8,
# "9": "truck"
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 将字典写入json文件中
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

val_dataset = torchvision.datasets.CIFAR10(root='data', train=False, transform=data_transform["val"])
val_num = len(val_dataset)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size, shuffle=False,
                                         num_workers=0)
val_steps = len(val_loader)
# 模型,花数据集5个类别
net = GoogLeNet(num_classes=10, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0003)

train_steps = len(train_loader)
writer = SummaryWriter("logs")
# 开始训练
print('Start training...')
best_acc = 0.0
save_path = 'googleNet.pth'
epochs = 100

resume = True  # 设置是否需要从上次的状态继续训练
if resume:
    if os.path.isfile("googleNet.pth"):
        print("Resume from checkpoint...")
        checkpoint = torch.load("googleNet.pth")
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initepoch = checkpoint['epoch'] + 2
        print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']+1))
    else:
        print("====>no checkpoint found.")
        initepoch = 1  # 如果没进行训练过,初始训练epoch值为1

for epoch in range(initepoch-1,epochs):
    # train
    print("-------第 {} 轮训练开始-------".format(epoch + 1))
    train_bar = tqdm(train_loader, file=sys.stdout)
    net.train()
    train_acc = 0.0
    running_loss = 0.0
    for step, data in enumerate(train_bar, start=0):
        images, labels = data
        optimizer.zero_grad()  # 清空过往梯度
        # 将输入图片载入模型中,得到输出图像,有三个参数:主分类器 + 2个辅助分类器
        logits, aux_logits2, aux_logits1 = net(images.to(device))
        # 计算三个分类器的损失
        loss0 = loss_function(logits, labels.to(device))
        loss1 = loss_function(aux_logits1, labels.to(device))
        loss2 = loss_function(aux_logits2, labels.to(device))
        loss = loss0 + loss1 * 0.3 + loss2 * 0.3  # 乘以权重0.3是论文中提出的
        loss.backward()  # 反向传播,计算当前梯度
        optimizer.step()  # 根据梯度更新网络参数

        running_loss += loss.item()  # 累加损失值
        # 打印训练过程
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
        predict = torch.max(logits, dim=1)[1]
        train_acc += torch.eq(predict, labels.to(device)).sum().item()
    train_loss = running_loss / train_steps
    train_accurate = train_acc / train_num
    # validate
    running_loss = 0.0
    net.eval()
    val_acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        val_bar = tqdm(val_loader, file=sys.stdout)
        for step, val_data in enumerate(val_bar):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model只需要主分类器的输出
            loss = loss_function(outputs, val_labels.to(device))
            running_loss += loss.item()
            predict_y = torch.max(outputs, dim=1)[1]
            val_acc += (predict_y == val_labels.to(device)).sum().item()
    val_accurate = val_acc / val_num
    val_loss = running_loss / val_steps
    print('[epoch %d] train_loss: %.3f val_loss:%.3f train_accuracy:%.3f val_accuracy: %.3f' %
          (epoch + 1, train_loss, val_loss, train_accurate, val_accurate))
    writer.add_scalars('loss',
                       {'train': train_loss, 'val': val_loss}, global_step=epoch)
    writer.add_scalars('acc',
                       {'train': train_accurate, 'val': val_accurate}, global_step=epoch)
    # 保存断点
    checkpoint = {"model_state_dict": net.state_dict(),
                  "optimizer_state_dict": optimizer.state_dict(),
                  "epoch": epoch}
    path_checkpoint = "googleNet.pth"
    torch.save(checkpoint, path_checkpoint)
    print("保存模型成功")


print('Finished Training')

writer.close()

程序设置了断点续训,可以接着训练,日志可以用tensorboard查看 

你可能感兴趣的:(pytorch,深度学习,python)