官方的Segmentation Transformer源码是基于MMSegmentation框架的,不便于阅读和学习,想使用官方版本的就不用参考此博客了。
这里采用的是GitHub上某大佬复现Segmentation Transformer的版本
Segmentation Transformer 源码下载链接 https://github.com/gupta-abhay/setr-pytorch ,此源码缺少训练脚本 我将SETR网络模型去替换了DeepLabV3的网络模型,有关DeepLabV3模型的使用与修改见此博客https://blog.csdn.net/qq_41964545/article/details/115252939?spm=1001.2014.3001.5501
那么开始调整源码吧
将源码下载好后,解压至DeeplabV3文件夹下
按照上一篇Deeplabv3博客处理好CityScapes数据集的label
由于SETR模型设计了三种decoder结构 这里采用的是最简单的Naive结构,这里采用的是SETR_Naive_S网络模型,如下,查看源码可以看出CityScapes数据集用于训练的图像大小为768*768,首先将类别数修改为20
然后就需要datasets.py部分
首先修改DatasetTrain这个类
对于数据增强部分,我只保留了随机翻转,其余的 randomly scale the img and the label部分和random crop from the img and label我进行了注释,你也可以根据自己的需要调整,但是要保证返回的图像的大小是768*768
同样地对于DatasetVal这个类
同样地对于DatasetSeq这个类
datasets.py代码如下:
# camera-ready
import torch
import torch.utils.data
import numpy as np
import cv2
import os
train_dirs = ["jena/", "zurich/", "weimar/", "ulm/", "tubingen/", "stuttgart/",
"strasbourg/", "monchengladbach/", "krefeld/", "hanover/",
"hamburg/", "erfurt/", "dusseldorf/", "darmstadt/", "cologne/",
"bremen/", "bochum/", "aachen/"]
val_dirs = ["frankfurt/", "munster/", "lindau/"]
test_dirs = ["berlin", "bielefeld", "bonn", "leverkusen", "mainz", "munich"]
class DatasetTrain(torch.utils.data.Dataset):
def __init__(self, cityscapes_data_path, cityscapes_meta_path):
self.img_dir = cityscapes_data_path + "/leftImg8bit/train/"
self.label_dir = cityscapes_meta_path + "/label_imgs/"
self.img_h = 1024
self.img_w = 2048
# self.new_img_h = 512
# self.new_img_w = 1024
self.new_img_h = 768
self.new_img_w = 768
self.examples = []
for train_dir in train_dirs:
train_img_dir_path = self.img_dir + train_dir
file_names = os.listdir(train_img_dir_path)
for file_name in file_names:
img_id = file_name.split("_leftImg8bit.png")[0]
img_path = train_img_dir_path + file_name
label_img_path = self.label_dir + img_id + ".png"
example = {}
example["img_path"] = img_path
example["label_img_path"] = label_img_path
example["img_id"] = img_id
self.examples.append(example)
self.num_examples = len(self.examples)
def __getitem__(self, index):
example = self.examples[index]
img_path = example["img_path"]
img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
# resize img without interpolation (want the image to still match
# label_img, which we resize below):
img = cv2.resize(img, (self.new_img_w, self.new_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))
label_img_path = example["label_img_path"]
label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
# resize label_img without interpolation (want the resulting image to
# still only contain pixel values corresponding to an object class):
label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))
# flip the img and the label with 0.5 probability:
flip = np.random.randint(low=0, high=2)
if flip == 1:
img = cv2.flip(img, 1)
label_img = cv2.flip(label_img, 1)
########################################################################
# randomly scale the img and the label:
########################################################################
# scale = np.random.uniform(low=0.7, high=2.0)
# new_img_h = int(scale*self.new_img_h)
# new_img_w = int(scale*self.new_img_w)
# resize img without interpolation (want the image to still match
# label_img, which we resize below):
# img = cv2.resize(img, (new_img_w, new_img_h),
# interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w, 3))
# resize label_img without interpolation (want the resulting image to
# still only contain pixel values corresponding to an object class):
# label_img = cv2.resize(label_img, (new_img_w, new_img_h),
# interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w))
########################################################################
# # # # # # # # debug visualization START
# print (scale)
# print (new_img_h)
# print (new_img_w)
#
# cv2.imshow("test", img)
# cv2.waitKey(0)
#
# cv2.imshow("test", label_img)
# cv2.waitKey(0)
# # # # # # # # debug visualization END
########################################################################
# select a 256x256 random crop from the img and label:
########################################################################
# start_x = np.random.randint(low=0, high=(new_img_w - 256))
# end_x = start_x + 256
# start_y = np.random.randint(low=0, high=(new_img_h - 256))
# end_y = start_y + 256
# start_x = np.random.randint(low=0, high=(new_img_w - 768))
# end_x = start_x + 768
# start_y = np.random.randint(low=0, high=(new_img_h - 768))
# end_y = start_y + 768
# img = img[start_y:end_y, start_x:end_x] # (shape: (256, 256, 3))
# label_img = label_img[start_y:end_y, start_x:end_x] # (shape: (256, 256))
########################################################################
# # # # # # # # debug visualization START
# print (img.shape)
# print (label_img.shape)
#
# cv2.imshow("test", img)
# cv2.waitKey(0)
#
# cv2.imshow("test", label_img)
# cv2.waitKey(0)
# # # # # # # # debug visualization END
# normalize the img (with the mean and std for the pretrained ResNet):
img = img/255.0
img = img - np.array([0.485, 0.456, 0.406])
img = img/np.array([0.229, 0.224, 0.225]) # (shape: (256, 256, 3))
img = np.transpose(img, (2, 0, 1)) # (shape: (3, 256, 256))
img = img.astype(np.float32)
# convert numpy -> torch:
img = torch.from_numpy(img) # (shape: (3, 256, 256))
label_img = torch.from_numpy(label_img) # (shape: (256, 256))
return (img, label_img)
def __len__(self):
return self.num_examples
class DatasetVal(torch.utils.data.Dataset):
def __init__(self, cityscapes_data_path, cityscapes_meta_path):
self.img_dir = cityscapes_data_path + "/leftImg8bit/val/"
self.label_dir = cityscapes_meta_path + "/label_imgs/"
self.img_h = 1024
self.img_w = 2048
self.new_img_h = 768
self.new_img_w = 768
self.examples = []
for val_dir in val_dirs:
val_img_dir_path = self.img_dir + val_dir
file_names = os.listdir(val_img_dir_path)
for file_name in file_names:
img_id = file_name.split("_leftImg8bit.png")[0]
img_path = val_img_dir_path + file_name
label_img_path = self.label_dir + img_id + ".png"
label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
example = {}
example["img_path"] = img_path
example["label_img_path"] = label_img_path
example["img_id"] = img_id
self.examples.append(example)
self.num_examples = len(self.examples)
def __getitem__(self, index):
example = self.examples[index]
img_id = example["img_id"]
img_path = example["img_path"]
img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
# resize img without interpolation (want the image to still match
# label_img, which we resize below):
img = cv2.resize(img, (self.new_img_w, self.new_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))
label_img_path = example["label_img_path"]
label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
# resize label_img without interpolation (want the resulting image to
# still only contain pixel values corresponding to an object class):
label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))
# # # # # # # # debug visualization START
# cv2.imshow("test", img)
# cv2.waitKey(0)
#
# cv2.imshow("test", label_img)
# cv2.waitKey(0)
# # # # # # # # debug visualization END
# normalize the img (with the mean and std for the pretrained ResNet):
img = img/255.0
img = img - np.array([0.485, 0.456, 0.406])
img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
img = img.astype(np.float32)
# convert numpy -> torch:
img = torch.from_numpy(img) # (shape: (3, 512, 1024))
label_img = torch.from_numpy(label_img) # (shape: (512, 1024))
return (img, label_img, img_id)
def __len__(self):
return self.num_examples
class DatasetSeq(torch.utils.data.Dataset):
def __init__(self, cityscapes_data_path, cityscapes_meta_path, sequence):
self.img_dir = cityscapes_data_path + "/leftImg8bit/demoVideo/stuttgart_" + sequence + "/"
# self.img_dir = cityscapes_data_path + "/leftImg8bit/" + sequence + "/"
self.img_h = 1024
self.img_w = 2048
self.new_img_h = 768
self.new_img_w = 768
self.examples = []
file_names = os.listdir(self.img_dir)
for file_name in file_names:
img_id = file_name.split("_leftImg8bit.png")[0]
img_path = self.img_dir + file_name
example = {}
example["img_path"] = img_path
example["img_id"] = img_id
self.examples.append(example)
self.num_examples = len(self.examples)
def __getitem__(self, index):
example = self.examples[index]
img_id = example["img_id"]
img_path = example["img_path"]
img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
# resize img without interpolation:
img = cv2.resize(img, (self.new_img_w, self.new_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))
# normalize the img (with the mean and std for the pretrained ResNet):
img = img/255.0
img = img - np.array([0.485, 0.456, 0.406])
img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
img = img.astype(np.float32)
# convert numpy -> torch:
img = torch.from_numpy(img) # (shape: (3, 512, 1024))
return (img, img_id)
def __len__(self):
return self.num_examples
class DatasetThnSeq(torch.utils.data.Dataset):
def __init__(self, thn_data_path):
self.img_dir = thn_data_path + "/"
self.examples = []
file_names = os.listdir(self.img_dir)
for file_name in file_names:
img_id = file_name.split(".png")[0]
img_path = self.img_dir + file_name
example = {}
example["img_path"] = img_path
example["img_id"] = img_id
self.examples.append(example)
self.num_examples = len(self.examples)
def __getitem__(self, index):
example = self.examples[index]
img_id = example["img_id"]
img_path = example["img_path"]
img = cv2.imread(img_path, -1) # (shape: (512, 1024, 3))
# normalize the img (with mean and std for the pretrained ResNet):
img = img/255.0
img = img - np.array([0.485, 0.456, 0.406])
img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
img = img.astype(np.float32)
# convert numpy -> torch:
img = torch.from_numpy(img) # (shape: (3, 512, 1024))
return (img, img_id)
def __len__(self):
return self.num_examples
读入模型
修改后的train.py部分
# camera-ready
import sys
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
from datasets import DatasetTrain, DatasetVal # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)
sys.path.append("/root/deeplabv3/model")
from model.deeplabv3 import DeepLabV3
sys.path.append("/root/deeplabv3/utils")
from utils.utils import add_weight_decay
from setr.SETR import SETR_Naive_S
import torch
import torch.utils.data
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pickle
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2
import time
if __name__ == "__main__":
# NOTE! NOTE! change this to not overwrite all log data when you train the model:
# network = DeepLabV3(model_id=1, project_dir="E:/master/master1/RSISS/deeplabv3/deeplabv3").cuda()
# x = Variable(torch.randn(2,3,256,256)).cuda()
# print(x.shape)
# y = network(x)
# print(y.shape)
model_id = "1"
num_epochs = 1000
batch_size = 1
learning_rate = 0.0001
# network = DeepLabV3(model_id, project_dir="E:/master/master1/RSISS/deeplabv3/deeplabv3").cuda()
_,network = SETR_Naive_S()
network.cuda()
train_dataset = DatasetTrain(cityscapes_data_path="/mnt/cityscapes",
cityscapes_meta_path="/mnt/cityscapes")
val_dataset = DatasetVal(cityscapes_data_path="/mnt/cityscapes",
cityscapes_meta_path="/mnt/cityscapes")
num_train_batches = int(len(train_dataset)/batch_size)
num_val_batches = int(len(val_dataset)/batch_size)
print ("num_train_batches:", num_train_batches)
print ("num_val_batches:", num_val_batches)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=1)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=batch_size, shuffle=False,
num_workers=1)
params = add_weight_decay(network, l2_value=0.0001)
optimizer = torch.optim.Adam(params, lr=learning_rate)
# with open("/mnt/cityscapes/class_weights.pkl", "rb") as file: # (needed for python3)
# class_weights = np.array(pickle.load(file))
# class_weights = torch.from_numpy(class_weights)
# class_weights = Variable(class_weights.type(torch.FloatTensor)).cuda()
# loss function
loss_fn = nn.CrossEntropyLoss(weight=None)
epoch_losses_train = []
epoch_losses_val = []
for epoch in range(num_epochs):
print ("###########################")
print ("######## NEW EPOCH ########")
print ("###########################")
print ("epoch: %d/%d" % (epoch+1, num_epochs))
############################################################################
# train:
############################################################################
network.train() # (set in training mode, this affects BatchNorm and dropout)
batch_losses = []
for step, (imgs, label_imgs) in enumerate(train_loader):
#current_time = time.time()
imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))
# print("imgs.shape: ",imgs.shape)
label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() # (shape: (batch_size, img_h, img_w))
# print("label_imgs.shape: ",label_imgs.shape)
outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))
# print(outputs)
# print("shape of label_imgs: ",label_imgs.shape)
# print("shape of outputs: ",outputs.shape)
# compute the loss:
loss = loss_fn(outputs, label_imgs)
loss_value = loss.data.cpu().numpy()
batch_losses.append(loss_value)
# optimization step:
optimizer.zero_grad() # (reset gradients)
loss.backward() # (compute gradients)
optimizer.step() # (perform optimization step)
#print (time.time() - current_time)
epoch_loss = np.mean(batch_losses)
epoch_losses_train.append(epoch_loss)
# with open("%s/epoch_losses_train.pkl" % network.model_dir, "wb") as file:
# pickle.dump(epoch_losses_train, file)
with open("training_logs/model_1/epoch_losses_train.pkl" , "wb") as file:
pickle.dump(epoch_losses_train, file)
print ("train loss: %g" % epoch_loss)
plt.figure(1)
plt.plot(epoch_losses_train, "k^")
plt.plot(epoch_losses_train, "k")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.title("train loss per epoch")
plt.savefig("training_logs/model_1/epoch_losses_train.png")
plt.close(1)
print ("####")
############################################################################
# val:
############################################################################
network.eval() # (set in evaluation mode, this affects BatchNorm and dropout)
batch_losses = []
for step, (imgs, label_imgs, img_ids) in enumerate(val_loader):
with torch.no_grad(): # (corresponds to setting volatile=True in all variables, this is done during inference to reduce memory consumption)
imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))
label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() # (shape: (batch_size, img_h, img_w))
outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))
# compute the loss:
loss = loss_fn(outputs, label_imgs)
loss_value = loss.data.cpu().numpy()
batch_losses.append(loss_value)
epoch_loss = np.mean(batch_losses)
epoch_losses_val.append(epoch_loss)
with open("training_logs/model_1/epoch_losses_val.pkl" , "wb") as file:
pickle.dump(epoch_losses_val, file)
print ("val loss: %g" % epoch_loss)
plt.figure(1)
plt.plot(epoch_losses_val, "k^")
plt.plot(epoch_losses_val, "k")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.title("val loss per epoch")
plt.savefig("training_logs/model_1/epoch_losses_val.png" )
plt.close(1)
# save the model weights to disk:
checkpoint_path = "training_logs/model_1/checkpoints/model_1" +"_epoch_" + str(epoch+1) + ".pth"
torch.save(network.state_dict(), checkpoint_path)
之后就可以训练了,训练过程如下:
此模型参数非常大,我用24G显存,batch_size为1才能勉强训练,展示一下训练6个Epoch的Loss
三、测试部分run_on_seg.py部分修改
为了与原始DeepLabV3其余部分代码兼容,随后我修改了许多文件保存路径
修改后run_on_seg.py代码如下:
# camera-ready
import sys
from setr.SETR import SETR_Naive_S
from datasets import DatasetSeq # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)
from utils.utils import label_img_to_color
import torch
import torch.utils.data
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pickle
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2
import os
if __name__ =="__main__":
batch_size = 2
# network = DeepLabV3("eval_seq", project_dir="E:/master/master1/RSISS/deeplabv3/deeplabv3").cuda()
_,network = SETR_Naive_S()
network.cuda()
network.load_state_dict(torch.load("training_logs/model_1/checkpoints/model_1_epoch_6.pth"))
for sequence in ["00", "01", "02"]:
print (sequence)
val_dataset = DatasetSeq(cityscapes_data_path="/mnt/cityscapes",
cityscapes_meta_path="/mnt/cityscapes",
sequence=sequence)
num_val_batches = int(len(val_dataset)/batch_size)
print ("num_val_batches:", num_val_batches)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=batch_size, shuffle=False,
num_workers=1)
network.eval() # (set in evaluation mode, this affects BatchNorm and dropout)
unsorted_img_ids = []
for step, (imgs, img_ids) in enumerate(val_loader):
with torch.no_grad(): # (corresponds to setting volatile=True in all variables, this is done during inference to reduce memory consumption)
imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))
outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))
####################################################################
# save data for visualization:
####################################################################
outputs = outputs.data.cpu().numpy() # (shape: (batch_size, num_classes, img_h, img_w))
pred_label_imgs = np.argmax(outputs, axis=1) # (shape: (batch_size, img_h, img_w))
pred_label_imgs = pred_label_imgs.astype(np.uint8)
for i in range(pred_label_imgs.shape[0]):
pred_label_img = pred_label_imgs[i] # (shape: (img_h, img_w))
img_id = img_ids[i]
img = imgs[i] # (shape: (3, img_h, img_w))
img = img.data.cpu().numpy()
img = np.transpose(img, (1, 2, 0)) # (shape: (img_h, img_w, 3))
img = img*np.array([0.229, 0.224, 0.225])
img = img + np.array([0.485, 0.456, 0.406])
img = img*255.0
img = img.astype(np.uint8)
pred_label_img_color = label_img_to_color(pred_label_img)
overlayed_img = 0.35*img + 0.65*pred_label_img_color
overlayed_img = overlayed_img.astype(np.uint8)
img_h = overlayed_img.shape[0]
img_w = overlayed_img.shape[1]
cv2.imwrite("training_logs/model_eval_seq" + "/" + img_id + ".png", img)
cv2.imwrite("training_logs/model_eval_seq" + "/" + img_id + "_pred.png", pred_label_img_color)
cv2.imwrite("training_logs/model_eval_seq" + "/" + img_id + "_overlayed.png", overlayed_img)
unsorted_img_ids.append(img_id)
############################################################################
# create visualization video:
############################################################################
out = cv2.VideoWriter("training_logs/model_eval_seq/stuttgart_%s_combined.avi" % (sequence), cv2.VideoWriter_fourcc(*"MJPG"), 20, (2*img_w, 2*img_h))
sorted_img_ids = sorted(unsorted_img_ids)
for img_id in sorted_img_ids:
img = cv2.imread("training_logs/model_eval_seq" + "/" + img_id + ".png", -1)
pred_img = cv2.imread("training_logs/model_eval_seq" + "/" + img_id + "_pred.png", -1)
overlayed_img = cv2.imread("training_logs/model_eval_seq" + "/" + img_id + "_overlayed.png", -1)
combined_img = np.zeros((2*img_h, 2*img_w, 3), dtype=np.uint8)
combined_img[0:img_h, 0:img_w] = img
combined_img[0:img_h, img_w:(2*img_w)] = pred_img
combined_img[img_h:(2*img_h), (int(img_w/2)):(img_w + int(img_w/2))] = overlayed_img
out.write(combined_img)
out.release()
然后展示一下训练6个epoch的预测效果
输入图片
得到结果:
这个模型参数量非常大,要多多训练,或者利用大型数据集训练后的模型作为预训练模型,之后再来训练,效果才能完全体现出来!
如果你觉得此博客对你有所帮助的话,不妨帮我点一个赞哦!