前言:博主闲来无事,写了个AE网络来实现任意图片的去马赛克(前提是训练样本是哪方面的马赛克图~~,博主是针对打码人脸的去码)。样本分为没打码人脸8000张和经打码处理的对应8000张(取自celeba人脸数据集)。预处理阶段:人脸resize到256x256,打码。为读入数据,把文件夹里的图片名读取出并保存到了label.txt文件。主程序分4块:网络定义文件net.py, 读数据dataset.py, 训练train.py, 测试test.py 。并在每段代码上,我认为初学者易犯错的地方做了详解。
这是一个练手的网络,仅供初学者参考交流,博主希望能通过此代码在以下几点对初学者有所帮助
(1)dataset.py
'自制作样本标签txt以读取'
import torch
import os
import numpy as np
import cv2
from torch.utils.data import Dataset,DataLoader
class GetData(Dataset):
def __init__(self,path1,path2):
super(GetData,self).__init__()
self.path1 = path1
self.path2 = path2
self.dataset1 = []
self.dataset2 = []
self.dataset1.extend(open(os.path.join(self.path1,'label.txt')).readlines())
self.dataset2.extend(open(os.path.join(self.path2,'label.txt')).readlines())
def __getitem__(self, index): #index不是待赋参量,而是对应批次batch_size
str1 = self.dataset1[index].strip() #如dataset[0]是第一批次
str2 = self.dataset2[index].strip()
imgpath1 = os.path.join(self.path1,str1)
imgpath2 = os.path.join(self.path2,str2)
im1 = cv2.imread(imgpath1)
im2 = cv2.imread(imgpath2)
'对imgdata不要用transpose,会导致cv2.imshow()时出现显示错误!!!'
imgdata1 = torch.Tensor((im1 / 255. - 0.5))
imgdata2 = torch.Tensor((im2 / 255. - 0.5))
return imgdata1,imgdata2
def __len__(self):
return len(self.dataset1)
'cv2里BGR且HWC'
if __name__ == '__main__':
dataset= GetData(r'C:\Users\87419\Desktop\VAE1\data\trainA',r'C:\Users\87419\Desktop\VAE1\data\trainB')
dataloader = DataLoader(dataset, batch_size=200 ,shuffle=True) #经DataLoader()加批次N,由3维升至4维。平展求参数总量要乘N
for i,(imgdata ,_) in enumerate(dataloader):
# print(imgdata.numpy().shape) # NHWC(1, 256, 256, 3)
'imgdata是4维数据NHWC;imgdata[i]是第i-1批次,是3维数据HWC'
im1 = imgdata[0].numpy().reshape((256,256,3)) #经imgdata = torch.Tensor((im1 / 255. - 0.5))操作,显示与原图对应的像素点颜色不一样
cv2.imshow('a',im1)
cv2.waitKey(0)
(2).net.py
import torch
import torch.nn as nn
'需要分类时用全连接,提取特征用卷积'
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1), # (3,256,256)-->(16,256,256)
nn.PReLU(),
nn.Conv2d(in_channels=16,
out_channels=16,
kernel_size=4,
stride=2,
padding=1), # (16,256,256)-->(16,128,128)
nn.PReLU(),
nn.Conv2d(in_channels=16,
out_channels=32,
kernel_size=3,
stride=1,
padding=1), # (32,128,128)
nn.PReLU(),
nn.Conv2d(in_channels=32,
out_channels=32,
kernel_size=4,
stride=2,
padding=1), # (32,64,64)
nn.PReLU(),
nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1), # (64,64,64)
nn.PReLU(),
nn.Conv2d(in_channels=64,
out_channels=64,
kernel_size=4,
stride=2,
padding=1), # (64,32,32)
nn.PReLU(),
nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=4,
stride=2,
padding=1), # (128,16,16)
nn.PReLU(),
nn.Conv2d(in_channels=128,
out_channels=128,
kernel_size=4,
stride=2,
padding=1), # (128,8,8)
nn.PReLU(),
nn.Conv2d(in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1), # (256,8,8)
nn.PReLU(),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=4,
stride=2,
padding=1) # (256,4,4)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 256, kernel_size=3, stride=3, padding=2), #(256,4,4)-->(256,8,8)
nn.PReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1), # (256,8,8)-->(128,8,8)
nn.PReLU(),
nn.ConvTranspose2d(128, 128,kernel_size=2,stride= 2), # (128,8,8)-->(128,16,16)
nn.PReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), # (128,16,16)-->(64,32,32)
nn.PReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2), # (64,32,32)-->(64,64,64)
nn.PReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=1, stride=1), # (64,64,64)-->(32,64,64)
nn.PReLU(),
nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2), # (32,64,64)-->(32,128,128)
nn.PReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=1, stride=1), # (32,128,128)-->(16,128,128)
nn.PReLU(),
nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2), # (16,128,128)-->(16,256,256)
nn.PReLU(),
nn.ConvTranspose2d(16, 3, kernel_size=1, stride=1), # (16,256,256)-->(3,256,256)
nn.Tanh()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
(3)train.py
import torch
import torch.optim as optim
from net import AutoEncoder
import torch.nn as nn
import os
import cv2
import numpy as np
from dataset import GetData
from torch.utils.data import DataLoader
auto = AutoEncoder()
auto.cuda()
LR = 0.001
BATCH_SIZE = 180
EPOCHES = 200
optimizer = optim.Adam(auto.parameters(),lr=LR)
loss_f = nn.MSELoss() #均方误差
def train(x,_x): # x有码图,_x无码图
if os.path.exists(r'C:\Users\87419\Desktop\VAE2\auto.pkl'):
auto.load_state_dict(torch.load(r'C:\Users\87419\Desktop\VAE2\auto.pkl'))
#################################################
'此段来显示图片,用以判断输入的无码图是否正常'
# _x = _x[0].detach().cpu().data.numpy()
# _x = _x.reshape(256,256,3)
# cv2.imshow('aa', _x)
# cv2.waitKey(0)
################################################
encoded, decoded = auto(x)
loss = loss_f(decoded,_x)
return loss
for i in range(EPOCHES):
print('epoch:',i)
dataset = GetData(r'C:\Users\87419\Desktop\VAE1\data\trainB', r'C:\Users\87419\Desktop\VAE1\data\trainA')
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
for j,(imgdata1,imgdata2) in enumerate(dataloader):
'每次循环,同时处理BATCH_SIZE张。故每个epoch内的循环次数j=总张数/BATCH_SIZE '
'对于dataloder的批次,imgdata是NCHW。对其取imgdata[0]是取N上的第一个批次对应的CHW而不是数值N,同理imgdata[9]是取N上的第10个批次对应的CHW'
'为了让批次设的有意义,取批量的imgdata1.cuda()而不是imgdata1[0].cuda()'
imgdata1_ = imgdata1.cuda()
imgdata2_ = imgdata2.cuda()
'切记为不影响矩阵内部结构就用reshape() 。此处view()等都会导致内部矩阵结构变化,从而输出图片出问题'
# imgdata1_ = imgdata1_.view(-1,3,256,256)
# imgdata2_ = imgdata2_.view(-1,3,256,256)
imgdata1_ = imgdata1_.reshape(-1,3,256,256)
imgdata2_ = imgdata2_.reshape(-1,3,256,256)
# print(imgdata1_.shape) # [180, 3 ,256, 256]
loss = train(imgdata1_,imgdata2_)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('j:',j,'|','loss:',loss)
torch.save(auto.state_dict(),'auto.pkl')
(4)test.py
import torch
from net import AutoEncoder
import cv2
import os
from PIL import Image
import numpy as np
auto = AutoEncoder()
auto.load_state_dict(torch.load('auto.pkl'))
auto.cuda()
count = 0
# x_path = r'C:\Users\87419\Desktop\VAE1\data\test'
x_path = r'C:\Users\87419\Desktop\cg1\dama'
for name in os.listdir(x_path):
count += 1
im1 = cv2.imread(os.path.join(x_path,name))
arr = np.array(im1)
arr_ = torch.Tensor(arr /255. - 0.5)
arr_ = arr_.cuda()
arr_ = arr_.reshape(-1,3,256,256)
encoded = auto.encoder(arr_)
decoded=auto.decoder(encoded)
img = (decoded.detach().cpu().numpy() + 0.5)*255
'网络运算时是NCHW,待到输出图时需变回HWC'
img = img.reshape(256, 256,3)
# img = img[:, :, ::-1] # BGR->RGB
# print(img)
# cv2.imshow('{}'.format(name),img)
# cv2.waitKey(0)
# cv2.imwrite(os.path.join(r'C:\Users\87419\Desktop\VAE1\data\save',name),img)
cv2.imwrite(os.path.join(r'C:\Users\87419\Desktop\cg1\restore',name),img)
(1).resize.py
'cv2'
import cv2
import os
import glob
path = r'C:\Users\87419\Desktop\cg1\img\*.jpg'
for i in glob.glob(path):
im1 = cv2.imread(i)
im2 = cv2.resize(im1, (256, 256))
cv2.imwrite(os.path.join(r'C:\Users\87419\Desktop\cg1\resize', os.path.basename(i)), im2)
(2)dama.py
import os
from PIL import Image
import numpy as np
outdir = r'C:\Users\87419\Desktop\cg1\dama'
count = 0
path = r'C:\Users\87419\Desktop\cg1\resize'
x_names = os.listdir(path)
x_names.sort(key=lambda i: int(i[:-4]))
for i in x_names:
im1 = Image.open(os.path.join(path, i))
arr = np.array(im1)
h = arr.shape[0]
w = arr.shape[1]
for j in range(int((1 / 3) * w), int((2 / 3) * w), 1):
for k in range(int((1 / 3) * h), int((2 / 3) * h), 1):
im1.putpixel((j, k), (255, 0, 0))
count += 1
print(count)
im1.save(os.path.join(outdir, '{}.jpg'.format(count)))
(3)ToTxt.py
import os
def ListFilesToTxt(dir, file, wildcard, recursion):
exts = wildcard.split(" ")
files = os.listdir(dir)
files.sort(key=lambda x: int(x[:-4]))
for name in files:
fullname = os.path.join(dir, name)
if (os.path.isdir(fullname) & recursion):
ListFilesToTxt(fullname, file, wildcard, recursion)
else:
for ext in exts:
if (name.endswith(ext)):
file.write(name + "\n")
break
def Test():
dir = r"C:\Users\87419\Desktop\VAE1\data\trainA" # 读入
outfile = "label.txt" # 写入
# wildcard = ".jpg .txt .exe .dll .lib" # 要读取的文件类型
wildcard = ".jpg"
file = open(outfile, "w")
if not file:
print("cannot open the file %s for writing" % outfile)
ListFilesToTxt(dir, file, wildcard, 1)
file.close()
if __name__ == '__main__':
Test()