Train.py
import numpy as np
np.set_printoptions(threshold=np.inf)
# threshold表示: Total number of array elements to be print(输出数组的元素数目)
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
import random
from tools.my_dataset import MyDataset
from tools.unet import UNet
from tools.set_seed import set_seed
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # 训练数据预处理
# train_transform = transforms.Compose([
# transforms.Resize((256, 256)),
# transforms.RandomCrop(256, padding=4),
# # 添加随机遮挡 旋转 等
# transforms.ToTensor(),
# transforms.Normalize(norm_mean, norm_std),
# ])
# # 验证数据预处理
# valid_transform = transforms.Compose([
# transforms.Resize((256, 256)),
# transforms.ToTensor(),
# transforms.Normalize(norm_mean, norm_std),
# ])
#
# # 构建MyDataset实例
# train_data = MyDataset(data_dir=train_dir, transform=train_transform)
# valid_data = MyDataset(data_dir=valid_dir, transform=valid_transform)
set_seed() # 设置随机种子
def compute_dice(y_pred, y_true): # 计算dice系数
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))
if __name__ == "__main__":
# ============================ step 0/5 参数设置 ============================
LR = 0.01 # 学习率
BATCH_SIZE = 8 # 批大小
max_epoch = 1 # 训练epoch
start_epoch = 0 # 开始
lr_step = 150 # 调整学习率的间隔 ,每训练step_size个epoch,更新一次参数
val_interval = 3 # 验证间隔
checkpoint_interval = 20 # 模型保存间隔
vis_num = 10 # 可视化的间隔
mask_thres = 0.5 # 分类阈值
# ============================ step 1/5 数据 ============================
# 读取数据文件夹
train_dir = os.path.join(BASE_DIR, "..", "data", "blood", "train")
valid_dir = os.path.join(BASE_DIR, "..", "data", "blood", "valid")
# 数据预处理
# train_transform = transforms.Compose([
#
# transforms.RandomCrop(20, padding=4),
# # 添加随机遮挡 旋转 等
# ])
train_set = MyDataset(train_dir) #,transform=train_transform
valid_set = MyDataset(valid_dir)
# 读取预处理后的数据
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)
# ============================ step 2/5 模型 ============================
net = UNet(in_channels=3, out_channels=1, init_features=32) # init_features is 64 in stander uent 输入是三通道的 输出是一通道的, init_features是第一个特征图的层数
net.to(device)
# ============================ step 3/5 损失函数 ============================
# 均方误差损失函数
loss_fn = nn.MSELoss()
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 定义优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1) # 调整学习率
# ============================ step 5/5 训练 ============================
train_curve = list() # 画训练曲线
valid_curve = list()
train_dice_curve = list() # 画训练dice曲线
valid_dice_curve = list()
for epoch in range(start_epoch, max_epoch):
train_loss_total = 0.
train_dice_total = 0.
net.train()
for iter, (inputs, labels) in enumerate(train_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
# forward
outputs = net(inputs)
# backward
optimizer.zero_grad()
inputs.size()
# print(list(labels.size()))
# print(list(outputs.size()))
# print(list(inputs.size()))
#
# b = labels.numpy()
# b=b[0][0]
# b = np.around(b)
# #np.savetxt("test_inputs.csv", b)
#
# print(b)
#
# print("----------------------------------------------")
# print("----------------------------------------------")
# print("----------------------------------------------")
# print("----------------------------------------------")
# c = outputs.detach().numpy()
# c = c[0][0]
# c = np.around(c)
# #np.savetxt("test_outputs.csv", c)
# print(c)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
# print
train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu())
train_dice_curve.append(train_dice)
train_curve.append(loss.item())
train_loss_total += loss.item()
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
"running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
train_loss_total/(iter+1), train_dice, scheduler.get_lr()))
scheduler.step()
# 保存模型
if (epoch + 1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
# validate the model
if (epoch+1) % val_interval == 0:
net.eval()
valid_loss_total = 0.
valid_dice_total = 0.
with torch.no_grad():
for j, (inputs, labels) in enumerate(valid_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_fn(outputs, labels)
valid_loss_total += loss.item()
valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
valid_dice_total += valid_dice
valid_loss_mean = valid_loss_total/len(valid_loader)
valid_dice_mean = valid_dice_total/len(valid_loader)
valid_curve.append(valid_loss_mean)
valid_dice_curve.append(valid_dice_mean)
print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
epoch, max_epoch, valid_loss_mean, valid_dice_mean))
# 可视化
with torch.no_grad(): # 不保留梯度,减少内存消耗,加快速度
for idx, (inputs, labels) in enumerate(valid_loader):
if idx > vis_num:
break
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
pred = outputs.ge(mask_thres)
mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")
img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
mask_pred_gray = mask_pred.squeeze() * 255
plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
plt.show()
plt.pause(0.5)
plt.close()
# plot curve
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(
valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()
# dice curve
train_x = range(len(train_dice_curve))
train_y = train_dice_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(
valid_dice_curve) + 1) * train_iters * val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_dice_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('dice value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()
torch.cuda.empty_cache()
Inference.py
import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
from tools.set_seed import set_seed
from tools.my_dataset import MyDataset
from tools.unet import UNet
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed() # 设置随机种子
def compute_dice(y_pred, y_true):
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))
def get_img_name(img_dir, format="jpg"):
"""
获取文件夹下format格式的文件名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir)
img_names = list(filter(lambda x: x.endswith(format), file_names))
img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))
if len(img_names) < 1:
raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
return img_names
def get_model(m_path):
unet = UNet(in_channels=3, out_channels=1, init_features=32)
checkpoint = torch.load(m_path, map_location="cpu")
# remove module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['model_state_dict'].items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v
unet.load_state_dict(new_state_dict)
return unet
if __name__ == "__main__":
img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
model_path = "checkpoint_399_epoch.pkl"
time_total = 0
num_infer = 5
mask_thres = .5
# 1. data
img_names = get_img_name(img_dir, format="png")
random.shuffle(img_names)
num_img = len(img_names)
# 2. model
unet = get_model(model_path)
unet.to(device)
unet.eval()
for idx, img_name in enumerate(img_names):
if idx > num_infer:
break
path_img = os.path.join(img_dir, img_name)
# path_img = "C:\\Users\\Administrator\\Desktop\\Andrew-wu.png"
#
# step 1/4 : path --> img_chw
img_hwc = Image.open(path_img).convert('RGB')
img_hwc = img_hwc.resize((224, 224))
img_arr = np.array(img_hwc)
img_chw = img_arr.transpose((2, 0, 1))
# step 2/4 : img --> tensor
img_tensor = torch.tensor(img_chw).to(torch.float)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> features
time_tic = time.time()
outputs = unet(img_tensor)
time_toc = time.time()
# step 4/4 : visualization
pred = outputs.ge(mask_thres)
mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")
img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
mask_pred_gray = mask_pred.squeeze() * 255
plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
plt.show()
# plt.pause(0.5)
plt.close()
time_s = time_toc - time_tic
time_total += time_s
print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))
my_dataset.py
import numpy as np
np.set_printoptions(threshold=np.inf)
# threshold表示: Total number of array elements to be print(输出数组的元素数目)
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
random.seed(1)
class MyDataset(Dataset):
def __init__(self, data_dir, transform=None, in_size = 224):
super(MyDataset, self).__init__()
self.data_dir = data_dir
self.transform = transform
self.label_path_list = list()
self.in_size = in_size
# 获取mask的path
self._get_img_path()
def __getitem__(self, index):
path_label = self.label_path_list[index]
path_img = path_label[:-9] + ".tif"
img_pil = Image.open(path_img).convert('RGB')
img_pil = img_pil.resize((self.in_size, self.in_size), Image.BILINEAR)
# 在神经网络中,图像被表示成[c, h, w]格式或者[n, c, h, w]格式,但如果想要将图像以np.ndarray形式输入,因np.ndarray默认将图像表示成[h, w, c]个格式,需要对其进行转化。
img_hwc = np.array(img_pil)
# print(img_hwc)
img_chw = img_hwc.transpose((2, 0, 1))
# 标签
label_pil = Image.open(path_label).convert('L') # 灰度图,一通道
label_pil = label_pil.resize((self.in_size, self.in_size), Image.NEAREST)
label_hw = np.array(label_pil)
label_chw = label_hw[np.newaxis, :, :]
label_hw[label_hw != 0] = 1 # 变成二分类的标签
if self.transform is not None:
img_chw_tensor = torch.from_numpy(self.transform(img_chw.numpy())).float()
label_chw_tensor = torch.from_numpy(self.transform(label_chw.numpy())).float()
# print(type(img_chw))
# label_chw=Image.fromarray(label_chw)
# img_chw_tensor =self.transform(img_chw)
# label_chw_tensor=self.transform(label_chw)
else:
img_chw_tensor = torch.from_numpy(img_chw).float()
label_chw_tensor = torch.from_numpy(label_chw).float()
# img_chw=Image.fromarray(img_chw)
# label_chw=Image.fromarray(label_chw)
# img_chw_tensor =self.transform(img_chw)
# label_chw_tensor=self.transform(label_chw)
return img_chw_tensor, label_chw_tensor
def __len__(self):
return len(self.label_path_list)
def _get_img_path(self):
file_list = os.listdir(self.data_dir)
file_list = list(filter(lambda x: x.endswith("_mask.gif"), file_list)) # 尾缀是_matte.png是mask
path_list = [os.path.join(self.data_dir, name) for name in file_list]
random.shuffle(path_list)
if len(path_list) == 0:
raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))
self.label_path_list = path_list
set_seed.py
import random
import torch
import numpy as np
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
unet.py
from collections import OrderedDict
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()
features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet._block(features * 2, features, name="dec1")
self.conv = nn.Conv2d(
in_channels=features, out_channels=out_channels, kernel_size=1
)
def forward(self, x):
# 编码器
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
# bottleneck
bottleneck = self.bottleneck(self.pool4(enc4))
# 解码器
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1) # 那根线
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels, # 确定卷积核的深度
out_channels=features, # 确实输出的特征图深度,即卷积核组的多少
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm2", nn.BatchNorm2d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)