论文地址:http://iizuka.cs.tsukuba.ac.jp/projects/colorization/en/
pytorch源代码:https://github.com/shufanwu/colorNet-pytorch
神经网络模型
模型代码:colornet.py
import torch.nn as nn
import torch.nn.functional as F
import torch
class LowLevelFeatNet(nn.Module):
def __init__(self):
super(LowLevelFeatNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
self.bn5 = nn.BatchNorm2d(256)
self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.bn6 = nn.BatchNorm2d(512)
def forward(self, x1, x2):
x1 = F.relu(self.bn1(self.conv1(x1)))
x1 = F.relu(self.bn2(self.conv2(x1)))
x1 = F.relu(self.bn3(self.conv3(x1)))
x1 = F.relu(self.bn4(self.conv4(x1)))
x1 = F.relu(self.bn5(self.conv5(x1)))
x1 = F.relu(self.bn6(self.conv6(x1)))
if self.training:
x2 = x1.clone()
else:
x2 = F.relu(self.bn1(self.conv1(x2)))
x2 = F.relu(self.bn2(self.conv2(x2)))
x2 = F.relu(self.bn3(self.conv3(x2)))
x2 = F.relu(self.bn4(self.conv4(x2)))
x2 = F.relu(self.bn5(self.conv5(x2)))
x2 = F.relu(self.bn6(self.conv6(x2)))
return x1, x2
class MidLevelFeatNet(nn.Module):
def __init__(self):
super(MidLevelFeatNet, self).__init__()
self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.conv2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(256)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
return x
class GlobalFeatNet(nn.Module):
def __init__(self):
super(GlobalFeatNet, self).__init__()
self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(512)
self.conv3 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
self.bn3 = nn.BatchNorm2d(512)
self.conv4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2d(512)
self.fc1 = nn.Linear(25088, 1024)
self.bn5 = nn.BatchNorm1d(1024)
self.fc2 = nn.Linear(1024, 512)
self.bn6 = nn.BatchNorm1d(512)
self.fc3 = nn.Linear(512, 256)
self.bn7 = nn.BatchNorm1d(256)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
x = x.view(-1, 25088)
x = F.relu(self.bn5(self.fc1(x)))
output_512 = F.relu(self.bn6(self.fc2(x)))
output_256 = F.relu(self.bn7(self.fc3(output_512)))
return output_512, output_256
class ClassificationNet(nn.Module):
def __init__(self):
super(ClassificationNet, self).__init__()
self.fc1 = nn.Linear(512, 256)
self.bn1 = nn.BatchNorm1d(256)
self.fc2 = nn.Linear(256, 205)
self.bn2 = nn.BatchNorm1d(205)
def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
x = F.log_softmax(self.bn2(self.fc2(x)))
return x
class ColorizationNet(nn.Module):
def __init__(self):
super(ColorizationNet, self).__init__()
self.fc1 = nn.Linear(512, 256)
self.bn1 = nn.BatchNorm1d(256)
self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.conv5 = nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
def forward(self, mid_input, global_input):
w = mid_input.size()[2]
h = mid_input.size()[3]
global_input = global_input.unsqueeze(2).unsqueeze(2).expand_as(mid_input)
fusion_layer = torch.cat((mid_input, global_input), 1)
fusion_layer = fusion_layer.permute(2, 3, 0, 1).contiguous()
fusion_layer = fusion_layer.view(-1, 512)
fusion_layer = self.bn1(self.fc1(fusion_layer))
fusion_layer = fusion_layer.view(w, h, -1, 256)
x = fusion_layer.permute(2, 3, 0, 1).contiguous()
x = F.relu(self.bn2(self.conv1(x)))
x = self.upsample(x)
x = F.relu(self.bn3(self.conv2(x)))
x = F.relu(self.bn4(self.conv3(x)))
x = self.upsample(x)
x = F.sigmoid(self.bn5(self.conv4(x)))
x = self.upsample(self.conv5(x))
return x
class ColorNet(nn.Module):
def __init__(self):
super(ColorNet, self).__init__()
self.low_lv_feat_net = LowLevelFeatNet()
self.mid_lv_feat_net = MidLevelFeatNet()
self.global_feat_net = GlobalFeatNet()
self.class_net = ClassificationNet()
self.upsample_col_net = ColorizationNet()
def forward(self, x1, x2):
x1, x2 = self.low_lv_feat_net(x1, x2)
#print('after low_lv, mid_input is:{}, global_input is:{}'.format(x1.size(), x2.size()))
x1 = self.mid_lv_feat_net(x1)
#print('after mid_lv, mid2fusion_input is:{}'.format(x1.size()))
class_input, x2 = self.global_feat_net(x2)
#print('after global_lv, class_input is:{}, global2fusion_input is:{}'.format(class_input.size(), x2.size()))
class_output = self.class_net(class_input)
#print('after class_lv, class_output is:{}'.format(class_output.size()))
output = self.upsample_col_net(x1, x2)
#print('after upsample_lv, output is:{}'.format(output.size()))
return class_output, output
训练代码train.py
import os
import traceback
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
import numpy as np
from myimgfolder import TrainImageFolder
from colornet import ColorNet
original_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
#transforms.ToTensor()
])
def train(epoch):
color_model.train()
try:
for batch_idx, (data, classes) in enumerate(train_loader):
messagefile = open('./message.txt', 'a')
original_img = data[0].unsqueeze(1).float()
img_ab = data[1].float()
if have_cuda:
original_img = original_img.cuda()
img_ab = img_ab.cuda()
classes = classes.cuda()
original_img = Variable(original_img)
img_ab = Variable(img_ab)
classes = Variable(classes)
class_output, output = color_model(original_img, original_img)
aa=output.size()
bb=list(aa)
cc=np.array(bb)
dd=torch.from_numpy(cc).prod()
ems_loss = torch.pow((img_ab - output), 2).sum() / torch.from_numpy(np.array(list(output.size()))).prod()
cross_entropy_loss = 1/300 * F.cross_entropy(class_output, classes)
loss = ems_loss + cross_entropy_loss
lossmsg = 'loss: %.9f\n' % (loss.data)
messagefile.write(lossmsg)
optimizer.zero_grad()
ems_loss.backward(retain_graph=True)
cross_entropy_loss.backward()
optimizer.step()
print('Train Epoch: {}[{}/{}({:.0f}%)]\tLoss: {:.9f}\n'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data))
if batch_idx % 500 == 0:
message = 'Train Epoch:%d\tPercent:[%d/%d (%.0f%%)]\tLoss:%.9f\n' % (
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data)
messagefile.write(message)
torch.save(color_model.state_dict(), 'colornet_params.pkl')
messagefile.close()
except Exception:
logfile = open('log.txt', 'w')
logfile.write(traceback.format_exc())
logfile.close()
finally:
torch.save(color_model.state_dict(), 'colornet_params.pkl')
if __name__ == '__main__':
have_cuda = torch.cuda.is_available()
epochs = 3
data_dir = "../images256/"
train_set = TrainImageFolder(data_dir, original_transform)
train_set_size = len(train_set)
train_set_classes = train_set.classes
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
color_model = ColorNet()
if os.path.exists('./colornet_params.pkl'):
color_model.load_state_dict(torch.load('colornet_params.pkl'))
if have_cuda:
color_model.cuda()
optimizer = optim.Adadelta(color_model.parameters())
for epoch in range(1, epochs + 1):
train(epoch)
测试代码val.py
import os
import torch
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image
from skimage.color import lab2rgb
from skimage import io
from colornet import ColorNet
from myimgfolder import ValImageFolder
import numpy as np
import matplotlib.pyplot as plt
def val():
color_model.eval()
i = 0
for data, _ in val_loader:
original_img = data[0].unsqueeze(1).float()
gray_name = './gray/' + str(i) + '.jpg'
for img in original_img:
pic = img.squeeze().numpy()
pic = pic.astype(np.float64)
plt.imsave(gray_name, pic, cmap='gray')
w = original_img.size()[2]
h = original_img.size()[3]
scale_img = data[1].unsqueeze(1).float()
if have_cuda:
original_img, scale_img = original_img.cuda(), scale_img.cuda()
original_img, scale_img = Variable(original_img, volatile=True), Variable(scale_img)
_, output = color_model(original_img, scale_img)
color_img = torch.cat((original_img, output[:, :, 0:w, 0:h]), 1)
color_img = color_img.data.cpu().numpy().transpose((0, 2, 3, 1))
for img in color_img:
img[:, :, 0:1] = img[:, :, 0:1] * 100
img[:, :, 1:3] = img[:, :, 1:3] * 255 - 128
img = img.astype(np.float64)
img = lab2rgb(img)
color_name = './colorimg/' + str(i) + '.jpg'
plt.imsave(color_name, img)
i += 1
# use the follow method can't get the right image but I don't know why
# color_img = torch.from_numpy(color_img.transpose((0, 3, 1, 2)))
# sprite_img = make_grid(color_img)
# color_name = './colorimg/'+str(i)+'.jpg'
# save_image(sprite_img, color_name)
# i += 1
if __name__ == '__main__':
data_dir = "../places205"
have_cuda = torch.cuda.is_available()
val_set = ValImageFolder(data_dir)
val_set_size = len(val_set)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)
color_model = ColorNet()
color_model.load_state_dict(torch.load('colornet_params.pkl'))
if have_cuda:
color_model.cuda()
val()