代码地址:https://github.com/bat67/pytorch-FCN-easiest-demo
论文参考:全卷积网络 FCN 详解
FCN详解与pytorch简单实现(附详细代码解读)
FCN网络实现了端到端的图像分割,去除了卷积网络的全连接层,采用反卷积的方式,实现像素级别的分割,并加入skip方式,提升了分割性能。
文件目录展示
activate tensorflow
python -m visdom.server
训练第1个epoch
train result
train label
train loss function
train loss
训练第99个epoch
train result
test result
文件夹介绍
主要.py文件已经在下方展示代码,其中图中位置1为训练的图像,位置2为模型保存的路径
关于网络模型中的FCN32S表示特征图缩小到1/32,随后在变大到32倍。
关于网络模型中的FCN16S表示特征图缩小到1/16,随后在变大到16倍。
关于网络模型中的FCN8S表示特征图缩小到1/8,随后在变大到8倍。
性能:FCN8S>FCN16S>FCN32S
fcn.py
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.vgg import VGG
class FCN32s(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
score = self.bn1(self.relu(self.deconv1(x5)))
score = self.bn2(self.relu(self.deconv2(score)))
score = self.bn3(self.relu(self.deconv3(score)))
score = self.bn4(self.relu(self.deconv4(score)))
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class FCN16s(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
x4 = output['x4']
score = self.relu(self.deconv1(x5))
score = self.bn1(score + x4)
score = self.bn2(self.relu(self.deconv2(score)))
score = self.bn3(self.relu(self.deconv3(score)))
score = self.bn4(self.relu(self.deconv4(score)))
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class FCN8s(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
x4 = output['x4']
x3 = output['x3']
score = self.relu(self.deconv1(x5))
score = self.bn1(score + x4)
score = self.relu(self.deconv2(score))
score = self.bn2(score + x3)
score = self.bn3(self.relu(self.deconv3(score)))
score = self.bn4(self.relu(self.deconv4(score)))
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class FCNs(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
#设置基本参数
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
# classifier is 1x1 conv, to reduce channels from 32 to n_class
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
x4 = output['x4']
x3 = output['x3']
x2 = output['x2']
x1 = output['x1']
score = self.bn1(self.relu(self.deconv1(x5)))
score = score + x4
score = self.bn2(self.relu(self.deconv2(score)))
score = score + x3
score = self.bn3(self.relu(self.deconv3(score)))
score = score + x2
score = self.bn4(self.relu(self.deconv4(score)))
score = score + x1
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class VGGNet(VGG):
def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
super().__init__(make_layers(cfg[model]))
self.ranges = ranges[model]
if pretrained:
exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)
if not requires_grad:
for param in super().parameters():
param.requires_grad = False
# delete redundant fully-connected layer params, can save memory
# 去掉vgg最后的全连接层(classifier)
if remove_fc:
del self.classifier
if show_params:
for name, param in self.named_parameters():
print(name, param.size())
def forward(self, x):
output = {}
# get the output of each maxpooling layer (5 maxpool in VGG net)
for idx, (begin, end) in enumerate(self.ranges):
#self.ranges = ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)) (vgg16 examples)
for layer in range(begin, end):
x = self.features[layer](x)
output["x%d"%(idx+1)] = x
return output
#字典类型,键值对
ranges = {
'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),
'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}
# Vgg-Net config
# Vgg网络结构配置
cfg = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
# make layers using Vgg-Net config(cfg)
# 由cfg构建vgg-Net
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
'''
VGG-16网络参数
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
'''
if __name__ == "__main__":
pass
train.py
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
from BagData import test_dataloader, train_dataloader
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
def train(epo_num=50, show_vgg_params=False):
vis = visdom.Visdom()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg_model = VGGNet(requires_grad=True, show_params=show_vgg_params)
fcn_model = FCNs(pretrained_net=vgg_model, n_class=2)
fcn_model = fcn_model.to(device)
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(fcn_model.parameters(), lr=1e-2, momentum=0.7)
all_train_iter_loss = []
all_test_iter_loss = []
# start timing
prev_time = datetime.now()
for epo in range(epo_num):
train_loss = 0
fcn_model.train()
for index, (bag, bag_msk) in enumerate(train_dataloader):
# bag.shape is torch.Size([4, 3, 160, 160])
# bag_msk.shape is torch.Size([4, 2, 160, 160])
bag = bag.to(device)
bag_msk = bag_msk.to(device)
optimizer.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, bag_msk)
loss.backward()
iter_loss = loss.item()
all_train_iter_loss.append(iter_loss)
train_loss += iter_loss
optimizer.step()
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
#每隔15次打印一下结果
if np.mod(index, 15) == 0:
print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss))
# vis.close()
vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction'))
vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label'))
vis.line(all_train_iter_loss, win='train_iter_loss',opts=dict(title='train iter loss'))
# plt.subplot(1, 2, 1)
# plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')
# plt.subplot(1, 2, 2)
# plt.imshow(np.squeeze(output_np[0, ...]), 'gray')
# plt.pause(0.5)
test_loss = 0
fcn_model.eval()
with torch.no_grad():
for index, (bag, bag_msk) in enumerate(test_dataloader):
bag = bag.to(device)
bag_msk = bag_msk.to(device)
optimizer.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, bag_msk)
iter_loss = loss.item()
all_test_iter_loss.append(iter_loss)
test_loss += iter_loss
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
if np.mod(index, 15) == 0:
print(r'Testing... Open http://localhost:8097/ to see test result.')
# vis.close()
vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction'))
vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label'))
vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss'))
# plt.subplot(1, 2, 1)
# plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')
# plt.subplot(1, 2, 2)
# plt.imshow(np.squeeze(output_np[0, ...]), 'gray')
# plt.pause(0.5)
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
prev_time = cur_time
print('epoch train loss = %f, epoch test loss = %f, %s'
%(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))
if np.mod(epo, 5) == 0:
torch.save(fcn_model, 'checkpoints/fcn_model_{}.pt'.format(epo))
print('saveing checkpoints/fcn_model_{}.pt'.format(epo))
if __name__ == "__main__":
train(epo_num=100, show_vgg_params=False)
onehot.py
import numpy as np
def onehot(data, n):
buf = np.zeros(data.shape + (n, ))
#ravel()将多维数组降为一维
nmsk = np.arange(data.size)*n + data.ravel()
buf.ravel()[nmsk-1] = 1
return buf
'''
one-hot是比较常用的文本特征特征提取的方法。
one-hot编码,又称“独热编码”。其实就是用N位状态寄存器编码N个状态,
每个状态都有独立的寄存器位,且这些寄存器位中只有一位有效,说白了就是只能有一个状态。
参考文献:
https://blog.csdn.net/Dorothy_Xue/article/details/84641417?utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-4.control&dist_request_id=&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-4.control
'''
BagData.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : BagData.py
@Time : 2021/04/18 11:32:40
@Author : Jian Song
@Contact : [email protected]
@Desc : None
'''
# here put the import lib
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import cv2
from onehot import onehot
#张量转换,设置三个通道的均值和方差
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
class BagDataset(Dataset):
def __init__(self, transform=None):
self.transform = transform
def __len__(self):
return len(os.listdir('bag_data'))
def __getitem__(self, idx):
img_name = os.listdir('bag_data')[idx]
#获取原始图像
imgA = cv2.imread('bag_data/'+img_name)
imgA = cv2.resize(imgA, (160, 160))
#获取研磨图像
imgB = cv2.imread('bag_data_msk/'+img_name, 0)
imgB = cv2.resize(imgB, (160, 160))
imgB = imgB/255
imgB = imgB.astype('uint8')
imgB = onehot(imgB, 2)
#根据轴的索引进行转置,
imgB = imgB.transpose(2,0,1)
#类型转换,uint8转换为torchTensor
imgB = torch.FloatTensor(imgB)
#print(imgB.shape)
#如果设置了图像预处理
if self.transform:
imgA = self.transform(imgA)
return imgA, imgB
bag = BagDataset(transform)
#%90数据用于训练,10%数据用于测试
train_size = int(0.9 * len(bag))
test_size = len(bag) - train_size
train_dataset, test_dataset = random_split(bag, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)
if __name__ =='__main__':
for train_batch in train_dataloader:
print(train_batch)
for test_batch in test_dataloader:
print(test_batch)