pytorch 实现迁移学习 transfer learn区分 蜜蜂和蚂蚁

数据集
这个数据集是一个很小的imagenet的子集.
下载链接

https://download.pytorch.org/tutorial/hymenoptera_data.zip

下载下来以后

unzip hymenoptera_data.zip

文件夹结构

./data/hymenoptera_data/
                     ->train/
                           ->ants/
                                ->*.jpg
                           ->bees/
                                ->*.jpg
                     ->val/
                            ->ants/
                                ->*.jpg
                            ->bees/
                                ->*.jpg

训练中使用了数据集增强和归一化
验证时只是用了归一化
24 epoch finetune 后的输出
pytorch 实现迁移学习 transfer learn区分 蜜蜂和蚂蚁_第1张图片

run the model with finetune=True
config:
std : [0.229, 0.224, 0.225]
dataset : hymenoptera_data
finetune : True
train_load_check_point_file : True
image_size : 224
num_workers : 4
device : cuda:0
epochs : 25
step_size : 7
resize : 256
momentum : 0.9
mean : [0.485, 0.456, 0.406]
batch_size : 8
gamma : 0.1
data_path : ./data/hymenoptera_data
learn_rate : 0.001
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)
  (fc): Linear(in_features=512, out_features=2, bias=True)
)
[1/24] [Train Loss:0.5185 Acc:0.7213] [Val Loss:0.2674 Acc:0.9085] [Best Epoch:1 Acc:0.9085] [1.3285s 1.3285s]
[2/24] [Train Loss:0.3609 Acc:0.8361] [Val Loss:0.1733 Acc:0.9412] [Best Epoch:2 Acc:0.9412] [1.3079s 2.6365s]
[3/24] [Train Loss:0.2823 Acc:0.8934] [Val Loss:0.1625 Acc:0.9477] [Best Epoch:3 Acc:0.9477] [1.2812s 3.9176s]
[4/24] [Train Loss:0.1906 Acc:0.9262] [Val Loss:0.1639 Acc:0.9477] [Best Epoch:3 Acc:0.9477] [1.3068s 5.2244s]
[5/24] [Train Loss:0.2429 Acc:0.9057] [Val Loss:0.2559 Acc:0.9150] [Best Epoch:3 Acc:0.9477] [1.3273s 6.5517s]
[6/24] [Train Loss:0.2161 Acc:0.9057] [Val Loss:0.2018 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.3289s 7.8806s]
[7/24] [Train Loss:0.2494 Acc:0.9139] [Val Loss:0.2553 Acc:0.9150] [Best Epoch:3 Acc:0.9477] [1.3054s 9.1860s]
[8/24] [Train Loss:0.1451 Acc:0.9508] [Val Loss:0.2175 Acc:0.9150] [Best Epoch:3 Acc:0.9477] [1.3425s 10.5285s]
[9/24] [Train Loss:0.1422 Acc:0.9590] [Val Loss:0.2163 Acc:0.9150] [Best Epoch:3 Acc:0.9477] [1.3253s 11.8538s]
[10/24] [Train Loss:0.1500 Acc:0.9467] [Val Loss:0.2364 Acc:0.9216] [Best Epoch:3 Acc:0.9477] [1.3283s 13.1821s]
[11/24] [Train Loss:0.1840 Acc:0.9221] [Val Loss:0.2118 Acc:0.9216] [Best Epoch:3 Acc:0.9477] [1.3313s 14.5134s]
[12/24] [Train Loss:0.1532 Acc:0.9426] [Val Loss:0.2108 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.3792s 15.8925s]
[13/24] [Train Loss:0.1586 Acc:0.9303] [Val Loss:0.2082 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.2652s 17.1577s]
[14/24] [Train Loss:0.1772 Acc:0.9262] [Val Loss:0.2242 Acc:0.9216] [Best Epoch:3 Acc:0.9477] [1.2985s 18.4562s]
[15/24] [Train Loss:0.1209 Acc:0.9672] [Val Loss:0.2173 Acc:0.9216] [Best Epoch:3 Acc:0.9477] [1.3271s 19.7833s]
[16/24] [Train Loss:0.1571 Acc:0.9426] [Val Loss:0.2082 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.3196s 21.1029s]
[17/24] [Train Loss:0.2401 Acc:0.9016] [Val Loss:0.2140 Acc:0.9346] [Best Epoch:3 Acc:0.9477] [1.3619s 22.4648s]
[18/24] [Train Loss:0.1579 Acc:0.9467] [Val Loss:0.2133 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.2715s 23.7362s]
[19/24] [Train Loss:0.1147 Acc:0.9631] [Val Loss:0.2195 Acc:0.9216] [Best Epoch:3 Acc:0.9477] [1.3313s 25.0675s]
[20/24] [Train Loss:0.1963 Acc:0.9098] [Val Loss:0.2144 Acc:0.9150] [Best Epoch:3 Acc:0.9477] [1.3578s 26.4253s]
[21/24] [Train Loss:0.1867 Acc:0.9221] [Val Loss:0.2094 Acc:0.9216] [Best Epoch:3 Acc:0.9477] [1.2682s 27.6935s]
[22/24] [Train Loss:0.1368 Acc:0.9467] [Val Loss:0.2142 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.3807s 29.0742s]
[23/24] [Train Loss:0.1819 Acc:0.9221] [Val Loss:0.2000 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.3199s 30.3941s]
[24/24] [Train Loss:0.1716 Acc:0.9303] [Val Loss:0.2122 Acc:0.9281] [Best Epoch:3 Acc:0.9477] [1.3263s 31.7204s]

24 epoch without finetune 后的输出
pytorch 实现迁移学习 transfer learn区分 蜜蜂和蚂蚁_第2张图片

run the model with finetune=False
config:
std : [0.229, 0.224, 0.225]
dataset : hymenoptera_data
finetune : False
train_load_check_point_file : True
image_size : 224
num_workers : 4
device : cuda:0
epochs : 25
step_size : 7
resize : 256
momentum : 0.9
mean : [0.485, 0.456, 0.406]
batch_size : 8
gamma : 0.1
data_path : ./data/hymenoptera_data
learn_rate : 0.001
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)
  (fc): Linear(in_features=512, out_features=2, bias=True)
)
[1/24] [Train Loss:0.5838 Acc:0.6393] [Val Loss:0.2962 Acc:0.9020] [Best Epoch:1 Acc:0.9020] [1.1856s 1.1856s]
[2/24] [Train Loss:0.3283 Acc:0.8934] [Val Loss:0.2175 Acc:0.9281] [Best Epoch:2 Acc:0.9281] [1.2195s 2.4051s]
[3/24] [Train Loss:0.3166 Acc:0.8566] [Val Loss:0.1928 Acc:0.9346] [Best Epoch:3 Acc:0.9346] [1.1106s 3.5157s]
[4/24] [Train Loss:0.2932 Acc:0.8975] [Val Loss:0.1743 Acc:0.9477] [Best Epoch:4 Acc:0.9477] [1.2303s 4.7460s]
[5/24] [Train Loss:0.3207 Acc:0.8320] [Val Loss:0.1753 Acc:0.9542] [Best Epoch:5 Acc:0.9542] [1.2099s 5.9559s]
[6/24] [Train Loss:0.2776 Acc:0.8770] [Val Loss:0.1751 Acc:0.9412] [Best Epoch:5 Acc:0.9542] [1.1920s 7.1479s]
[7/24] [Train Loss:0.2705 Acc:0.8811] [Val Loss:0.1822 Acc:0.9542] [Best Epoch:5 Acc:0.9542] [1.1728s 8.3207s]
[8/24] [Train Loss:0.1962 Acc:0.9180] [Val Loss:0.1821 Acc:0.9477] [Best Epoch:5 Acc:0.9542] [1.1578s 9.4785s]
[9/24] [Train Loss:0.1874 Acc:0.9221] [Val Loss:0.1754 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.2263s 10.7048s]
[10/24] [Train Loss:0.1853 Acc:0.9221] [Val Loss:0.1743 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.2255s 11.9303s]
[11/24] [Train Loss:0.2535 Acc:0.8852] [Val Loss:0.1808 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.2431s 13.1734s]
[12/24] [Train Loss:0.2674 Acc:0.8893] [Val Loss:0.1741 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.1599s 14.3333s]
[13/24] [Train Loss:0.2844 Acc:0.8893] [Val Loss:0.1766 Acc:0.9542] [Best Epoch:9 Acc:0.9608] [1.2288s 15.5621s]
[14/24] [Train Loss:0.1607 Acc:0.9467] [Val Loss:0.1908 Acc:0.9477] [Best Epoch:9 Acc:0.9608] [1.2446s 16.8068s]
[15/24] [Train Loss:0.1954 Acc:0.9303] [Val Loss:0.1785 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.2252s 18.0319s]
[16/24] [Train Loss:0.2241 Acc:0.9016] [Val Loss:0.1734 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.2243s 19.2562s]
[17/24] [Train Loss:0.2126 Acc:0.9098] [Val Loss:0.1827 Acc:0.9477] [Best Epoch:9 Acc:0.9608] [1.1610s 20.4172s]
[18/24] [Train Loss:0.1669 Acc:0.9303] [Val Loss:0.1779 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.1824s 21.5996s]
[19/24] [Train Loss:0.2127 Acc:0.9180] [Val Loss:0.1720 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.1214s 22.7209s]
[20/24] [Train Loss:0.2483 Acc:0.8975] [Val Loss:0.1696 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.1469s 23.8678s]
[21/24] [Train Loss:0.2560 Acc:0.8852] [Val Loss:0.1639 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.1381s 25.0059s]
[22/24] [Train Loss:0.2112 Acc:0.9180] [Val Loss:0.1777 Acc:0.9542] [Best Epoch:9 Acc:0.9608] [1.1740s 26.1799s]
[23/24] [Train Loss:0.2900 Acc:0.8852] [Val Loss:0.1691 Acc:0.9542] [Best Epoch:9 Acc:0.9608] [1.1615s 27.3414s]
[24/24] [Train Loss:0.2246 Acc:0.8975] [Val Loss:0.1746 Acc:0.9608] [Best Epoch:9 Acc:0.9608] [1.1800s 28.5213s]

下面是实现文件

ants_bees_data_set.py  
graph.py  
record.py  
test_graph.py
etc.py                 
main.py   
status.py  
train_graph.py

ants_bees_data_set.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : ants_bees_data_set.py
# Create date : 2019-01-30 13:51
# Modified date : 2019-02-01 21:48
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import os
import torch
from torchvision import datasets, models, transforms

def _get_image_dataset(data_path, kind, data_transforms):
    return datasets.ImageFolder(os.path.join(data_path, kind), data_transforms[kind])

def _get_a_dataloader(image_datasets, kind, config):
    batch_size = config["batch_size"]
    num_workers = config["num_workers"]
    return torch.utils.data.DataLoader(image_datasets[kind], batch_size=batch_size, shuffle=True, num_workers=num_workers)

def _get_dataloaders(image_datasets,config):
    dataloaders = {}
    dataloaders["train"] = _get_a_dataloader(image_datasets, "train", config)
    dataloaders["val"] = _get_a_dataloader(image_datasets,"val", config)
    return dataloaders

def _get_data_transforms(config):
    img_size = config["image_size"]
    resize = config["resize"]
    mean = config["mean"]
    std = config["std"]
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
        'val': transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
    }
    return data_transforms

def _get_image_datasets(config):
    data_path = config["data_path"]
    data_transforms = _get_data_transforms(config)
    image_datasets = {}
    image_datasets["train"] = _get_image_dataset(data_path,"train", data_transforms)
    image_datasets["val"] = _get_image_dataset(data_path,"val",data_transforms)
    return image_datasets

def _get_dataset_sizes(datasets):
    dataset_sizes = {}
    dataset_sizes["train"] = len(datasets["train"])
    dataset_sizes["val"] = len(datasets["val"])
    return dataset_sizes

def _get_data_dict(dataloaders, dataset_sizes, class_names):
    data_dict = {}
    data_dict["dataloaders"] = dataloaders
    data_dict["dataset_sizes"] = dataset_sizes
    data_dict["class_names"] = class_names
    return data_dict

def get_dataset_info_dict(config):
    image_datasets = _get_image_datasets(config)

    dataloaders = _get_dataloaders(image_datasets,config)
    dataset_sizes = _get_dataset_sizes(image_datasets)
    class_names = image_datasets['train'].classes

    data_dict = _get_data_dict(dataloaders, dataset_sizes, class_names)
    return data_dict

graph.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-30 14:25
# Modified date : 2019-02-01 18:19
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import os
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.optim import lr_scheduler

import record
import status

def _add_last_layer(model, config):
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)
    model = model.to(config["device"])
    return model

def _get_model(config):
    finetune = config["finetune"]
    model = torchvision.models.resnet18(pretrained=True)
    if not finetune:
        for param in model.parameters():
            param.requires_grad = False
    return _add_last_layer(model, config)

def _get_optimizer(config, model):
    learn_rate = config["learn_rate"]
    momentum = config["momentum"]
    optimizer_ft = optim.SGD(model.parameters(), lr=learn_rate, momentum=momentum)
    return optimizer_ft

def _get_scheduler(optimizer, config):
    step_size = config["step_size"]
    gamma = config["gamma"]
    return lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

class TransferLearnGraph(object):
    def __init__(self, data_dict, config):
        super(TransferLearnGraph, self).__init__()
        self.config = config
        self.data_dict = data_dict
        self.graph_dict = self._init_graph_dict(config)
        self.status_dict = status.get_status_dict()
        self._load_train_model()

    def _save_trained_model(self):
        model_dict = self._get_model_dict()
        file_full_path = record.get_check_point_file_full_path(self.config)
        torch.save(model_dict, file_full_path)

    def _init_graph_dict(self, config):
        graph_dict = {}
        graph_dict["model"] = _get_model(config)
        graph_dict["criterion"] = nn.CrossEntropyLoss()
        graph_dict["optimizer"] = _get_optimizer(config, graph_dict["model"])
        graph_dict["scheduler"] = _get_scheduler(graph_dict["optimizer"], config)
        return graph_dict

    def _get_model_dict(self):
        model_dict = {}
        model_dict["model"] = self.graph_dict["model"].state_dict()
        model_dict["criterion"] = self.graph_dict["criterion"].state_dict()
        model_dict["optimizer"] = self.graph_dict["optimizer"].state_dict()
        model_dict["scheduler"] = self.graph_dict["scheduler"].state_dict()

        model_dict["status_dict"] = self.status_dict
        model_dict["config"] = self.config
        return model_dict

    def _load_model_dict(self, checkpoint):
        self.graph_dict["model"].load_state_dict(checkpoint["model"])
        self.graph_dict["criterion"].load_state_dict(checkpoint["criterion"])
        self.graph_dict["optimizer"].load_state_dict(checkpoint["optimizer"])
        self.graph_dict["scheduler"].load_state_dict(checkpoint["scheduler"])

        self.status_dict = checkpoint["status_dict"]
        self.config = checkpoint["config"]

    def _load_train_model(self):
        file_full_path = record.get_check_point_file_full_path(self.config)
        if os.path.exists(file_full_path) and self.config["train_load_check_point_file"]:
            checkpoint = torch.load(file_full_path)
            self._load_model_dict(checkpoint)

record.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : record.py
# Create date : 2019-01-30 21:37
# Modified date : 2019-02-01 21:49
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import os

def _get_param_str(config):
    # pylint: disable=bad-continuation
    param_str = "%s_%s_%s_%s_%s" % (
                                config["dataset"],
                                config["image_size"],
                                config["batch_size"],
                                config["learn_rate"],
                                config["finetune"],
                                )
    # pylint: enable=bad-continuation
    return param_str

def get_check_point_path(config):
    param_str = _get_param_str(config)
    directory = "%s/save/%s/" % (config["data_path"], param_str)
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

def get_check_point_file_full_path(config):
    path = get_check_point_path(config)
    param_str = _get_param_str(config)
    file_full_path = "%s%scheckpoint.tar" % (path, param_str)
    return file_full_path

def _write_output(config, con):
    save_path = get_check_point_path(config)
    file_full_path = "%s/output" % save_path
    f = open(file_full_path, "a")
    f.write("%s\n" %  con)
    f.close()

def record_dict(config, dic):
    save_content(config, "config:")
    for key in dic:
        dic_str = "%s : %s" % (key, dic[key])
        save_content(config, dic_str)

def save_content(config, con):
    print(con)
    _write_output(config, con)

def save_epoch_status(status_dict, config):
    num_epochs = config["epochs"]
    epoch = status_dict["epoch"]
    train_epoch_loss = status_dict["train_epoch_loss"]
    train_epoch_acc = status_dict["train_epoch_acc"]
    val_epoch_loss = status_dict["val_epoch_loss"]
    val_epoch_acc = status_dict["val_epoch_acc"]
    best_epoch = status_dict["best_epoch"]
    best_acc = status_dict["best_acc"]
    epoch_elapsed_time = status_dict["epoch_eplapsed_time"]
    so_far_elapsed_time = status_dict["so_far_elapsed_time"]

    # pylint: disable=bad-continuation
    save_str = '[%s/%s] [Train Loss:%.4f Acc:%.4f] [Val Loss:%.4f Acc:%.4f] [Best Epoch:%s Acc:%.4f] [%.4fs %.4fs]' % (
                            epoch,
                            num_epochs - 1,
                            train_epoch_loss,
                            train_epoch_acc,
                            val_epoch_loss,
                            val_epoch_acc,
                            best_epoch,
                            best_acc,
                            epoch_elapsed_time,
                            so_far_elapsed_time
                            )

    # pylint: enable=bad-continuation
    save_content(config, save_str)

test_graph.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : test_graph.py
# Create date : 2019-02-01 17:21
# Modified date : 2019-02-01 21:49
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import numpy as np
import matplotlib.pyplot as plt
from graph import TransferLearnGraph
import torch
import record

def imshow(inp, config, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array(config["mean"])
    std = np.array(config["std"])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

def _test_model(dataloader, model, config):
    device = config["device"]
    with torch.no_grad():
        inputs, labels = next(iter(dataloader))
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        return inputs, preds

def _show_test_images(inputs, class_names, preds, config):
    images_so_far = 0
    num_images = inputs.size()[0]
    for j in range(num_images):
        images_so_far += 1
        ax = plt.subplot(num_images//2, 2, images_so_far)
        ax.axis('off')
        ax.set_title('predicted: {}'.format(class_names[preds[j]]))
        imshow(inputs.cpu().data[j], config)

    save_path = record.get_check_point_path(config)
    name = "test_images.jpg"
    full_path_name = "%s/%s" % (save_path, name)
    plt.savefig(full_path_name)
#    plt.show()

def run_test(dataloader, model, class_names, config):
    model.eval()
    inputs, preds = _test_model(dataloader, model, config)
    _show_test_images(inputs, class_names, preds, config)

class TestTransferLearnGraph(TransferLearnGraph):
    def __init__(self, data_dict, config):
        super(TestTransferLearnGraph, self).__init__(data_dict, config)

    def test_the_model(self):
        dataloader = self.data_dict["dataloaders"]['val']

        model = self.graph_dict["model"]
        model.load_state_dict(self.status_dict["best_model_wts"])

        class_names = self.data_dict["class_names"]
        run_test(dataloader, model, class_names, self.config)

etc.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-30 15:17
# Modified date : 2019-02-01 18:27
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import torch

config = {}

config["dataset"] = "hymenoptera_data"
config["data_path"] = "./data/%s" % config["dataset"]

config["epochs"] = 25
config["batch_size"] = 8
config["num_workers"] = 4
config["image_size"] = 224
config["resize"] = 256

config["device"] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

config["finetune"] = True
config["learn_rate"] = 0.001
config["momentum"] = 0.9
config["step_size"] = 7
config["gamma"] = 0.1

config["mean"] = [0.485, 0.456, 0.406]
config["std"] = [0.229, 0.224, 0.225]

config["train_load_check_point_file"] = True

main.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-01-30 13:35
# Modified date : 2019-02-01 18:33
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import ants_bees_data_set
from train_graph import TrainTransferLearnGraph
from test_graph import TestTransferLearnGraph
from etc import config
import record

def with_finetune():
    print_str = "run the model with finetune=True"
    config["finetune"] = True
    record.save_content(config, print_str)
    record.record_dict(config, config)
    data_dict = ants_bees_data_set.get_dataset_info_dict(config)

    g = TrainTransferLearnGraph(data_dict, config)
    g.train_the_model()

    test_g = TestTransferLearnGraph(data_dict, config)
    test_g.test_the_model()

def without_finetune():
    print_str = "run the model with finetune=False"
    config["finetune"] = False
    record.save_content(config, print_str)
    record.record_dict(config, config)
    data_dict = ants_bees_data_set.get_dataset_info_dict(config)

    g = TrainTransferLearnGraph(data_dict, config)
    g.train_the_model()

    test_g = TestTransferLearnGraph(data_dict, config)
    test_g.test_the_model()

def run():
    with_finetune()
    without_finetune()

run()

status.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : status.py
# Create date : 2019-02-01 13:41
# Modified date : 2019-02-01 21:51
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import copy

def get_status_dict():
    status_dict = {}
    status_dict["best_acc"] = 0.0
    status_dict["best_model_wts"] = None
    status_dict["elapsed_time"] = 0.0
    status_dict["epoch"] = 0
    status_dict["train_epoch_loss"] = 0.0
    status_dict["train_epoch_acc"] = 0.0
    status_dict["val_epoch_loss"] = 0.0
    status_dict["val_epoch_acc"] = 0.0
    status_dict["best_epoch"] = 0
    status_dict["best_acc"] = 0.0
    status_dict["epoch_eplapsed_time"] = 0.0
    status_dict["so_far_elapsed_time"] = 0.0
    return status_dict

def val_epoch_update_status_dict(val_epoch_loss, val_epoch_acc, epoch, model, status_dict):
    status_dict["val_epoch_loss"] = val_epoch_loss
    status_dict["val_epoch_acc"] = val_epoch_acc

    if val_epoch_acc > status_dict["best_acc"]:
        status_dict["best_epoch"] = epoch
        status_dict["best_acc"] = val_epoch_acc
        status_dict["best_model_wts"] = copy.deepcopy(model.state_dict())

def train_epoch_update_status_dict(train_epoch_loss, train_epoch_acc, status_dict):
    status_dict["train_epoch_loss"] = train_epoch_loss
    status_dict["train_epoch_acc"] = train_epoch_acc

def update_eplapsed_time(start, end, status_dict):
    status_dict["epoch_eplapsed_time"] = end - start
    status_dict["so_far_elapsed_time"] += status_dict["epoch_eplapsed_time"]

def update_epoch(epoch, status_dict):
    status_dict["epoch"] = epoch

train_graph.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : train_graph.py
# Create date : 2019-02-01 17:22
# Modified date : 2019-02-01 21:52
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import copy
import time
import torch
from graph import TransferLearnGraph
import status
import record

class TrainTransferLearnGraph(TransferLearnGraph):
    def __init__(self, data_dict, config):
        super(TrainTransferLearnGraph, self).__init__(data_dict, config)

    def _run_a_epoch(self, epoch):
        status.update_epoch(epoch, self.status_dict)
        start = time.time()
        self._train_a_epoch()
        self._eval_a_epoch()
        end = time.time()

        status.update_eplapsed_time(start, end, self.status_dict)
        record.save_epoch_status(self.status_dict, self.config)
        self._save_trained_model()

    def _train_a_step(self, inputs, labels,):
        model = self.graph_dict["model"]
        criterion = self.graph_dict["criterion"]
        optimizer = self.graph_dict["optimizer"]
        device = self.config["device"]

        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        l = loss.item()*inputs.size(0)
        corrects = torch.sum(preds == labels.data)
        return l, corrects

    def _train_a_epoch(self):
        dataloaders = self.data_dict["dataloaders"]
        dataset_sizes = self.data_dict["dataset_sizes"]
        model = self.graph_dict["model"]
        scheduler = self.graph_dict["scheduler"]
        running_loss = 0.0
        running_corrects = 0

        scheduler.step()
        model.train()
        for inputs, labels in dataloaders["train"]:
            loss, corrects = self._train_a_step(inputs, labels)
            running_loss += loss
            running_corrects += corrects

        train_epoch_loss = running_loss / dataset_sizes["train"]
        train_epoch_acc = running_corrects.double() / dataset_sizes["train"]
        status.train_epoch_update_status_dict(train_epoch_loss, train_epoch_acc, self.status_dict)

    def _eval_a_step(self, inputs, labels):
        model = self.graph_dict["model"]
        criterion = self.graph_dict["criterion"]
        device = self.config["device"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        l = loss.item() * inputs.size(0)
        corrects = torch.sum(preds == labels.data)
        return l, corrects

    def _eval_a_epoch(self):
        epoch = self.status_dict["epoch"]
        dataloaders = self.data_dict["dataloaders"]
        dataset_sizes = self.data_dict["dataset_sizes"]
        model = self.graph_dict["model"]

        dataloader = dataloaders["val"]
        dataset_size = dataset_sizes["val"]
        running_loss = 0.0
        running_corrects = 0

        model.eval()
        for inputs, labels in dataloader:
            loss, corrects = self._eval_a_step(inputs, labels)
            running_loss += loss
            running_corrects += corrects

        val_epoch_loss = running_loss / dataset_size
        val_epoch_acc = running_corrects.double() / dataset_size

        status.val_epoch_update_status_dict(val_epoch_loss, val_epoch_acc, epoch, model, self.status_dict)

    def train_the_model(self):
        model = self.graph_dict["model"]
        record.save_content(self.config, model)
        num_epochs = self.config["epochs"]

        self.status_dict["best_model_wts"] = copy.deepcopy(model.state_dict())
        epoch_start = self.status_dict["epoch"]

        for epoch in range(epoch_start + 1, num_epochs):
            self._run_a_epoch(epoch)

github:
https://github.com/darr/transfer_learn

你可能感兴趣的:(AI,Deep,Learning,人工智能,神经网络,pytorch)