在学习深度学习与PyTorch的过程中,大大小小的开源项目里的代码不尽相同,有的较为复杂,有的偏向简单;本文描述了构建一个深度学习项目的最基本组成部分,包括net.py
,dataset.py
,train.py
三个模块。
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
'''[conv, bn, relu] * 2'''
def __init__(self, in_channel, out_channel):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
class MetaNet(nn.Module):
'''[down_sample, up_sample]'''
def __init__(self, in_channel, num_class):
super(MetaNet, self).__init__()
self.down_sample = nn.Sequential(
DoubleConv(in_channel, 64),
nn.MaxPool2d(kernel_size=2),
DoubleConv(64, 128),
nn.MaxPool2d(kernel_size=2),
DoubleConv(128, 256),
nn.MaxPool2d(kernel_size=2),
DoubleConv(256, 512),
nn.MaxPool2d(kernel_size=2),
DoubleConv(512, 1024)
)
self.up_sample = nn.Sequential(
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
DoubleConv(512, 512),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
DoubleConv(256, 256),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
DoubleConv(128, 128),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
DoubleConv(64, 64),
nn.Conv2d(64, num_class, kernel_size=1)
)
def forward(self, x):
return self.up_sample(self.down_sample(x))
# net = MetaNet(3, 23)
# x = torch.rand([2, 3, 1024, 1024])
# y = net(x)
# print(y.shape)
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from PIL import Image
import os
import os.path as osp
from tqdm import tqdm
class MetaDataset(Dataset):
def __init__(self, image_dir, mask_dir):
super(MetaDataset, self).__init__()
self.image_dir = image_dir
self.mask_dir = mask_dir
self.ids = [osp.splitext(filename)[0] for filename in os.listdir(image_dir)]
def __len__(self):
return len(self.ids)
def __getitem__(self, idx):
name = self.ids[idx]
img = Image.open(osp.join(self.image_dir, name + '.jpg'))
mask = Image.open(osp.join(self.mask_dir, name + '_mask.png'))
assert img.size == mask.size, 'img.size != mask.size'
transform = transforms.ToTensor()
img_copy = transform(img.copy()) # [3, 1080, 1920]
mask_copy = torch.from_numpy(np.array(mask.copy())) # [1080, 1920]
h, w = img.size[0], img.size[1] # (1920, 1080)
h_crop, w_crop = h % 16, w % 16
img_crop = img_copy[:, w_crop//2: w-w_crop//2, h_crop//2: h-h_crop//2]
mask_crop = mask_copy[w_crop//2: w-w_crop//2, h_crop//2: h-h_crop//2]
return {
'image': img_crop,
'mask' : mask_crop,
'name' : name
}
# dataset = KaistDataset('./data/train/image', './data/train/mask')
# for i in range(len(dataset)):
# image = dataset[i]['image']
# mask = dataset[i]['mask']
# print(image.shape)
# print(mask.shape)
# break
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torch import optim
import os.path as osp
from tqdm import tqdm
from net import MetaNet
from dataset import MetaDataset
if __name__ == '__main__':
# model
model = MetaNet(3, 23)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# dataset
train_batch_size = 2
val_batch_size = 1
image_dir = osp.join('data', 'train', 'image')
mask_dir = osp.join('data', 'train', 'mask')
dataset = MetaDataset(image_dir, mask_dir)
n_val = 31
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_set, batch_size=val_batch_size, shuffle=True, drop_last=True)
# optimizer
learning_rate = 1e-4
weight_decay = 1e-8
momentum = 0.999
optimizer = optim.RMSprop(model.parameters(), lr=learning_rate,
weight_decay=weight_decay, momentum=momentum)
criterion = nn.CrossEntropyLoss()
# train
model.train()
n_epoch = 50
checkpoint_dir = './checkpoints'
best_val_loss = 1e9
for epoch in range(n_epoch):
train_loss = 0
for batch in tqdm(train_loader):
images, masks = batch['image'], batch['mask']
images.to(device=device, dtype=torch.float32)
masks.to(device=device, dtype=torch.long)
preds = model(images)
loss = criterion(preds, masks)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch{}: train_loss={}'.format(epoch, train_loss))
# validation
model.eval()
val_loss = 0
for batch in tqdm(val_loader):
images, masks = batch['image'], batch['mask']
images.to(device=device, dtype=torch.float32)
masks.to(device=device, dtype=torch.long)
preds = model(images)
loss = criterion(preds, masks)
val_loss += loss.item()
print('epoch{}: valid_loss={}'.format(epoch, val_loss))
if val_loss < best_val_loss:
state_dict = model.state_dict()
torch.save(state_dict, osp.join(checkpoint_dir, 'epoch_{}.pth'.format(epoch)))
print('saving current best model to epoch_{}.pth'.format(epoch))