目录
本周完成的计划
论文阅读
Abstract(摘要)
1. Introduction(绪论)
2. Approach(方法)
3. Discussions(讨论)
4. Experiments(实验)
4.1. Setup(设置)
4.2. Results(结果)
4.3. Improving Full- and Few-Supervision(改进完全监督和少数监督)
4.4. Empirical Study(实证研究)
4.5. Qualitative Results(定性结果)
5. Conclusion(结论)
病理组织分割项目
1.代码
2.结果
Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision(带有交叉伪监督的半监督式语义分割法)
在本文中,我们通过探索标记数据和额外的无标记数据来研究半监督的语义分割问题。我们提出了一种新的一致性正则化方法,称为交叉伪监督(CPS)。我们的方法对同一输入图像的两个分割网络施加了不同初始化的一致性。从一个扰动的分割网络输出的伪单热标签图,被用来监督另一个具有标准交叉熵损失的分割网络,反之亦然。CPS的一致性有两个作用:鼓励两个扰动网络对同一输入图像的预测的高度相似性,并通过使用带有伪标签的未标记数据来扩大训练数据。实验结果表明,我们的方法在Cityscapes和PASCAL VOC 2012上实现了最先进的半监督分割性能。
图像语义分割是计算机视觉中的一项基本识别任务。语义分割的训练数据需要像素级的人工标注,与其他视觉任务(如图像分类和物体检测)相比,其成本要高得多。这使得半监督性分割成为一个重要的问题,通过使用标记的数据以及额外的未标记的数据来学习分割模型。它通过各种扰动来执行预测的一致性,例如,通过增强输入图像的输入扰动[11,19],特征扰动[27]和网络扰动[18]。自我训练也被研究用于半监督性分割[6, 43, 42, 9, 13, 25]。它在未标记的图像上加入了从标记图像上训练的分割模型获得的伪分割图,以扩大训练数据,并重新训练分割模型。我们提出了一种新颖而简单的带有网络扰动的一致性正则化方法,称为交叉伪监督。所提出的方法是将已标记和未标记的图像输入两个结构相同但初始化不同的分割网络。这两个网络在标记数据上的输出分别由相应的地面真实分割图监督。我们的主要观点在于交叉伪监督,它强制两个分割网络之间的一致性。每个输入图像的分割网络估计一个分割结果,称为伪分割图。伪分割图被用作监督其他分割网络的额外信号。
交叉伪监督方案的好处在于两个方面。一方面,像以前的一致性正则化一样,所提出的方法鼓励不同的初始化网络对同一输入图像的预测是一致的,并且预测决策边界位于低密度区域。另一方面,在后期的优化阶段,伪分割变得稳定,并且比只在标记数据上进行正常监督训练的结果更准确。伪标记数据的行为就像扩大训练数据一样,从而提高了分割网络的训练质量。在Cityscapes和PASCAL VOC 2012两个基准上的不同设置的实验结果表明,所提出的交叉伪监督方法优于现有的半监督分割的一致性方案。我们的方法在两个基准上都达到了最先进的半监督分割性能。
给定一组由N个标记图像组成的Dl和一组由M个未标记图像组成的Du,半监督的语义分割任务旨在通过探索标记和未标记的图像来学习分割网络。
交叉伪监督:提出的方法由两个平行的分割网络组成
这两个网络具有相同的结构,它们的权重,即θ1和θ2,被初始化得不同。输入X是有相同的增量,P1(P2)是分割置信度图,是softmax归一化后的网络输出。提出的方法在逻辑上说明如下1:
这里Y1(Y2)是预测的一热标签图,称为伪分割图。在每个位置i,标签向量y1i(y2i)是由相应的置信度向量p1i(p2i)计算得出的单热向量。 我们的方法的完整版本如图1(a)所示,我们在上述方程中没有包括损失监督。
训练目标包含两个损失:监督损失Ls和交叉伪监督损失Lcps。监督损失Ls是使用标准的像素级交叉熵损失对两个平行分割网络上的标记图像制定的。
其中'ce是交叉熵损失函数,y∗1i(y∗2i)是地面真相。W和H代表输入图像的宽度和高度。交叉伪监督损失是双向的。一个是从f(θ1)到f(θ2)。我们用一个网络f(θ1)输出的像素级一热标签图Y1来监督另一个网络f(θ2)的像素级置信度图P2,另一个是从f(θ2)到f(θ1)。无标签数据的交叉伪监督损失写为:
我们还以同样的方式定义了标签数据上的交叉伪监督损失Llcps。整个交叉伪监督损失是有标签数据和无标签数据的损失的组合。Lcps = Llcps + Lucps。整个训练目标被写成:
与CutMix增强方案的结合。CutMix增强方案被应用于半监督性分割的平均教师框架。我们还在我们的方法中应用了CutMix增强技术。我们将CutMix图像输入两个网络f(θ1)和f(θ2)。我们使用类似于[11]的方式从两个网络中生成伪分割图:将两张源图像(用于生成CutMix图像)输入每个分割网络,并将这两张伪分割图作为另一个分割网络的监督。
我们讨论了我们的方法与几个相关工作的关系,如下所示:
交叉概率一致性:
两个受扰动网络的可选一致性是交叉概率一致性:概率向量(来自像素置信度图)应该相似(如图1(b)所示)。
损失函数被写成:
这里用一个损失的例子来强加一致性。其他损失,如KL divergence,以及对中间特征的一致性也可以使用。我们用D来表示有标签的集合Dl和无标签的集合Du的联合。与特征/概率的一致性相似,提议的交叉伪监督一致性也期望两个扰动的分割网络之间的一致性。特别是,我们的方法在某种意义上通过使用伪标签探索未标记的数据来增强训练数据。表4所示的实证结果表明,交叉伪监督优于交叉概率一致性。
Mean teacher:
Mean teacher最初用于半监督分类,最近用于半监督分割,例如在CutMix Seg中。 具有不同增强的未标记图像被馈入具有相同结构的两个网络:一个是学生f(θ),另一个是平均教师f(θ),参数是学生网络参数θ的移动平均值:
我们使用X1和X2来表示X的不同增强版本。一致性正则化旨在将学生网络预测的X1的概率图P1与教师网络预测的X2的概率图P2对齐。在训练过程中,我们和P2一起监督P1,并对教师网络应用无反向传播算法。在下图中,我们使用来表示“无反向传播”。我们没有将损失监管包括在上述等式中,并在图1 (c)中说明了完整版本。
单网伪监管:
我们考虑我们的方法的降级版本,单网络伪监管,其中两个网络是相同的:
其结构类似于图1 (d ),唯一的区别是两个流的输入是相同的,而不是一个弱增强和一个强增强。
我们用-从Y到P来表示损失监督。实证结果表明,单网伪监管表现不佳。主要原因是来自同一网络的伪标签的监督倾向于学习网络本身以更好地逼近伪标签,因此网络可能会向错误的方向收敛。相反,来自另一个网络的交叉伪标签的监督,由于网络扰动而不同于来自网络本身的伪标签,能够以一定的概率远离错误的方向学习网络。换句话说,两个网络之间的伪标签的扰动在某种意义上充当正则化器,而不会过度拟合错误的方向。此外,我们研究了单一网络伪监督的方式类似于[11]与切割混合扩增。伪监督的反向传播仅对剪切混合图像进行。结果表明,我们的方法执行得更好(表6),这意味着网络扰动是有帮助的,尽管在[11]中已经存在来自CutMix增强的方式的扰动。
PseudoSeg :
与FixMatch类似,PseudoSeg 应用弱增强图像Xw来生成伪分割图,该伪分割图用于监督来自相同网络的具有相同参数的强增强图像Xs的输出。Xw和Xs基于相同的输入图像X,PseudoSeg仅在处理强增强图像Xs的路径上进行反向传播(如图1 (d)所示)。它的逻辑形式如下:
我们用从Yw到Ps来表示损失监督,上述方式类似于单网伪监管。不同之处在于伪分割图来自弱增强,并且它监督强增强上的训练。我们猜测,除了基于弱增强的分割图更准确之外,另一个原因与我们的方法相同:来自弱增强的伪分割图也对伪监督引入了额外的扰动。
数据集:
PASCAL VOC 2012 是一个标准的以对象为中心的语义分割数据集,它由13,000多幅图像组成,包含20个对象类和1个背景类。标准训练、验证和测试集分别由1464、1449和1456幅图像组成。我们遵循之前的工作,使用扩充集 (10582幅图像)作为我们的完整训练集。
Cityscapes 主要是为城市场景理解而设计的。官方的划分有2975张图片用于训练,500张用于验证,1525张用于测试。每幅图像的分辨率为2048 × 1024,并使用19个语义类别的像素级标签进行精细注释。
我们遵循指导协作训练(GCT) 的划分协议,通过对整个集合的1/2、1/4、1/8和1/16进行随机子采样,将整个训练集合分为两组,作为标记集合,并将剩余图像视为未标记集合。
评价:
我们使用平均交并(mIoU)度量来评估分割性能。对于所有分区协议,我们仅通过单一规模测试报告了1456 PASCAK VOC 2012 val集(或500 Cityscapes val集)的结果。在我们的方法中,我们仅使用一个网络来生成评估结果。
实施细节:
我们基于PyTorch框架实现了我们的方法。我们使用在ImageNet上预先训练的相同权重和两个分割头(DeepLabv3+)的权重来初始化两个分割网络中两个主干的权重。我们采用带动量的小批量SGD,用Sync-BN训练我们的模型。动量固定为0.9,重量衰减设定为0.0005。我们采用多学习率策略,初始学习率乘以。
对于在完整训练集上训练的监督基线,如果没有指定,我们使用随机水平翻转和多尺度作为数据扩充。我们将PASCAL VOC 2012训练60个时期,基础学习率设置为0.01,将Cityscapes训练240个时期,基础学习率设置为0.04。OHEM损失用于城市景观。
对基线的改进:
在图2中,我们展示了与所有分区协议下的监督基线相比,我们的方法的改进。所有的方法都是基于带有ResNet-50或ResNet-101的DeepLabv3+。
图2 (a)显示了我们的方法在ResNet50的城市景观上始终优于监督基线。具体而言,在1/16、1/8、1/4和1/2分区协议下,我们的方法w/o CutMix增强相对于基线方法w/o CutMix增强的改进分别为4.89%、4.07%、2.74%和2.42%。
图2 (b)显示了我们的方法在使用ResNet-101的Cityscapes上相对于基线方法的增益:在1/16、1/8、1/4和1/2分区协议下分别为3.70%、3.52%、2.11%和2.02%。
图2还显示了CutMix增强带来的改进。我们可以看到,CutMix在1/16和1/8分区下比在1/4和1/2分区下带来了更多的增益。例如,在使用ResNet-101的Cityscapes上,在1/16、1/8和1/2分区协议下,CutMix增强带来的额外增益分别为4.22%、1.91%和0.13%。
与SOTA的比:
我们将我们的方法与最近的一些半监督分割方法进行了比较,这些方法包括:不同分割协议下的Meat-Teacher (MT)、交叉一致性训练(CCT) 、引导式协作训练(GCT) 和CutMix-Seg。具体来说,我们采用CutMix-Seg的官方开源实现。对于MT和GCT,我们使用[17]中的实现。为了公平起见,我们使用相同的架构和分区协议对它们进行了比较。
PASCAL VOC 2012:
表1显示了PASCAL VOC 2012的比较结果。我们可以看到,在所有分区上,使用ResNet-50和ResNet-101,我们的方法w/o CutMix增强始终优于其他方法,除了使用强CutMix增强的CutMix-Seg。
我们的方法w/ CutMix增强性能最佳,并在所有分区协议下创造了新的技术水平。例如,我们的方法w/ CutMix增强在1/16分区协议下分别以ResNet-50和ResNet-101优于cut mix-Seg 3.08%和1.92%。结果表明,我们的交叉伪监督方案优于CutMix-Seg中使用的平均教师方案。当比较我们的方法w/o和w/ CutMix增强的结果时,我们有以下观察结果:CutMix增强对于具有较少标记数据的场景更重要。例如,使用ResNet-50,1/16分区下的增益3.77%高于1/8分区下的0.47%。
Cityscapes:
表2说明了城市景观估值集的比较结果。我们没有CutMix-Seg的结果,因为官方的CutMix-Seg实现只支持单GPU训练,由于GPU内存的限制,在Cityscapes上用DeepLabv3+运行CutMix-Seg是不可行的。与其他SOTA方法相比,我们的方法在ResNet-50和ResNet-101骨干网的所有分区协议中取得了最佳性能。例如,我们的方法加上CutMix增强,在ResNet-101骨干的1/2分区下获得了80.08%,比GCT高出1.50%。我们在表3中报告了HRNet的其他结果。
例如,我们的方法加上CutMix增强,在ResNet-101骨干的1/2分区下获得了80.08%,比GCT高出1.50%。我们在表3中报告了HRNet的其他结果。
完全监督:
我们使用完整的Cityscapes训练集(2,975张图片)来验证我们的方法,并从Cityscapes粗略集中随机抽取3,000张图片作为无标签集。对于无标签集,我们不使用他们的粗略注释的地面真相。图3说明了对Cityscapes估值集的单尺度评估结果。
改进完全监督的基线。基线模型(DeepLabv3+与ResNet-101和HRNet-W48)是使用完整的Cityscapes训练集训练的。我们的方法使用了来自Cityscapes粗集的3,000张图片作为额外的无标签集进行训练。我们的方法的优越性意味着我们的方法在相对较大的标记数据上运行良好。我们可以看到,即使有大量的标记数据,我们的方法仍然可以从无标记数据的训练中受益,而且我们的方法在最先进的分割网络HRNet上也有很好的效果。
很少监督:
我们通过遵循PseudoSeg 中采用的相同分区协议,在PASCAL VOC 2012上研究了我们方法的性能。PseudoSeg随机采样标准训练集中1/2、1/4、1/8和1/16的图像(大约1.5k个图像)来构建标记集。标准训练集中的剩余图像以及扩充集中的图像(大约9k个图像)被用作未标记集。我们只报告我们的方法的结果,即增加CutMix,因为CutMix对很少的监督很重要。结果如表5所示,除使用ResNet-50的CCT外,所有方法都使用ResNet101作为主干。
我们可以看到,我们的方法表现最好,并且在少数标记的情况下再次优于CutMix Seg。 我们的方法也比PseudoSeg更好,PseudoSeg使用一个复杂的方案来计算伪分割图。 我们认为,原因来自于我们的方法使用了网络扰动和交叉伪监督,而PseudoSeg使用了单一网络的输入扰动。
交叉伪监督。我们在表4中调查了对有标签集(Llcps)或无标签集(Lucps)应用建议的交叉伪监督损失的影响。
在PASCAL VOC 2012和Cityscapes上进行不同损失组合的消融研究。这些结果是在1/8数据分区协议下得到的,对其他分区协议的观察是一致的。Ls表示对标记集的监督损失。Llcps(Lucps)表示在有标签(无标签)的集合上的交叉伪监督损失。Llcpc(Lucpc)表示在有标签(无标签)集上的交叉概率一致性损失。在有标签和无标签的数据上,交叉伪监督损失的整体性能是最好的。
我们可以看到,在大多数情况下,无标签集上的交叉伪监督损失比有标签集上的交叉伪监督损失带来的改进更为显著。在有标签集和无标签集上,交叉伪监督损失的表现总体上是最好的。
与交叉概率一致性的比较:
我们将我们的方法与表4的最后2行的交叉概率一致性进行比较。我们可以看到,在这两个基准上,我们的交叉伪监督胜过交叉概率一致性。例如,在Cityscapes上,当用ResNet-50(ResNet-101)应用于有标签和无标签的集合时,交叉伪监督比交叉概率一致性要好2.36%(1.94%)。
权衡权重 λ:
我们研究了不同的λ的影响,它被用来平衡监督损失和交叉伪监督损失,如公式6所示。
从图4中,我们可以看到,λ=1.5在PASCAL VOC 2012上表现最好,λ=6在Cityscapes上表现最好。我们在所有的实验中使用λ=1.5和λ=6的方法。
单一网络的伪监督与交叉伪监督:
我们在表6中比较了所提出的方法和单网络伪监督
与PASCAL VOC 2012 val上的单网络伪监督比较,SPS=单网络伪监督,所有的方法都是基于DeepLabv3+是与ResNet50。我们可以看到,在有CutMix增强和无CutMix增强这两种情况下,我们的方法都优于单网络的伪监督。带有CutMix增强功能的单网络伪监督类似于FixMatch[28]在语义分割中的应用(如在PseudoSeg中的应用)。我们认为,这是我们的方法优于PseudoSeg的主要原因之一。
与自我训练的结合/比较:
我们实证研究了我们的方法和传统的自我训练的结合,这两个基准的结果总结在表7中。
我们可以看到,自我训练和我们的方法相结合的结果比只用我们的方法和只用自我训练的结果都好。这种优越性意味着我们的方法是对自我训练的补充。
由于自我训练方案由多个阶段组成(对有标签的集合进行训练→对无标签的集合预测伪标签→对有标签和无标签的集合进行重新训练,并使用伪标签),它比我们的方法需要更多的训练历时。为了与自我训练进行更公平的比较,我们对我们的方法进行了更多的历时训练(表示为Ours+),以确保我们的训练历时也能与自我训练相比较。
根据图5所示的结果,我们可以看到,在各种分区协议下,我们的+一直优于自我训练。我们猜测,原因在于我们方法中的一致性正则化。
图6是在PASCAL VOC 2012上的一些细分结果的可视化
例如,在第一行中,仅有监督的方法(第(c)栏)错误地将许多牛的像素归类为马的像素,而我们没有CutMix增强的方法(第(d)栏)修复了这些错误。在第2行,监督基线和我们的方法(没有CutMix增强)都将一些狗的像素误标为马的像素,而我们的方法(有CutMix增强)成功地纠正了这些错误。
我们提出了一种简单而有效的半监督分割方法,即交叉伪监督。我们的方法通过使用一个网络获得的单次伪分割图来监督另一个网络,使两个具有相同结构和不同初始化的网络之间保持一致。另一方面,带有伪分割图的未标记数据在后期训练阶段更加准确,可以作为扩大训练数据来提高性能。
import os
import math
import numpy as np
import glob
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import time
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch
from glob import glob
from tqdm import tqdm
import random
from albumentations import (
HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,VerticalFlip,
IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose,Resize,Sharpen,PiecewiseAffine,Emboss,RandomBrightnessContrast
) # 图像变换函数
from torchvision import transforms
import random
# 设标签宽W,长H
# 计算混淆矩阵
def fast_hist(a, b, n):
#--------------------------------------------------------------------------------#
# a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)
#--------------------------------------------------------------------------------#
k = (a >= 0) & (a < n)
#--------------------------------------------------------------------------------#
# np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
# 返回中,写对角线上的为分类正确的像素点
#--------------------------------------------------------------------------------#
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
# 计算每个类别的平均iou
def per_class_iu(hist):
return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
# 计算准确率
def per_class_PA_Recall(hist):
return np.diag(hist) / np.maximum(hist.sum(1), 1)
def per_class_Precision(hist):
return np.diag(hist) / np.maximum(hist.sum(0), 1)
def per_Accuracy(hist):
return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed(403) # 设置随机种子
# 数据的根路径
# root_path = r"data"
root_path = r"D:\BaiduNetdiskDownload\good_dataset\data"
images_path = os.path.join(root_path,"rgbs_colorNormalized_patchs")
labels_path = os.path.join(root_path,"palette_masks_patchs")
print("images_path = ",images_path)
print("labels_path = ",labels_path)
# 获取图片
data_img = os.listdir(images_path)
# 获取图片对应的标签
data_label = os.listdir(labels_path)
print("data_img = ",len(data_img))
print("labels_img = ",len(data_label))
print("data_img = ",(data_img[0]))
print("labels_img = ",(data_label[0]))
print("data_img = ",data_img[:3])
print("data_label = ",data_label[:3])
data_imgs = []
data_labels = []
# 随机打乱
random.shuffle(data_label)
# random.shuffle(data_labels)
# 获取标签对应图片,防止图片数量比标签多
for i in range(len(data_label)):
img_mask = data_label[i]
img = img_mask.replace('_mask.png','.jpg')
# 有的原图也是png结尾的
img_path = os.path.join(root_path,"rgbs_colorNormalized_patchs",img)
if not os.path.exists(img_path):
img = img_mask.replace('_mask.png', '.png')
img_path = os.path.join(root_path, "rgbs_colorNormalized_patchs", img)
data_imgs.append(img_path)
data_labels.append(os.path.join(root_path, "palette_masks_patchs", img_mask))
# 查看数据集长度
print("data_imgs = ",len(data_imgs))
print("data_labels = ",len(data_labels))
# 打印一下前3张图片
print("data_imgs = ",(data_imgs[:3]))
print("data_labels = ",(data_labels[:3]))
# 查看一下数据
lb=data_labels[22]
print('lb = ',lb)
lb=Image.open(lb)
lb_tensor1 = np.array(lb)
print("lb_tensor1 = ",np.unique(lb_tensor1))
print("lb_tensor1.shape = ",(lb_tensor1.shape))
# 简单的数据增强操作
joint_transformer = Compose([
RandomRotate90(p=0.3),
Flip(),
OneOf([
HorizontalFlip(p=0.3),
VerticalFlip(p=0.3),
Transpose(p=0.2)
], p=0.5),#按照归一化的概率选择执行哪一个
],p=0.3)
img_transformer=transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
label_transformer=transforms.Compose([
transforms.Resize((256,256))
])
# 自定义的数据集类
class HistoBreastdataset(Dataset):
def __init__(self, img, mask, transformer=None, label_transformer=None, joint_transformer=None):
"""
:param img: 图片路径
:param mask: 标签路径
:param transformer: 对图像进行的变换
"""
self.img = img
self.mask = mask
self.transformer = transformer
self.label_transformer = label_transformer
self.joint_transformer = joint_transformer
def __getitem__(self, index):
img_path = self.img[index] # 单个图片路径
mask_path = self.mask[index] # 单个标签路径
# print("img_path = ",img_path)
# print("mask_path = ",mask_path)
img_open = Image.open(img_path)
mask_open = Image.open(mask_path)
# plt.imshow(np.array(img_open))
# plt.show()
# plt.imshow(np.array(mask_open))
# plt.show()
# 对图像和标签同时做处理
if self.joint_transformer != None:
augmented = self.joint_transformer(image=np.array(img_open),
mask=np.array(mask_open)) # 这个包比较方便,能把mask也一并做掉
img = augmented['image'] # 参考https://github.com/albumentations-team/albumentations
mask = augmented['mask']
else:
img = np.array(img_open)
mask = np.array(mask_open)
# print("mask = ",(mask.shape))
# plt.imshow(img)
# plt.show()
# plt.imshow(mask)
# plt.show()
img_pil = Image.fromarray(img)
mask_pil = Image.fromarray(mask)
# print("mask_pil = ",mask_pil)
img_tensor = self.transformer(img_pil)
mask_img = self.label_transformer(mask_pil)
##########################################
#########这一步最关键,有两个坑#############
#####1.作为分割的label,label必须是二维######
#####2.作为分割的label,label必须是long类型的#
mask_np = np.array(mask_img)
mask_bool5 = (mask_np == 5)
mask_np[mask_bool5] = 0
mask_tensor = torch.from_numpy(mask_np)
# print("mask_tensor.shape = ",mask_tensor.shape)
mask_tensor_squeeze = torch.squeeze(mask_tensor).type(torch.long)
# print("mask_tensor.shape = ",mask_tensor.shape)
return img_tensor, mask_tensor_squeeze
def __len__(self):
return len(self.img)
# [256,256,1]
# [0,2,4,5,6]
# [[255,0,255,0]
# [255,0,0,0],
# [1,2,3]]
# 划分训练集、验证集和测试集 8:1:1
print("data_imgs = ",len(data_imgs))
print("data_labels = ",len(data_labels))
print("*"*100)
# 训练集、验证集、测试集划分比例 6:2:2
train_length = 0.90 * len(data_imgs) # 0 - 0.6 * len(data_imgs)
validate_length = 0.99 * len(data_imgs) # 0.6 * len(data_imgs) - 0.8 * len(data_imgs)
test_length = 1 * len(data_imgs) # 0.8 * len(data_imgs) - 1 * len(data_imgs)
print("train_length = ",train_length)
print("validate_length = ",validate_length)
print("test_length = ",test_length)
train_img = data_imgs[0:int(train_length)]
train_label = data_labels[0:int(train_length)]
validate_img = data_imgs[int(train_length):int(validate_length)]
validate_label = data_labels[int(train_length):int(validate_length)]
test_img = data_imgs[int(validate_length):int(test_length)]
test_label = data_labels[int(validate_length):int(test_length)]
print("len(train_img) = ",int(len(train_img)) , " len(train_label) = ",int(len(train_label)))
print("len(validate_img) = ",int(len(validate_img)) , " len(validate_img) = ",int(len(validate_img)))
print("len(test_img) = ",int(len(test_img)) , " len(test_label) = ",int(len(test_label)))
print("len(train_img) + len(validate_img) + len(test_img) = ",int(len(train_img)) + int(len(validate_img)) + int(len(test_img)))
print("len(train_label) + len(validate_label) + len(test_label) = ",int(len(train_label)) + int(len(validate_label)) + int(len(test_label)))
# 打印前5个数据
(print(train_img[:5]))
(print(train_label[:5]))
# 创建数据集对象
train_dataset=HistoBreastdataset(train_img,train_label,img_transformer,label_transformer,joint_transformer)
# 只对数据集做增强
validate_dataset=HistoBreastdataset(validate_img,validate_label,img_transformer,label_transformer)
test_dataset=HistoBreastdataset(test_img,test_label,img_transformer,label_transformer)
# 查看图片和标签的shape
print(train_dataset[0][0].shape)
print(train_dataset[0][1].shape)
# 创建对应的DataLoader类
dl_train=DataLoader(train_dataset,batch_size=8,shuffle=True,drop_last=True)
dl_validate=DataLoader(validate_dataset,batch_size=8,shuffle=True,drop_last=True)
dl_test=DataLoader(test_dataset,batch_size=4,shuffle=True,drop_last=True)
# 查看一个batch的数据
img,lable=next(iter(dl_train))
print("img.shape = ",img.shape)
print("lable.shape = ",lable.shape)
# 可视化一个batch的数据
img,label=next(iter(dl_train))
plt.figure(figsize=(12,8))
for i,(img,label) in enumerate(zip(img[:4],label[:4])):
print("shape = ",label.shape)
# cal_factor(label)
img_np=img.permute(1,2,0).numpy()
label_np=label.numpy()
plt.subplot(2,4,i+1)
plt.imshow(img_np)
plt.subplot(2,4,i+5)
plt.imshow(label_np)
plt.show()
# 定义Unet模型
class DoubleConv(nn.Sequential):
def __init__(self, in_channels, out_channels, mid_channels=None):
if mid_channels is None:
mid_channels = out_channels
super(DoubleConv, self).__init__(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__(
nn.MaxPool2d(2, stride=2),
DoubleConv(in_channels, out_channels)
)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# [N, C, H, W]
diff_y = x2.size()[2] - x1.size()[2]
diff_x = x2.size()[3] - x1.size()[3]
# padding_left, padding_right, padding_top, padding_bottom
x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
diff_y // 2, diff_y - diff_y // 2])
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(OutConv, self).__init__(
nn.Conv2d(in_channels, num_classes, kernel_size=1)
)
class UNet(nn.Module):
def __init__(self,
in_channels: int = 1,
num_classes: int = 2,
bilinear: bool = True,
base_c: int = 64):
super(UNet, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.bilinear = bilinear
self.in_conv = DoubleConv(in_channels, base_c)
self.down1 = Down(base_c, base_c * 2)
self.down2 = Down(base_c * 2, base_c * 4)
self.down3 = Down(base_c * 4, base_c * 8)
factor = 2 if bilinear else 1
self.down4 = Down(base_c * 8, base_c * 16 // factor)
self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
self.up4 = Up(base_c * 2, base_c, bilinear)
self.out_conv = OutConv(base_c, num_classes)
def forward(self, x):
x1 = self.in_conv(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.out_conv(x)
return {"out": logits}
def create_model(num_classes):
model = UNet(in_channels=3, num_classes=num_classes, base_c=64)
return model
# 创建unet
model = create_model(num_classes=5)
# 加载预训练模型
model.load_state_dict(torch.load(r"./logs/unet__bcss_pretrain_model65.pth"))
device = torch.device("cuda")
model=model.to(device)
# 带权重的交叉熵损失函数
# weights = torch.tensor([1,1,5,1,1],dtype=torch.float).cuda()
# loss_fn=nn.CrossEntropyLoss(weights)
loss_fn=nn.CrossEntropyLoss()
# 优化器
# optimizer=torch.optim.Adam(model.parameters(),lr=0.0001,weight_decay = 0.0001)
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)
# optimizer=torch.optim.SGD(model.parameters(),lr=0.005,momentum=0.9,weight_decay=0.0001)
# 训练函数
name_classes = ["other", "tumor", "stroma", "lymphocytic_infiltrate", "necrosis_or_debris"]
num_classes = 5
from tqdm import tqdm
def fit(epoch, model, trainloader, testloader,writer):
correct = 0
total = 0
running_loss = 0
epoch_iou = []
train_dice = []
test_dice = []
# 混淆矩阵
hist = np.zeros((num_classes, num_classes))
model.train()
for x, y in tqdm(trainloader):
x, y = x.to(device), y.to(device)
# print("x.shape = ",x.shape)
y_pred = model(x)
# print("x.shape = ",x.shape)
y_pred = y_pred['out']
# print("*"*100)
# print("y_pred.shape = ",y_pred.shape)
# print("y.shape = ",y.shape)
# print("y_pred = ",y_pred)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim=1)
# print("y.shape = ",y.shape)
# print("y_pred.shape = ",y_pred.shape)
writer.add_image('train/label', y[0].unsqueeze(0).cpu().numpy(), epoch)
writer.add_image('train/pred', y_pred[0].unsqueeze(0).cpu().numpy(), epoch)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
# 计算验证集的混淆矩阵
hist += fast_hist(y_pred.cpu().numpy().flatten(), y.cpu().numpy().flatten(), num_classes)
# 计算验证集的每个类别的iou和miou
epoch_class_iou = per_class_iu(hist)
mean_iou = np.mean(epoch_class_iou)
# name_classes = ["other","tumor","stroma","lymphocytic_infiltrate","necrosis_or_debris"]
print("========================================训练集===================================================>")
print("other_iou : ", epoch_class_iou[0], "tumor_iou : ", epoch_class_iou[1], "stroma_iou : ", epoch_class_iou[2],
"lymphocytic_infiltrate_iou : ", epoch_class_iou[3], "necrosis_or_debris_iou : ", epoch_class_iou[4])
print("mean_iou = ", mean_iou)
# 保存模型
torch.save(model.state_dict(), 'unet__bcss_pretrain_model_with_lym_nec_again{}.pth'.format(epoch))
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / (total * 256 * 256)
# 保存训练损失到tensorboard中
writer.add_scalar('loss/train_loss',epoch_loss,epoch)
# 保存指标
writer.add_scalar('train/train_loss',epoch_loss,epoch)
writer.add_scalar('train/other_iou',epoch_class_iou[0],epoch)
writer.add_scalar('train/tumor_iou',epoch_class_iou[1],epoch)
writer.add_scalar('train/stroma_iou',epoch_class_iou[2],epoch)
writer.add_scalar('train/lymphocytic_infiltrate_iou',epoch_class_iou[3],epoch)
writer.add_scalar('train/necrosis_or_debris_iou',epoch_class_iou[4],epoch)
writer.add_scalar('train/mean_iou',mean_iou,epoch)
test_correct = 0
test_total = 0
test_running_loss = 0
# 混淆矩阵
hist2 = np.zeros((num_classes, num_classes))
model.eval()
with torch.no_grad():
for x, y in tqdm(testloader):
x, y = x.to(device), y.to(device)
y_pred = model(x)
y_pred = y_pred['out']
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim=1)
# 保存图片
writer.add_image('valid/label', y[0].unsqueeze(0).cpu().numpy(), epoch)
writer.add_image('valid/pred', y_pred[0].unsqueeze(0).cpu().numpy(), epoch)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
# 计算验证集的混淆矩阵
hist2 += fast_hist(y_pred.cpu().numpy().flatten(), y.cpu().numpy().flatten(), num_classes)
epoch_test_loss = test_running_loss / len(testloader.dataset)
epoch_test_acc = test_correct / (test_total * 256 * 256)
# 计算验证集的每个类别的iou和miou
epoch_class_iou2 = per_class_iu(hist2)
mean_iou2 = np.mean(epoch_class_iou2)
# name_classes = ["other","tumor","stroma","lymphocytic_infiltrate","necrosis_or_debris"]
print("========================================测试集===================================================>")
print("other_iou : ", epoch_class_iou2[0], "tumor_iou : ", epoch_class_iou2[1], "stroma_iou : ",
epoch_class_iou2[2], "lymphocytic_infiltrate_iou : ", epoch_class_iou2[3], "necrosis_or_debris_iou : ",
epoch_class_iou2[4])
print("mean_iou = ", mean_iou2)
# 保存训练损失到tensorboard中
writer.add_scalar('loss/valid_loss', epoch_test_loss, epoch)
# 保存指标
writer.add_scalar('valid/valid_other_iou',epoch_class_iou2[0],epoch)
writer.add_scalar('valid/valid_tumor_iou',epoch_class_iou2[1],epoch)
writer.add_scalar('valid/valid_stroma_iou',epoch_class_iou2[2],epoch)
writer.add_scalar('valid/valid_lymphocytic_infiltrate_iou',epoch_class_iou2[3],epoch)
writer.add_scalar('valid/valid_necrosis_or_debris_iou',epoch_class_iou2[4],epoch)
writer.add_scalar('valid/valid_mean_iou',mean_iou2,epoch)
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy:', round(epoch_acc, 3),
'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy:', round(epoch_test_acc, 3)
)
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
# 训练的epochs
epochs = 200
train_loss = []
train_acc = []
test_loss = []
test_acc = []
# 开始训练
writer = SummaryWriter('./runs_lym_nec_again')
for epoch in range(0,epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
model,
dl_train,
dl_validate,writer)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)