百度网盘提取码:f1j7
Pytorch-gpu==1.10.1
Python==3.8
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.LongTensor)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
indexx = 5
images = images[indexx]
labels = labels[indexx]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
from torch import nn
import torch
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ConvBlock, self).__init__()
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_bn_relu(x)
class DecodeConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, out_padding=1):
super(DecodeConvBlock, self).__init__()
self.de_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=out_padding)
self.bn = nn.BatchNorm2d(num_features=out_channels)
def forward(self, x, is_act=True):
x = self.de_conv(x)
if is_act:
x = torch.relu(self.bn(x))
return x
class EncodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
self.conv2 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv3 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv4 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.short_cut = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
def forward(self, x):
out1 = self.conv1(x)
out1 = self.conv2(out1)
short_cut = self.short_cut(x)
out2 = self.conv3(out1 + short_cut)
out2 = self.conv4(out2)
return out1 + out2
class DecodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DecodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=1, padding=0)
self.de_conv = DecodeConvBlock(in_channels=in_channels // 4, out_channels=in_channels // 4)
self.conv3 = ConvBlock(in_channels=in_channels // 4, out_channels=out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.de_conv(x)
x = self.conv3(x)
return x
class LinkNet(nn.Module):
def __init__(self):
super(LinkNet, self).__init__()
self.init_conv = ConvBlock(in_channels=3, out_channels=64, stride=2, kernel_size=7, padding=3)
self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.encode_1 = EncodeBlock(in_channels=64, out_channels=64)
self.encode_2 = EncodeBlock(in_channels=64, out_channels=128)
self.encode_3 = EncodeBlock(in_channels=128, out_channels=256)
self.encode_4 = EncodeBlock(in_channels=256, out_channels=512)
self.decode_4 = DecodeBlock(in_channels=512, out_channels=256)
self.decode_3 = DecodeBlock(in_channels=256, out_channels=128)
self.decode_2 = DecodeBlock(in_channels=128, out_channels=64)
self.decode_1 = DecodeBlock(in_channels=64, out_channels=64)
self.deconv_out1 = DecodeConvBlock(in_channels=64, out_channels=32)
self.conv_out = ConvBlock(in_channels=32, out_channels=32)
self.deconv_out2 = DecodeConvBlock(in_channels=32, out_channels=2, kernel_size=2, padding=0, out_padding=0)
def forward(self, x):
x = self.init_conv(x)
x = self.init_maxpool(x)
e1 = self.encode_1(x)
e2 = self.encode_2(e1)
e3 = self.encode_3(e2)
e4 = self.encode_4(e3)
d4 = self.decode_4(e4)
d3 = self.decode_3(d4 + e3)
d2 = self.decode_2(d3 + e2)
d1 = self.decode_1(d2 + e1)
f1 = self.deconv_out1(d1)
f2 = self.conv_out(f1)
f3 = self.deconv_out2(f2)
return f3
import torch
from data_loader import load_data
from model_loader import LinkNet
from torch import nn
from torch import optim
import tqdm
import os
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据
train_dl, test_dl = load_data()
# 加载模型
model = LinkNet()
model = model.to(device=device)
# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# 开始进行训练
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
train_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
intersection = torch.logical_and(input=train_labels, other=torch.argmax(input=pred, dim=1))
union = torch.logical_or(input=train_labels, other=torch.argmax(input=pred, dim=1))
batch_iou = torch.true_divide(torch.sum(intersection), torch.sum(union))
train_iou_sum = torch.cat([train_iou_sum, torch.unsqueeze(input=batch_iou, dim=-1)], dim=-1)
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item(),
'train iou': train_iou_sum.mean().item()
})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
test_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_intersection = torch.logical_and(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_union = torch.logical_or(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_batch_iou = torch.true_divide(torch.sum(test_intersection), torch.sum(test_union))
test_iou_sum = torch.cat([test_iou_sum, torch.unsqueeze(input=test_batch_iou, dim=-1)], dim=-1)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item(),
'test iou': test_iou_sum.mean().item()
})
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import LinkNet
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = LinkNet()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# 开始进行预测
images, labels = next(iter(test_dl))
index = 2
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.8, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()