下载链接:需要在:EchoNet Dynamic 进行申请,源码也是由斯坦福公布的
数据集为心脏跳动的视频,我们可以根据视频中的图像,对每一帧的图像进行图像分割,根据心脏跳动时心房或者心室面积的变化,协助诊断。
整段代码比较简单,模型都是使用的pytorch提供的模型,这篇研究的价值在于探究了深度学习应用的领域,提出了一个数据集
配置参数解读
Args: num_epochs (int, optional): Number of epochs during training Defaults to 50.迭代次数 modelname (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'', ``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101'' 模型名称,主要为pytorch内置模型 pretrained (bool, optional): Whether to use pretrained weights for model Defaults to False.加载预训练权重 output (str or None, optional): 输出文件位置 device (str or None, optional): 设备(CPU/GPU)) n_train_patients (str or None, optional): 提前停止策略 num_workers (int, optional): 线程数 batch_size (int, optional): 一个批次的大小 seed (int, optional):随机种子 lr_step_period (int or None, optional): 学习率衰减 save_segmentation (bool, optional): 保存分割结果 block_size (int, optional): 视频太长的话,需不需要考虑分块保存 run_test (bool, optional): 进行测试
任务流程:
def run_epoch(model, dataloader, train, optim, device):
"""Run one epoch of training/evaluation for segmentation.
Args:
model (torch.nn.Module): Model to train/evaulate.
dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
train (bool): Whether or not to train model.
optim (torch.optim.Optimizer): Optimizer
device (torch.device): Device to run on
"""
total = 0.
n = 0
pos = 0
neg = 0
pos_pix = 0
neg_pix = 0
model.train(train)
large_inter = 0
large_union = 0
small_inter = 0
small_union = 0
large_inter_list = []
large_union_list = []
small_inter_list = []
small_union_list = []
with torch.set_grad_enabled(train):
with tqdm.tqdm(total=len(dataloader)) as pbar:
for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader:
# Count number of pixels in/out of human segmentation
pos += (large_trace == 1).sum().item()
pos += (small_trace == 1).sum().item()
neg += (large_trace == 0).sum().item()
neg += (small_trace == 0).sum().item()
# Count number of pixels in/out of computer segmentation
pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy()
pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy()
neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy()
neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy()
# Run prediction for diastolic frames and compute loss
large_frame = large_frame.to(device)
large_trace = large_trace.to(device)
y_large = model(large_frame)["out"]
loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum")
# Compute pixel intersection and union between human and computer segmentations
large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
# Run prediction for systolic frames and compute loss
small_frame = small_frame.to(device)
small_trace = small_trace.to(device)
y_small = model(small_frame)["out"]
loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum")
# Compute pixel intersection and union between human and computer segmentations
small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
# Take gradient step if training
loss = (loss_large + loss_small) / 2
if train:
optim.zero_grad()
loss.backward()
optim.step()
# Accumulate losses and compute baselines
total += loss.item()
n += large_trace.size(0)
p = pos / (pos + neg)
p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2)
# Show info on process bar
pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter)))
pbar.update()
large_inter_list = np.array(large_inter_list)
large_union_list = np.array(large_union_list)
small_inter_list = np.array(small_inter_list)
small_union_list = np.array(small_union_list)
return (total / n / 112 / 112,
large_inter_list,
large_union_list,
small_inter_list,
small_union_list,
)