AI 基础实战营打卡笔记提交 第三课

课程链接:火爆全网的 OpenMMLab 实战项目来了(基于 PyTorch)
第三课是实战课,没有什么需要整理的理论知识,这里主要记录一下视频中没有给到的各种代码

1. 演示 mmclassification 用的 .ipynb 文件

视频中,老师用到了 cluster/pub/openmmlab-tut 这个目录下的 jupyter 文件,可以在 github的 这个地址 找到

2.split 数据用到的脚本 split_data.py

# 必须用绝对路径,否则训练的时候会报错找不到文件
python split_data.py /HOME/scz0bca/run/mmclassification-master/data/flower_dataset /HOME/scz0bca/run/mmclassification-master/data/flower
import os
import sys
import shutil
import numpy as np

def load_data(data_path):
 count = 0
 data = {}
 for dir_name in os.listdir(data_path):
  dir_path = os.path.join(data_path, dir_name)
  if not os.path.isdir(dir_path):
   continue
  data[dir_name] = []
  for file_name in os.listdir(dir_path):
   file_path = os.path.join(dir_path, file_name)
   if not os.path.isfile(file_path):
    continue
   data[dir_name].append(file_path)
  count += len(data[dir_name])
  print("{} :{}".format(dir_name, len(data[dir_name])))
 print("total of image : {}".format(count))
 return data

def copy_dataset(src_img_list, data_index, target_path):
 target_img_list = []
 for index in data_index:
  src_img = src_img_list[index]
  img_name = os.path.split(src_img)[-1]
  shutil.copy(src_img, target_path)
  target_img_list.append(os.path.join(target_path, img_name))
 return target_img_list

def write_file(data, file_name):
 if isinstance(data, dict):
  write_data = []
  for lab, img_list in data.items():
   for img in img_list:
    write_data.append("{} {}".format(img, lab))
 else:
  write_data = data
 with open(file_name, "w") as f:
  for line in write_data:
   f.write(line + "\n")
 print("{} write over!".format(file_name))

def split_data(src_data_path, target_data_path, train_rate=0.8):
 src_data_dict = load_data(src_data_path)
 classes = []
 train_dataset, val_dataset = {}, {}
 train_count, val_count = 0, 0
 for i, (cls_name, img_list) in enumerate(src_data_dict.items()):
  img_data_size = len(img_list)
  random_index = np.random.choice(img_data_size, img_data_size,replace=False)
  train_data_size = int(img_data_size * train_rate)
  train_data_index = random_index[:train_data_size]
  val_data_index = random_index[train_data_size:]
  train_data_path = os.path.join(target_data_path, "train", cls_name)
  val_data_path = os.path.join(target_data_path, "val", cls_name)
  os.makedirs(train_data_path, exist_ok=True)
  os.makedirs(val_data_path, exist_ok=True)
  classes.append(cls_name)
  train_dataset[i] = copy_dataset(img_list, train_data_index,train_data_path)
  val_dataset[i] = copy_dataset(img_list, val_data_index, val_data_path)
  print("target {} train:{}, val:{}".format(cls_name,len(train_dataset[i]), len(val_dataset[i])))
  train_count += len(train_dataset[i])
  val_count += len(val_dataset[i])
 print("train size:{}, val size:{}, total:{}".format(train_count, val_count,train_count + val_count))
 write_file(classes, os.path.join(target_data_path, "classes.txt"))
 write_file(train_dataset, os.path.join(target_data_path, "train.txt"))
 write_file(val_dataset, os.path.join(target_data_path, "val.txt"))

def main():
 src_data_path = sys.argv[1]
 target_data_path = sys.argv[2]
 split_data(src_data_path, target_data_path, train_rate=0.8)

if __name__ == '__main__':
 main()

3. 预训练的模型 xxx.pth

可以使用下面命令下载

wget https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth -P checkpoints

4. 提交任务用到的 run.sh

#!/bin/bash
# 加载模块
module load anaconda/2021.05
module load cuda/11.1
module load gcc/7.3

# 激活环境
source activate mmclassification

# 刷新⽇志缓存
export PYTHONUNBUFFERED=1

# 训练模型
python tools/train.py \
configs/resnet18/resnet18_b32_flower.py \
--work-dir work/resnet18_b32_flower

5. 配置文件 resnet18_b32_flower.py

_base_ = ['../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py', '../_base_/default_runtime.py']
model = dict(
    head=dict(
        num_classes=5,
        topk=(1,)
    ))
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        data_prefix='data/flower/train',
        ann_file='data/flower/train.txt',
        classes='data/flower/classes.txt'
    ),
    val=dict(
        data_prefix='data/flower/val',
        ann_file='data/flower/val.txt',
        classes='data/flower/classes.txt'
    )
)
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy='step',
    step=[1])
runner = dict(type='EpochBasedRunner', max_epochs=100)
# 预训练模型
load_from = '/HOME/shenpg/run/openmmlab/mmclassification/checkpoints/resnet18_batch256_imagenet_20200708-34ab8f90.pth'

你可能感兴趣的:(人工智能,python,深度学习)