采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤

官方的Segmentation Transformer源码是基于MMSegmentation框架的,不便于阅读和学习,想使用官方版本的就不用参考此博客了。

这里采用的是GitHub上某大佬复现Segmentation Transformer的版本

Segmentation Transformer 源码下载链接 https://github.com/gupta-abhay/setr-pytorch ,此源码缺少训练脚本 我将SETR网络模型去替换了DeepLabV3的网络模型,有关DeepLabV3模型的使用与修改见此博客https://blog.csdn.net/qq_41964545/article/details/115252939?spm=1001.2014.3001.5501

那么开始调整源码吧

一、下载Segmentation Transformer 源码

将源码下载好后,解压至DeeplabV3文件夹下

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第1张图片

二、修改datasets.py

按照上一篇Deeplabv3博客处理好CityScapes数据集的label

由于SETR模型设计了三种decoder结构 这里采用的是最简单的Naive结构,这里采用的是SETR_Naive_S网络模型,如下,查看源码可以看出CityScapes数据集用于训练的图像大小为768*768,首先将类别数修改为20

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第2张图片

然后就需要datasets.py部分

首先修改DatasetTrain这个类

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第3张图片

对于数据增强部分,我只保留了随机翻转,其余的 randomly scale the img and the label部分和random crop from the img and label我进行了注释,你也可以根据自己的需要调整,但是要保证返回的图像的大小是768*768

同样地对于DatasetVal这个类

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第4张图片

同样地对于DatasetSeq这个类

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第5张图片

datasets.py代码如下:

# camera-ready

import torch
import torch.utils.data

import numpy as np
import cv2
import os

train_dirs = ["jena/", "zurich/", "weimar/", "ulm/", "tubingen/", "stuttgart/",
              "strasbourg/", "monchengladbach/", "krefeld/", "hanover/",
              "hamburg/", "erfurt/", "dusseldorf/", "darmstadt/", "cologne/",
              "bremen/", "bochum/", "aachen/"]
val_dirs = ["frankfurt/", "munster/", "lindau/"]
test_dirs = ["berlin", "bielefeld", "bonn", "leverkusen", "mainz", "munich"]

class DatasetTrain(torch.utils.data.Dataset):
    def __init__(self, cityscapes_data_path, cityscapes_meta_path):
        self.img_dir = cityscapes_data_path + "/leftImg8bit/train/"
        self.label_dir = cityscapes_meta_path + "/label_imgs/"

        self.img_h = 1024
        self.img_w = 2048

        # self.new_img_h = 512
        # self.new_img_w = 1024
        self.new_img_h = 768
        self.new_img_w = 768

        self.examples = []
        for train_dir in train_dirs:
            train_img_dir_path = self.img_dir + train_dir

            file_names = os.listdir(train_img_dir_path)
            for file_name in file_names:
                img_id = file_name.split("_leftImg8bit.png")[0]

                img_path = train_img_dir_path + file_name

                label_img_path = self.label_dir + img_id + ".png"

                example = {}
                example["img_path"] = img_path
                example["label_img_path"] = label_img_path
                example["img_id"] = img_id
                self.examples.append(example)

        self.num_examples = len(self.examples)

    def __getitem__(self, index):
        example = self.examples[index]

        img_path = example["img_path"]
        img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
        # resize img without interpolation (want the image to still match
        # label_img, which we resize below):
        img = cv2.resize(img, (self.new_img_w, self.new_img_h),
                         interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))

        label_img_path = example["label_img_path"]
        label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
        # resize label_img without interpolation (want the resulting image to
        # still only contain pixel values corresponding to an object class):
        label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h),
                               interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))

        # flip the img and the label with 0.5 probability:
        flip = np.random.randint(low=0, high=2)
        if flip == 1:
            img = cv2.flip(img, 1)
            label_img = cv2.flip(label_img, 1)

        ########################################################################
        # randomly scale the img and the label:
        ########################################################################
        # scale = np.random.uniform(low=0.7, high=2.0)
        # new_img_h = int(scale*self.new_img_h)
        # new_img_w = int(scale*self.new_img_w)

        # resize img without interpolation (want the image to still match
        # label_img, which we resize below):
        # img = cv2.resize(img, (new_img_w, new_img_h),
        #                  interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w, 3))

        # resize label_img without interpolation (want the resulting image to
        # still only contain pixel values corresponding to an object class):
        # label_img = cv2.resize(label_img, (new_img_w, new_img_h),
        #                        interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w))
        ########################################################################

        # # # # # # # # debug visualization START
        # print (scale)
        # print (new_img_h)
        # print (new_img_w)
        #
        # cv2.imshow("test", img)
        # cv2.waitKey(0)
        #
        # cv2.imshow("test", label_img)
        # cv2.waitKey(0)
        # # # # # # # # debug visualization END

        ########################################################################
        # select a 256x256 random crop from the img and label:
        ########################################################################
        # start_x = np.random.randint(low=0, high=(new_img_w - 256))
        # end_x = start_x + 256
        # start_y = np.random.randint(low=0, high=(new_img_h - 256))
        # end_y = start_y + 256

        # start_x = np.random.randint(low=0, high=(new_img_w - 768))
        # end_x = start_x + 768
        # start_y = np.random.randint(low=0, high=(new_img_h - 768))
        # end_y = start_y + 768

        # img = img[start_y:end_y, start_x:end_x] # (shape: (256, 256, 3))
        # label_img = label_img[start_y:end_y, start_x:end_x] # (shape: (256, 256))
        ########################################################################

        # # # # # # # # debug visualization START
        # print (img.shape)
        # print (label_img.shape)
        #
        # cv2.imshow("test", img)
        # cv2.waitKey(0)
        #
        # cv2.imshow("test", label_img)
        # cv2.waitKey(0)
        # # # # # # # # debug visualization END

        # normalize the img (with the mean and std for the pretrained ResNet):
        img = img/255.0
        img = img - np.array([0.485, 0.456, 0.406])
        img = img/np.array([0.229, 0.224, 0.225]) # (shape: (256, 256, 3))
        img = np.transpose(img, (2, 0, 1)) # (shape: (3, 256, 256))
        img = img.astype(np.float32)

        # convert numpy -> torch:
        img = torch.from_numpy(img) # (shape: (3, 256, 256))
        label_img = torch.from_numpy(label_img) # (shape: (256, 256))

        return (img, label_img)

    def __len__(self):
        return self.num_examples

class DatasetVal(torch.utils.data.Dataset):
    def __init__(self, cityscapes_data_path, cityscapes_meta_path):
        self.img_dir = cityscapes_data_path + "/leftImg8bit/val/"
        self.label_dir = cityscapes_meta_path + "/label_imgs/"

        self.img_h = 1024
        self.img_w = 2048

        self.new_img_h = 768
        self.new_img_w = 768

        self.examples = []
        for val_dir in val_dirs:
            val_img_dir_path = self.img_dir + val_dir

            file_names = os.listdir(val_img_dir_path)
            for file_name in file_names:
                img_id = file_name.split("_leftImg8bit.png")[0]

                img_path = val_img_dir_path + file_name

                label_img_path = self.label_dir + img_id + ".png"
                label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))

                example = {}
                example["img_path"] = img_path
                example["label_img_path"] = label_img_path
                example["img_id"] = img_id
                self.examples.append(example)

        self.num_examples = len(self.examples)

    def __getitem__(self, index):
        example = self.examples[index]

        img_id = example["img_id"]

        img_path = example["img_path"]
        img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
        # resize img without interpolation (want the image to still match
        # label_img, which we resize below):
        img = cv2.resize(img, (self.new_img_w, self.new_img_h),
                         interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))

        label_img_path = example["label_img_path"]
        label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
        # resize label_img without interpolation (want the resulting image to
        # still only contain pixel values corresponding to an object class):
        label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h),
                               interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))

        # # # # # # # # debug visualization START
        # cv2.imshow("test", img)
        # cv2.waitKey(0)
        #
        # cv2.imshow("test", label_img)
        # cv2.waitKey(0)
        # # # # # # # # debug visualization END

        # normalize the img (with the mean and std for the pretrained ResNet):
        img = img/255.0
        img = img - np.array([0.485, 0.456, 0.406])
        img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
        img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
        img = img.astype(np.float32)

        # convert numpy -> torch:
        img = torch.from_numpy(img) # (shape: (3, 512, 1024))
        label_img = torch.from_numpy(label_img) # (shape: (512, 1024))

        return (img, label_img, img_id)

    def __len__(self):
        return self.num_examples

class DatasetSeq(torch.utils.data.Dataset):
    def __init__(self, cityscapes_data_path, cityscapes_meta_path, sequence):
        self.img_dir = cityscapes_data_path + "/leftImg8bit/demoVideo/stuttgart_" + sequence + "/"
        # self.img_dir = cityscapes_data_path + "/leftImg8bit/" + sequence + "/"

        self.img_h = 1024
        self.img_w = 2048

        self.new_img_h = 768
        self.new_img_w = 768

        self.examples = []

        file_names = os.listdir(self.img_dir)
        for file_name in file_names:
            img_id = file_name.split("_leftImg8bit.png")[0]

            img_path = self.img_dir + file_name

            example = {}
            example["img_path"] = img_path
            example["img_id"] = img_id
            self.examples.append(example)

        self.num_examples = len(self.examples)

    def __getitem__(self, index):
        example = self.examples[index]

        img_id = example["img_id"]

        img_path = example["img_path"]
        img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
        # resize img without interpolation:
        img = cv2.resize(img, (self.new_img_w, self.new_img_h),
                         interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))

        # normalize the img (with the mean and std for the pretrained ResNet):
        img = img/255.0
        img = img - np.array([0.485, 0.456, 0.406])
        img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
        img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
        img = img.astype(np.float32)

        # convert numpy -> torch:
        img = torch.from_numpy(img) # (shape: (3, 512, 1024))

        return (img, img_id)

    def __len__(self):
        return self.num_examples

class DatasetThnSeq(torch.utils.data.Dataset):
    def __init__(self, thn_data_path):
        self.img_dir = thn_data_path + "/"

        self.examples = []

        file_names = os.listdir(self.img_dir)
        for file_name in file_names:
            img_id = file_name.split(".png")[0]

            img_path = self.img_dir + file_name

            example = {}
            example["img_path"] = img_path
            example["img_id"] = img_id
            self.examples.append(example)

        self.num_examples = len(self.examples)

    def __getitem__(self, index):
        example = self.examples[index]

        img_id = example["img_id"]

        img_path = example["img_path"]
        img = cv2.imread(img_path, -1) # (shape: (512, 1024, 3))

        # normalize the img (with mean and std for the pretrained ResNet):
        img = img/255.0
        img = img - np.array([0.485, 0.456, 0.406])
        img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
        img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
        img = img.astype(np.float32)

        # convert numpy -> torch:
        img = torch.from_numpy(img) # (shape: (3, 512, 1024))

        return (img, img_id)

    def __len__(self):
        return self.num_examples

三、修改train.py

读入模型

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第6张图片

修改后的train.py部分

# camera-ready

import sys
import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
from datasets import DatasetTrain, DatasetVal # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)

sys.path.append("/root/deeplabv3/model")
from model.deeplabv3 import DeepLabV3

sys.path.append("/root/deeplabv3/utils")
from utils.utils import add_weight_decay

from setr.SETR import SETR_Naive_S

import torch
import torch.utils.data
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import pickle
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2

import time

if __name__ == "__main__":
    # NOTE! NOTE! change this to not overwrite all log data when you train the model:
    # network = DeepLabV3(model_id=1, project_dir="E:/master/master1/RSISS/deeplabv3/deeplabv3").cuda()
    # x = Variable(torch.randn(2,3,256,256)).cuda() 
    # print(x.shape)
    # y = network(x)
    # print(y.shape)
    model_id = "1"

    num_epochs = 1000
    batch_size = 1
    learning_rate = 0.0001

    # network = DeepLabV3(model_id, project_dir="E:/master/master1/RSISS/deeplabv3/deeplabv3").cuda()
    _,network =  SETR_Naive_S()
    network.cuda()

    train_dataset = DatasetTrain(cityscapes_data_path="/mnt/cityscapes",
                                cityscapes_meta_path="/mnt/cityscapes")
    val_dataset = DatasetVal(cityscapes_data_path="/mnt/cityscapes",
                            cityscapes_meta_path="/mnt/cityscapes")

    num_train_batches = int(len(train_dataset)/batch_size)
    num_val_batches = int(len(val_dataset)/batch_size)
    print ("num_train_batches:", num_train_batches)
    print ("num_val_batches:", num_val_batches)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, shuffle=True,
                                            num_workers=1)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                            batch_size=batch_size, shuffle=False,
                                            num_workers=1)

    params = add_weight_decay(network, l2_value=0.0001)
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    # with open("/mnt/cityscapes/class_weights.pkl", "rb") as file: # (needed for python3)
    #     class_weights = np.array(pickle.load(file))
    # class_weights = torch.from_numpy(class_weights)
    # class_weights = Variable(class_weights.type(torch.FloatTensor)).cuda()

    # loss function
    loss_fn = nn.CrossEntropyLoss(weight=None)

    epoch_losses_train = []
    epoch_losses_val = []
    for epoch in range(num_epochs):
        print ("###########################")
        print ("######## NEW EPOCH ########")
        print ("###########################")
        print ("epoch: %d/%d" % (epoch+1, num_epochs))

        ############################################################################
        # train:
        ############################################################################
        network.train() # (set in training mode, this affects BatchNorm and dropout)
        batch_losses = []
        for step, (imgs, label_imgs) in enumerate(train_loader):
            #current_time = time.time()

            imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))
            # print("imgs.shape: ",imgs.shape)
            label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() # (shape: (batch_size, img_h, img_w))
            # print("label_imgs.shape: ",label_imgs.shape)

            outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))
            # print(outputs)
            # print("shape of label_imgs: ",label_imgs.shape)
            # print("shape of outputs: ",outputs.shape)

            # compute the loss:
            loss = loss_fn(outputs, label_imgs)
            loss_value = loss.data.cpu().numpy()
            batch_losses.append(loss_value)

            # optimization step:
            optimizer.zero_grad() # (reset gradients)
            loss.backward() # (compute gradients)
            optimizer.step() # (perform optimization step)

            #print (time.time() - current_time)

        epoch_loss = np.mean(batch_losses)
        epoch_losses_train.append(epoch_loss)
        # with open("%s/epoch_losses_train.pkl" % network.model_dir, "wb") as file:
        #     pickle.dump(epoch_losses_train, file)
        with open("training_logs/model_1/epoch_losses_train.pkl" , "wb") as file:
            pickle.dump(epoch_losses_train, file)
        print ("train loss: %g" % epoch_loss)
        plt.figure(1)
        plt.plot(epoch_losses_train, "k^")
        plt.plot(epoch_losses_train, "k")
        plt.ylabel("loss")
        plt.xlabel("epoch")
        plt.title("train loss per epoch")
        plt.savefig("training_logs/model_1/epoch_losses_train.png")
        plt.close(1)

        print ("####")

        ############################################################################
        # val:
        ############################################################################
        network.eval() # (set in evaluation mode, this affects BatchNorm and dropout)
        batch_losses = []
        for step, (imgs, label_imgs, img_ids) in enumerate(val_loader):
            with torch.no_grad(): # (corresponds to setting volatile=True in all variables, this is done during inference to reduce memory consumption)
                imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))
                label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() # (shape: (batch_size, img_h, img_w))

                outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))

                # compute the loss:
                loss = loss_fn(outputs, label_imgs)
                loss_value = loss.data.cpu().numpy()
                batch_losses.append(loss_value)

        epoch_loss = np.mean(batch_losses)
        epoch_losses_val.append(epoch_loss)
        with open("training_logs/model_1/epoch_losses_val.pkl" , "wb") as file:
            pickle.dump(epoch_losses_val, file)
        print ("val loss: %g" % epoch_loss)
        plt.figure(1)
        plt.plot(epoch_losses_val, "k^")
        plt.plot(epoch_losses_val, "k")
        plt.ylabel("loss")
        plt.xlabel("epoch")
        plt.title("val loss per epoch")
        plt.savefig("training_logs/model_1/epoch_losses_val.png" )
        plt.close(1)

        # save the model weights to disk:
        checkpoint_path = "training_logs/model_1/checkpoints/model_1" +"_epoch_" + str(epoch+1) + ".pth"
        torch.save(network.state_dict(), checkpoint_path)

 

之后就可以训练了,训练过程如下:

此模型参数非常大,我用24G显存,batch_size为1才能勉强训练,展示一下训练6个Epoch的Loss

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第7张图片

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第8张图片

三、测试部分run_on_seg.py部分修改

为了与原始DeepLabV3其余部分代码兼容,随后我修改了许多文件保存路径

修改后run_on_seg.py代码如下:

# camera-ready

import sys

from setr.SETR import SETR_Naive_S
from datasets import DatasetSeq # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)


from utils.utils import label_img_to_color

import torch
import torch.utils.data
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import pickle
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2

import os

if __name__ =="__main__":

    batch_size = 2

    # network = DeepLabV3("eval_seq", project_dir="E:/master/master1/RSISS/deeplabv3/deeplabv3").cuda()
    _,network =  SETR_Naive_S()
    network.cuda()
    network.load_state_dict(torch.load("training_logs/model_1/checkpoints/model_1_epoch_6.pth"))

    for sequence in ["00", "01", "02"]:
        print (sequence)

        val_dataset = DatasetSeq(cityscapes_data_path="/mnt/cityscapes",
                                cityscapes_meta_path="/mnt/cityscapes",
                                sequence=sequence)

        num_val_batches = int(len(val_dataset)/batch_size)
        print ("num_val_batches:", num_val_batches)

        val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                batch_size=batch_size, shuffle=False,
                                                num_workers=1)

        network.eval() # (set in evaluation mode, this affects BatchNorm and dropout)
        unsorted_img_ids = []
        for step, (imgs, img_ids) in enumerate(val_loader):
            with torch.no_grad(): # (corresponds to setting volatile=True in all variables, this is done during inference to reduce memory consumption)
                imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))

                outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))

                ####################################################################
                # save data for visualization:
                ####################################################################
                outputs = outputs.data.cpu().numpy() # (shape: (batch_size, num_classes, img_h, img_w))
                pred_label_imgs = np.argmax(outputs, axis=1) # (shape: (batch_size, img_h, img_w))
                pred_label_imgs = pred_label_imgs.astype(np.uint8)

                for i in range(pred_label_imgs.shape[0]):
                    pred_label_img = pred_label_imgs[i] # (shape: (img_h, img_w))
                    img_id = img_ids[i]
                    img = imgs[i] # (shape: (3, img_h, img_w))

                    img = img.data.cpu().numpy()
                    img = np.transpose(img, (1, 2, 0)) # (shape: (img_h, img_w, 3))
                    img = img*np.array([0.229, 0.224, 0.225])
                    img = img + np.array([0.485, 0.456, 0.406])
                    img = img*255.0
                    img = img.astype(np.uint8)

                    pred_label_img_color = label_img_to_color(pred_label_img)
                    overlayed_img = 0.35*img + 0.65*pred_label_img_color
                    overlayed_img = overlayed_img.astype(np.uint8)

                    img_h = overlayed_img.shape[0]
                    img_w = overlayed_img.shape[1]

                    cv2.imwrite("training_logs/model_eval_seq" + "/" + img_id + ".png", img)
                    cv2.imwrite("training_logs/model_eval_seq" + "/" + img_id + "_pred.png", pred_label_img_color)
                    cv2.imwrite("training_logs/model_eval_seq" + "/" + img_id + "_overlayed.png", overlayed_img)

                    unsorted_img_ids.append(img_id)

        ############################################################################
        # create visualization video:
        ############################################################################
        out = cv2.VideoWriter("training_logs/model_eval_seq/stuttgart_%s_combined.avi" % (sequence), cv2.VideoWriter_fourcc(*"MJPG"), 20, (2*img_w, 2*img_h))
        sorted_img_ids = sorted(unsorted_img_ids)
        for img_id in sorted_img_ids:
            img = cv2.imread("training_logs/model_eval_seq" + "/" + img_id + ".png", -1)
            pred_img = cv2.imread("training_logs/model_eval_seq" + "/" + img_id + "_pred.png", -1)
            overlayed_img = cv2.imread("training_logs/model_eval_seq" + "/" + img_id + "_overlayed.png", -1)

            combined_img = np.zeros((2*img_h, 2*img_w, 3), dtype=np.uint8)

            combined_img[0:img_h, 0:img_w] = img
            combined_img[0:img_h, img_w:(2*img_w)] = pred_img
            combined_img[img_h:(2*img_h), (int(img_w/2)):(img_w + int(img_w/2))] = overlayed_img

            out.write(combined_img)

        out.release()

然后展示一下训练6个epoch的预测效果

输入图片

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第9张图片

得到结果:

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第10张图片

采用Segmentation Transformer(SETR)(Pytorch版本)训练CityScapes数据集步骤_第11张图片

这个模型参数量非常大,要多多训练,或者利用大型数据集训练后的模型作为预训练模型,之后再来训练,效果才能完全体现出来!

如果你觉得此博客对你有所帮助的话,不妨帮我点一个赞哦!

 

 

 

 

你可能感兴趣的:(Pytorch学习)