MIT License
Copyright (c) 2019 PJ-Javis
目录:
链接: https://pan.baidu.com/s/1WVVo6VDv5NkjgdYfkIO7Fw 提取码: 14mi
unet.py
import torch.nn as nn
import torch
from torch import autograd
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self,in_ch,out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64,out_ch, 1)
def forward(self,x):
c1=self.conv1(x)
p1=self.pool1(c1)
c2=self.conv2(p1)
p2=self.pool2(c2)
c3=self.conv3(p2)
p3=self.pool3(c3)
c4=self.conv4(p3)
p4=self.pool4(c4)
c5=self.conv5(p4)
up_6= self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6=self.conv6(merge6)
up_7=self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7=self.conv7(merge7)
up_8=self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8=self.conv8(merge8)
up_9=self.up9(c8)
merge9=torch.cat([up_9,c1],dim=1)
c9=self.conv9(merge9)
c10=self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
dataset.py
import torch.utils.data as data
import PIL.Image as Image
import os
def make_dataset(root):
imgs=[]
n=len(os.listdir(root))//2
for i in range(n):
img=os.path.join(root,"%03d.png"%i)
mask=os.path.join(root,"%03d_mask.png"%i)
imgs.append((img,mask))
return imgs
class LiverDataset(data.Dataset):
def __init__(self, root, transform=None, target_transform=None):
imgs = make_dataset(root)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
x_path, y_path = self.imgs[index]
img_x = Image.open(x_path)
img_y = Image.open(y_path)
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
return img_x, img_y
def __len__(self):
return len(self.imgs)
main.py
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 把多个步骤整合到一起, channel=(channel-mean)/std, 因为是分别对三个通道处理
x_transforms = transforms.Compose([
transforms.ToTensor(), # -> [0,1]
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # ->[-1,1]
])
# mask只需要转换为tensor
y_transforms = transforms.ToTensor()
# 参数解析器,用来解析从终端读取的命令
parse = argparse.ArgumentParser()
def train_model(model, criterion, optimizer, dataload, num_epochs=20):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
inputs = x.to(device)
labels = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
return model
# 训练模型
def train():
model = Unet(3, 1).to(device)
batch_size = args.batch_size
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters())
liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_model(model, criterion, optimizer, dataloaders)
# 显示模型的输出结果
def test():
model = Unet(3, 1)
model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=1)
model.eval()
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
for x, _ in dataloaders:
y=model(x)
img_y=torch.squeeze(y).numpy()
plt.imshow(img_y)
plt.pause(0.01)
plt.show()
parse = argparse.ArgumentParser()
parse.add_argument("action", type=str, help="train or test")
parse.add_argument("--batch_size", type=int, default=1)
parse.add_argument("--ckp", type=str, help="the path of model weight file")
args = parse.parse_args()
# train
#train()
# test()
args.ckp = "weights_19.pth"
test()
测试结果: