代码来自:https://github.com/ChuHan89/WSSS-Tissue?tab=readme-ov-file
借助了一些人工智能
1_train_stage1.py
该代码是弱监督语义分割(WSSS)流程的 Stage1 训练与测试脚本,核心任务是通过 多标签分类模型 生成图像级标签,为后续生成伪掩码(Pseudo-Masks)提供基础。代码分为 train_phase
和 test_phase
两个阶段,支持 渐进式Dropout注意力(PDA) 和 Visdom可视化监控。
import os import numpy as np import argparse import importlib from visdom import Visdom # 可视化工具 import torch import torch.nn.functional as F from torch.backends import cudnn # CUDA加速 from torch.utils.data import DataLoader from torchvision import transforms # 数据预处理 from tool import pyutils, torchutils # 自定义工具包 from tool.GenDataset import Stage1_TrainDataset # 自定义数据集类 from tool.infer_fun import infer # 测试阶段推理函数 cudnn.enabled = True # 启用CUDA加速(自动优化卷积算法)
关键细节:
cudnn.enabled=True
:启用cuDNN加速,自动选择最优卷积实现。
pyutils
和 torchutils
:项目自定义工具模块(包含优化器、计时器等)。
Visdom
:用于实时可视化训练过程中的损失和准确率曲线。
compute_acc
def compute_acc(pred_labels, gt_labels): pred_correct_count = 0 for pred_label in pred_labels: # 遍历预测标签 if pred_label in gt_labels: # 判断是否在真实标签中 pred_correct_count += 1 union = len(gt_labels) + len(pred_labels) - pred_correct_count # 并集大小 acc = round(pred_correct_count/union, 4) # 交并比(IoU)式准确率 return acc
功能:计算预测标签与真实标签的 交并比准确率(IoU-like Accuracy)。
数学公式:
Acc=预测正确的标签数预测标签数+真实标签数−预测正确的标签数Acc=预测标签数+真实标签数−预测正确的标签数预测正确的标签数示例:
预测标签:[0, 2]
,真实标签:[2, 3]
正确数:1(标签2),并集:2 + 2 - 1 = 3 → Acc = 1/3 ≈ 0.333
train_phase
def train_phase(args): viz = Visdom(env=args.env_name) # 创建Visdom环境(用于可视化) model = getattr(importlib.import_module(args.network), 'Net')(args.init_gama, n_class=args.n_class) print(vars(args)) # 打印所有输入参数
关键细节:
动态模型加载:通过 importlib
从字符串 args.network
(如 "network.resnet38_cls"
)动态加载模型类 Net
。
PDA参数:args.init_gama
控制渐进式Dropout注意力的初始强度(值越大,注意力区域越集中)。
Visdom环境:通过 env=args.env_name
隔离不同实验的可视化结果。
transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转 transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻转 transforms.ToTensor() # 转为Tensor(范围[0,1]) ]) train_dataset = Stage1_TrainDataset( data_path=args.trainroot, # 训练集路径(如'datasets/BCSS-WSSS/train/') transform=transform_train, dataset=args.dataset # 数据集标识(如'bcss') ) train_data_loader = DataLoader( train_dataset, batch_size=args.batch_size, # 批大小(默认20) shuffle=True, # 打乱数据顺序 num_workers=args.num_workers, # 数据加载子进程数(默认10) pin_memory=False, # 不锁页内存(适用于小批量数据) drop_last=True # 丢弃最后不足一个batch的数据 )
关键细节:
数据增强策略:仅使用翻转操作,避免复杂变换干扰分类模型的学习。
自定义数据集类:Stage1_TrainDataset
需实现图像和标签的加载逻辑(如解析XML或CSV文件)。
max_step = (len(train_dataset) // args.batch_size) * args.max_epoches # 总迭代次数 param_groups = model.get_parameter_groups() # 获取模型参数分组(通常按网络层分组) optimizer = torchutils.PolyOptimizer( [ {'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec}, # 主干网络(低学习率) {'params': param_groups[1], 'lr': 2*args.lr, 'weight_decay': 0}, # 中间层(较高学习率) {'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec}, # 分类头(高学习率) {'params': param_groups[3], 'lr': 20*args.lr, 'weight_decay': 0} # 特殊模块(最高学习率) ], lr=args.lr, weight_decay=args.wt_dec, max_step=max_step # 控制学习率衰减 )
关键细节:
参数分组:不同网络层(如ResNet38的卷积层、全连接层)使用不同的学习率,分类头通常需要更高学习率以快速适应新任务。
Poly学习率衰减:学习率按公式 lr=base_lr×(1−stepmax_step)powerlr=base_lr×(1−max_stepstep)power 衰减,默认 power=0.9
。
if args.weights[-7:] == '.params': # MXNet格式权重(如'init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params') import network.resnet38d weights_dict = network.resnet38d.convert_mxnet_to_torch(args.weights) # 转换权重格式 model.load_state_dict(weights_dict, strict=False) # 非严格加载(允许部分参数不匹配) elif args.weights[-4:] == '.pth': # PyTorch格式权重 weights_dict = torch.load(args.weights) model.load_state_dict(weights_dict, strict=False) else: print('random init') # 随机初始化(无预训练)
关键细节:
MXNet转换:项目可能基于早期MXNet实现,需将预训练权重转换为PyTorch格式。
strict=False
:允许模型结构与权重文件部分不匹配(如分类头维度不同)。
model = model.cuda() # 将模型移至GPU avg_meter = pyutils.AverageMeter('loss', 'avg_ep_EM', 'avg_ep_acc') # 统计训练指标 timer = pyutils.Timer("Session started: ") # 计时器(计算剩余时间) for ep in range(args.max_epoches): # 遍历每个epoch model.train() args.ep_index = ep # 当前epoch索引(可能用于回调) ep_count = 0 # 当前epoch累计样本数 ep_EM = 0 # 完全匹配(Exact Match)次数 ep_acc = 0 # 累计准确率 for iter, (filename, data, label) in enumerate(train_data_loader): # 遍历每个batch img = data # 图像数据(未使用filename) label = label.cuda(non_blocking=True) # 标签移至GPU(异步传输) # 控制PDA的启用(前3个epoch禁用) enable_PDA = 1 if ep > 2 else 0 # 前向传播(返回分类输出、特征图、概率) x, feature, y = model(img.cuda(), enable_PDA) # 转换为CPU numpy数组以计算指标 prob = y.cpu().data.numpy() # 预测概率(shape=[batch_size, n_class]) gt = label.cpu().data.numpy() # 真实标签(shape=[batch_size, n_class]) # 遍历batch内每个样本计算指标 for num, one in enumerate(prob): ep_count += 1 pass_cls = np.where(one > 0.5)[0] # 预测标签(概率>0.5的类别) true_cls = np.where(gt[num] == 1)[0] # 真实标签(one-hot编码中为1的类别) # 统计Exact Match(完全匹配) if np.array_equal(pass_cls, true_cls): ep_EM += 1 # 计算交并比式准确率 acc = compute_acc(pass_cls, true_cls) ep_acc += acc # 计算当前batch的平均指标 avg_ep_EM = round(ep_EM / ep_count, 4) avg_ep_acc = round(ep_acc / ep_count, 4) # 计算多标签分类损失 loss = F.multilabel_soft_margin_loss(x, label) # x为模型原始输出(未经过sigmoid) # 更新统计指标 avg_meter.add({ 'loss': loss.item(), 'avg_ep_EM': avg_ep_EM, 'avg_ep_acc': avg_ep_acc }) # 反向传播与优化 optimizer.zero_grad() # 清空梯度 loss.backward() # 计算梯度 optimizer.step() # 更新参数 torch.cuda.empty_cache() # 清理GPU缓存(防止内存泄漏) # 每100步打印日志并更新Visdom if (optimizer.global_step) % 100 == 0 and (optimizer.global_step) != 0: timer.update_progress(optimizer.global_step / max_step) # 更新剩余时间估计 print( 'Epoch:%2d' % (ep), 'Iter:%5d/%5d' % (optimizer.global_step, max_step), 'Loss:%.4f' % (avg_meter.get('loss')), 'avg_ep_EM:%.4f' % (avg_meter.get('avg_ep_EM')), 'avg_ep_acc:%.4f' % (avg_meter.get('avg_ep_acc')), 'lr: %.4f' % (optimizer.param_groups[0]['lr']), 'Fin:%s' % (timer.str_est_finish()), flush=True ) # 更新Visdom图表 viz.line( [avg_meter.pop('loss')], [optimizer.global_step], win='loss', update='append', opts=dict(title='loss') ) # 同理更新 'Acc_exact' 和 'Acc' 图表... # 每epoch后调整PDA的gama参数 if model.gama > 0.65: model.gama = model.gama * 0.98 # 逐步衰减注意力强度 print('Gama of progressive dropout attention is: ', model.gama) # 保存最终模型 torch.save( model.state_dict(), os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth') )
关键细节:
渐进式Dropout注意力(PDA):
前3个epoch禁用(enable_PDA=0
),让模型初步学习基础特征。
gama
初始值为1,逐渐衰减(gama *= 0.98
),控制注意力区域的聚焦程度。
损失函数:F.multilabel_soft_margin_loss
结合Sigmoid和交叉熵,适用于多标签分类。
指标计算:
Exact Match (EM):预测标签与真实标签完全一致的样本比例(严格指标)。
IoU式准确率:反映预测与真实标签的重合程度(宽松指标)。
Visdom集成:实时可视化损失和准确率曲线,便于监控训练状态。
test_phase
def test_phase(args): # 加载生成CAM的模型变体(Net_CAM) model = getattr(importlib.import_module(args.network), 'Net_CAM')(n_class=args.n_class) model = model.cuda() # 加载训练阶段保存的权重 args.weights = os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth') weights_dict = torch.load(args.weights) model.load_state_dict(weights_dict, strict=False) model.eval() # 设置为评估模式(禁用Dropout和BatchNorm的随机性) # 调用自定义推理函数(评估模型在测试集上的性能) score = infer(model, args.testroot, args.n_class) print(score) # 输出评估结果(如mAP、IoU等) # 可选:保存最终模型(可能包含CAM生成能力) torch.save(model.state_dict(), ...)
关键细节:
模型变体:Net_CAM
可能修改了网络结构以输出类别激活图(Class Activation Map)。
评估指标:infer
函数内部可能计算mAP(平均精度)、像素级IoU等指标。
严格模式:strict=False
允许加载部分权重(如分类头维度不同)。
if __name__ == '__main__': parser = argparse.ArgumentParser() # 训练参数 parser.add_argument("--batch_size", default=20, type=int) parser.add_argument("--max_epoches", default=20, type=int) parser.add_argument("--network", default="network.resnet38_cls", type=str) parser.add_argument("--lr", default=0.01, type=float) parser.add_argument("--num_workers", default=10, type=int) parser.add_argument("--wt_dec", default=5e-4, type=float) # 权重衰减(L2正则化) # 实验命名与可视化 parser.add_argument("--session_name", default="Stage 1", type=str) # 实验名称(日志标识) parser.add_argument("--env_name", default="PDA", type=str) # Visdom环境名 parser.add_argument("--model_name", default='PDA', type=str) # 模型保存名称 # 数据集与模型结构 parser.add_argument("--n_class", default=4, type=int) # 类别数(如BCSS为4类) parser.add_argument("--weights", default='init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params', type=str) parser.add_argument("--trainroot", default='datasets/BCSS-WSSS/train/', type=str) parser.add_argument("--testroot", default='datasets/BCSS-WSSS/test/', type=str) parser.add_argument("--save_folder", default='checkpoints/', type=str) # PDA参数 parser.add_argument("--init_gama", default=1, type=float) # 初始注意力强度 # 数据集标识 parser.add_argument("--dataset", default='bcss', type=str) # 数据集缩写(影响保存文件名) args = parser.parse_args() train_phase(args) # 执行训练 test_phase(args) # 执行测试
关键参数说明:
--network
:模型定义文件路径(如 network.resnet38_cls
对应 network/resnet38_cls.py
)。
--init_gama
:PDA的初始强度,影响注意力机制的随机丢弃率。
--weights
:预训练权重路径(支持MXNet和PyTorch格式)。