EfficientNet的pyTorch版本的使用和训练方法

EfficientNet的pytorch版本的训练和评估使用方法

1. EfficientNet 的pyTorch版本的测试和使用

第三方PyTorch代码

# pytorch 的efficientNet安装
Install via pip:
pip install efficientnet_pytorch

Or install from source:
git clone https://github.com/lukemelas/EfficientNet-PyTorch
cd EfficientNet-Pytorch
pip install -e
# 使用示例
Usage

Loading pretrained models

Load an EfficientNet:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_name('efficientnet-b0')

Load a pretrained EfficientNet:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')
#修改输出通道数由1000改为50示例
from efficientnet_pytorch import EfficientNet
from torch import nn
model = EfficientNet.from_pretrained('efficientnet-b5')
feature = model._fc.in_features
model._fc = nn.Linear(in_features=feature,out_features=50,bias=True)
print(model)

2.提取特征使用

提取特征时可以使用model.extract_features来实现
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

# ... image preprocessing as in the classification example ...
print(img.shape) # torch.Size([1, 3, 224, 224])

features = model.extract_features(img)
print(features.shape) # torch.Size([1, 1280, 7, 7])

3.EfficientNet的pytorch版本的训练方法

# 训练代码
from torchvision import datasets, transforms
import torch
from torch import nn
import torch.optim as optim
import argparse
import warnings
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data.dataloader import default_collate  
from efficientnet_pytorch import EfficientNet #EfficientNet的使用需要倒入的库
from label_smooth import LabelSmoothSoftmaxCE
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
warnings.filterwarnings("ignore")

data_dir = "../data/"
# 需要分类的数目
num_classes = 50
# 批处理尺寸
batch_size = 20  
# 训练多少个epoch
EPOCH = 150
# feature_extract = True ture为特征提取,false为微调
feature_extract = False
# 超参数设置
pre_epoch = 0  # 定义已经遍历数据集的次数

def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x: x[0] is not None, batch))
    if len(batch) == 0: return torch.Tensor()
    return default_collate(batch)  # 用默认方式拼接过滤后的batch数据


input_size = 380  
# EfficientNet的使用和微调方法
net = EfficientNet.from_pretrained('efficientnet-b4')
net._fc.out_features = num_classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    net = nn.DataParallel(net)
net.to(device)
# net.load_state_dict(torch.load('./model/net_pre.pth'))
net = net.to(device)
# 数据预处理部分
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

    ]),

    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=12,
                                   collate_fn=my_collate_fn) for x in ['train', 'val']}
b = image_datasets['train'].class_to_idx
parser = argparse.ArgumentParser(description='PyTorch DeepNetwork Training')
parser.add_argument('--outf', default='./model/model', help='folder to output images and model checkpoints')  # 输出结果保存路径

args = parser.parse_args()
params_to_update = net.parameters()

print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name, param in net.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)
else:
    for name, param in net.named_parameters():
        if param.requires_grad == True:
            print("\t", name)

def main():
    ii = 0
    LR = 1e-3  # 学习率
    best_acc = 0  # 初始化best test accuracy
    print("Start Training, DeepNetwork!")  # 定义遍历数据集的次数

    # criterion
    criterion = LabelSmoothSoftmaxCE()
    # optimizer
    optimizer = optim.Adam(params_to_update, lr=LR, betas=(0.9, 0.999), eps=1e-9)
    # scheduler
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=3, verbose=True)
    with open("./txt/acc.txt", "w") as f:
        with open("./txt/log.txt", "w")as f2:
            for epoch in range(pre_epoch, EPOCH):
                # scheduler.step(epoch)
                print('\nEpoch: %d' % (epoch + 1))
                net.train()
                sum_loss = 0.0
                correct = 0.0
                total = 0.0

                for i, data in enumerate(dataloaders_dict['train'], 0):
                    # 准备数据
                    length = len(dataloaders_dict['train'])
                    input, target = data
                    input, target = input.to(device), target.to(device)
                    # 训练
                    optimizer.zero_grad()
                    # forward + backward
                    output = net(input)
                    loss = criterion(output, target)
                    loss.backward()
                    optimizer.step()
                    
                    # 每训练1个batch打印一次loss和准确率
                    sum_loss += loss.item()
                    _, predicted = torch.max(output.data, 1)
                    total += target.size(0)
                    correct += predicted.eq(target.data).cpu().sum()
                    print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% '
                          % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1),
                             100. * float(correct) / float(total)))
                    f2.write('%03d  %05d |Loss: %.03f | Acc: %.3f%% '
                             % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1),
                                100. * float(correct) / float(total)))
                    f2.write('\n')
                    f2.flush()

                # 每训练完一个epoch测试一下准确率
                print("Waiting Test!")
                with torch.no_grad():
                    correct = 0
                    total = 0
                    for data in dataloaders_dict['val']:
                        net.eval()
                        images, labels = data
                        images, labels = images.to(device), labels.to(device)
                        outputs = net(images)
                        # 取得分最高的那个类 (outputs.data的索引号)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).cpu().sum()
                    print('测试分类准确率为:%.3f%%' % (100. * float(correct) / float(total)))
                    acc = 100. * float(correct) / float(total)
                    scheduler.step(acc)

                    # 将每次测试结果实时写入acc.txt文件中
                    if (ii % 1 == 0):
                        print('Saving model......')
                        torch.save(net.state_dict(), '%s/net_%03d.pth' % (args.outf, epoch + 1))
                        # torch.save(net.state_dict(), '%s/net_%03d.pth' % (args.outf, epoch + 1))
                    f.write("EPOCH=%03d,Accuracy= %.3f%%" % (epoch + 1, acc))
                    f.write('\n')
                    f.flush()
                    # 记录最佳测试分类准确率并写入best_acc.txt文件中
                    if acc > best_acc:
                        f3 = open("./txt/best_acc.txt", "w")
                        f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1, acc))
                        f3.close()
                        best_acc = acc
            print("Training Finished, TotalEPOCH=%d" % EPOCH)


if __name__ == "__main__":
    main()

参考
https://zhuanlan.zhihu.com/p/95881337
https://blog.csdn.net/qq_38410428/article/details/103694759

你可能感兴趣的:(算法,深度学习,python)