数据集下载:用pytorch写FCN进行手提包的语义分割。
training data(https://github.com/yunlongdong/FCN-pytorch-easiest/tree/master/last),放到bag_data文件夹下
ground-truth label(https://github.com/yunlongdong/FCN-pytorch-easiest/tree/master/last_msk),放到bag_data_mask文件夹下
项目目录结构:
训练数据:
训练label:
从这个手提包数据集可以看出,这是个二分类的,就是只分割出手提包 和 背景 两个类别。所以label处黑色的表示手提包,白色的就是无关的背景。
###BagData.py
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import numpy as np
import cv2
#from onehot import onehot
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
def onehot(data, n):
buf = np.zeros(data.shape + (n, ))
nmsk = np.arange(data.size)*n + data.ravel()
buf.ravel()[nmsk-1] = 1
return buf
class BagDataset(Dataset):
def __init__(self, transform=None):
self.transform = transform
def __len__(self):
return len(os.listdir('bag_data'))
def __getitem__(self, idx):
img_name = os.listdir('bag_data')[idx]
imgA = cv2.imread('bag_data/'+img_name)
imgA = cv2.resize(imgA, (160, 160))
#print(imgA.shape)
imgB = cv2.imread('bag_data_msk/'+img_name, 0)
imgB = cv2.resize(imgB, (160, 160))
#print(imgB.shape)
imgB = imgB/255
imgB = imgB.astype('uint8')
imgB = onehot(imgB, 2) #因为此代码是二分类问题,即分割出手提包和背景两样就行,因此这里参数是2
imgB = imgB.transpose(2,0,1) #imgB不经过transform处理,所以要手动把(H,W,C)转成(C,H,W)
imgB = torch.FloatTensor(imgB)
if self.transform:
imgA = self.transform(imgA) #一转成向量后,imgA通道就变成(C,H,W)
return imgA, imgB
bag = BagDataset(transform)
train_size = int(0.9 * len(bag)) #整个训练集中,百分之90为训练集
test_size = len(bag) - train_size
train_dataset, test_dataset = random_split(bag, [train_size, test_size]) #划分训练集和测试集
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)
if __name__ =='__main__':
for train_batch in train_dataloader:
print(train_batch)
for test_batch in test_dataloader:
print(test_batch)
贴了数据集读取的代码后,我觉得有必要说一下onehot这个函数。
1.就是数据集label的onehot化:
onehot化是把label化成一个一维向量。
onehot化的函数如下:
def onehot(data, n):
buf = np.zeros(data.shape + (n, ))
nmsk = np.arange(data.size)*n + data.ravel()
buf.ravel()[nmsk-1] = 1
return buf
输入的data为以灰度图形式读取的label,n为分割的类别数(此数据集是2)
buf = np.zeros(data.shape + (n, ))#设data的shape为(a,b),则构造一个全0矩阵,维度为(a,b,n)
因为n是2,所以意思就是,2层的(a,b)的全0矩阵,一层用来表示手提包的,剩下一层则是用来表示背景的。
nmsk = np.arange(data.size)*n + data.ravel()
这行则比较妙一点,首先设data的size为5,则arange(5)为,(0,1,2,3,4),其实就是表示data各个元素的位置。arange(5)*2为(0,2,4,6,8),其实这是变相表示原来长度x2的位置。而data因为是label,且归一化过的,所以data里的值要么是0要么是1,data.ravel()是把data展成一维数组,arange(5)*2+data.ravel()意思是在(0,2,4,6,8)中,表示手提包的则+1,表示背景的则+0。这里打个比方,例如第三个和第五个位置是表示手提包的,则是(0,2,5,6,9),到这里可能还看不出什么,结合下一句代码就明白了。
buf.ravel()[nmsk-1] = 1
用回刚刚的例子(0,2,5,6,9),nmsk-1后,是(9,1,4,5,8),与初始的(0,2,4,6,8)对比,若原来是1的位置会保持原样(因为+1后又-1了),而原本是0的,表示其位置就会-1。这样的结果就是把(a,b)的label投射到(a,b)*2的长度中。这样做的原因数据集是2分类的,所以网络输出肯定是(a,b,2)这样的,所以label必须要和网络输出维度形式一样才能比较,得出损失函数。
#####FCN.py
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.vgg import VGG
class FCNs(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
# classifier is 1x1 conv, to reduce channels from 32 to n_class
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
x4 = output['x4']
x3 = output['x3']
x2 = output['x2']
x1 = output['x1']
score = self.bn1(self.relu(self.deconv1(x5)))
score = score + x4
score = self.bn2(self.relu(self.deconv2(score)))
score = score + x3
score = self.bn3(self.relu(self.deconv3(score)))
score = score + x2
score = self.bn4(self.relu(self.deconv4(score)))
score = score + x1
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class VGGNet(VGG):
def __init__(self, pretrained=False, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
super().__init__(make_layers(cfg[model]))
self.ranges = ranges[model]
if pretrained:
exec("self.load_state_dict(models.%s(pretrained=False).state_dict())" % model)
if not requires_grad:
for param in super().parameters():
param.requires_grad = False
# delete redundant fully-connected layer params, can save memory
# 去掉vgg最后的全连接层(classifier)
if remove_fc:
del self.classifier
if show_params:
for name, param in self.named_parameters():
print(name, param.size())
def forward(self, x):
output = {}
# get the output of each maxpooling layer (5 maxpool in VGG net)
for idx, (begin, end) in enumerate(self.ranges):
#self.ranges = ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)) (vgg16 examples)
for layer in range(begin, end):
x = self.features[layer](x)
output["x%d"%(idx+1)] = x
return output
ranges = {
'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),
'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}
# Vgg-Net config
# Vgg网络结构配置
cfg = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
# make layers using Vgg-Net config(cfg)
# 由cfg构建vgg-Net
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
'''
VGG-16网络参数
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
'''
if __name__ == "__main__":
pass
2.3。train代码
########train.py
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
from BagData import test_dataloader, train_dataloader
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
def train(epo_num=50, show_vgg_params=False):
#vis = visdom.Visdom()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg_model = VGGNet(requires_grad=True, show_params=show_vgg_params)
fcn_model = FCNs(pretrained_net=vgg_model, n_class=2)
fcn_model = fcn_model.to(device)
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(fcn_model.parameters(), lr=1e-2, momentum=0.7)
all_train_iter_loss = []
all_test_iter_loss = []
# start timing
prev_time = datetime.now()
for epo in range(epo_num):
train_loss = 0
fcn_model.train()
for index, (bag, bag_msk) in enumerate(train_dataloader):
# bag.shape is torch.Size([4, 3, 160, 160])
# bag_msk.shape is torch.Size([4, 2, 160, 160])
bag = bag.to(device)
bag_msk = bag_msk.to(device)
optimizer.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
# print(output)
# print(bag_msk)
loss = criterion(output, bag_msk)
loss.backward()
iter_loss = loss.item()
all_train_iter_loss.append(iter_loss)
train_loss += iter_loss
optimizer.step()
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
test_loss = 0
fcn_model.eval()
with torch.no_grad():
for index, (bag, bag_msk) in enumerate(test_dataloader):
bag = bag.to(device)
bag_msk = bag_msk.to(device)
optimizer.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, bag_msk)
iter_loss = loss.item()
all_test_iter_loss.append(iter_loss)
test_loss += iter_loss
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
prev_time = cur_time
print('epoch train loss = %f, epoch test loss = %f, %s'
%(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))
if np.mod(epo, 5) == 0:
torch.save(fcn_model, 'checkpoints/fcn_model_{}.pt'.format(epo))
print('saveing checkpoints/fcn_model_{}.pt'.format(epo))
if __name__ == "__main__":
train(epo_num=100, show_vgg_params=False)
#######t.py
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
import os
import cv2
import matplotlib.pyplot as plt
from BagData import test_dataloader, train_dataloader
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('checkpoints/fcn_model_95.pt') # 加载模型
model = model.to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if __name__ =='__main__':
img_name =r'bag3.jpg' #预测的图片
imgA = cv2.imread(img_name)
imgA = cv2.resize(imgA, (160, 160))
imgA = transform(imgA)
imgA = imgA.to(device)
imgA = imgA.unsqueeze(0)
output = model(imgA)
output = torch.sigmoid(output)
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
print(output_np.shape) #(1, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
print(output_np.shape) #(1,160, 160)
plt.subplot(1, 2, 1)
#plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')
#plt.subplot(1, 2, 2)
plt.imshow(np.squeeze(output_np[0, ...]), 'gray')
plt.pause(3)
输入图片:
输出效果:
项目代码下载:https://download.csdn.net/download/u014453898/11244794
运行时,直接运行train.py得到模型后,再 运行t.py则可以进行预测