import random
import clip
from utils import NoiseImageDataSet, initialize_network, eval_request, NoiseDataSetWithClip
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import os
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from multiprocessing import Process
from config import config
import shutil
from torch.optim import lr_scheduler
from utils.loss import PSNRLoss, MSELoss
from experiments.base_enum import NoieType
import torchvision.transforms as transforms
from PIL import Image
from experiments.model import model_bn_large_attention
import numpy as np
from random import random
def image_to_numpy(img: Image):
data = np.array(img)
return data
class RandomResizedCrop:
def __init__(self, p):
self.p = p
def __call__(self, img: Image):
if random() > self.p:
return img
width, height = img.size[:2]
print("width,height:",width,height)
base_w = int(width * (0.5 + 0.5 * random()))
base_h = int(height * (0.5 + 0.5 * random()))
final_w = int(base_w * (0.8 + 0.2 * random()))
final_h = int(base_h * (0.8 + 0.2 * random()))
ratio = base_w / base_h
if ratio > width / height:
scale = base_h / height
else:
scale = base_w / width
return transforms.RandomResizedCrop(size=(final_h, final_w), scale=(scale, scale),
ratio=(ratio, ratio))(img)
model= model_bn_large_attention.UNet_n2n_un()
all_dict=torch.load('C:/Users/qhq/PycharmProjects/lossy-compression-denoise-master/lossy-compression-denoise-master/checkpoint/l2/145.pth')
ls=[]
for k in list(all_dict.keys()):
if k== 'net':
ls.append(all_dict[k])
flag='module'
for k in list(ls[0].keys()):
if flag in k:
le = k.strip('module')
le = le.lstrip('.')
print(le)
ls[0][le]=ls[0][k]
del ls[0][k]
flag1='track'
for k in list(ls[0].keys()):
if flag1 in k:
del ls[0][k]
model.load_state_dict(ls[0])
print("load true!")
img_path='C:/Users/qhq/PycharmProjects/lossy-compression-denoise-master/lossy-compression-denoise-master/samper_test/2.jpg'
img = Image.open(img_path)
transfor_valid=transforms.Compose(
[
# RandomResizedCrop(0.9),
transforms.RandomApply([transforms.ColorJitter((0.7, 1.3), (0.6, 1.4), (0.7, 1.3),
(-0.15, 0.15))], p=0.7),
image_to_numpy
]
)
img = transfor_valid(img)
img_tor=torch.from_numpy(img)
img_tor=img_tor.permute([2,0,1])
img_batch=img_tor.unsqueeze(dim=0) #增加batch_size这一维度
outputs = model(img_batch.to(torch.float32))
print(outputs.shape)
outputs=outputs.squeeze(dim=0)
outputs=outputs.permute(2,1,0)
outputs=outputs.detach().numpy()
#print(outputs.shape) #[1280,1280,3]
pil_image=Image.fromarray(np.uint8(outputs)) #nuppy数组转成可以显示的数组数据
pil_image.show()
pil_image.save(r'C:\Users\qhq\PycharmProjects\lossy-compression-denoise-master\lossy-compression-denoise-master\samper_test\1_1.jpg')
1.定义模型,实例化模型
model= model_bn_large_attention.UNet_n2n_un()
2.加载模型参数(自己写的模型,可能保存了一些与本身模型不对应的参数,需要进行处理)
模型参数是以字典的格式存储的,所以可以进行for k in list(all_dict.keys()):进行输出查看
加载模型一般都是model.load_state_dict(torch.load(‘模型参数保存的位置.pth')
model= model_bn_large_attention.UNet_n2n_un()
all_dict=torch.load('C:/Users/qhq/PycharmProjects/lossy-compression-denoise-master/lossy-compression-denoise-master/checkpoint/l2/145.pth')
ls=[]
for k in list(all_dict.keys()):
if k== 'net':
ls.append(all_dict[k])
flag='module'
for k in list(ls[0].keys()):
if flag in k:
le = k.strip('module')
le = le.lstrip('.')
print(le)
ls[0][le]=ls[0][k]
del ls[0][k]
flag1='track'
for k in list(ls[0].keys()):
if flag1 in k:
del ls[0][k]
model.load_state_dict(ls[0])
3.打开需要输入测试的图片:
img = Image.open(img_path)
对输入的图片进行标准化(达到输入模型的尺寸要求)
transfor_valid=transforms.Compose(
[
# RandomResizedCrop(0.9),
transforms.RandomApply([transforms.ColorJitter((0.7, 1.3), (0.6, 1.4), (0.7, 1.3),
(-0.15, 0.15))], p=0.7),
image_to_numpy
]
)
img = transfor_valid(img)
对输入的图片转成torch(只有将numpy数据转成torch之后才可以使用permute方法),并转换(permute)通道位置。
由于输入到模型的数据为[batch_size,通道数,高,宽],所以需要进行增加维度使用.unsqueeze方法。
img = transfor_valid(img)
img_tor=torch.from_numpy(img)
img_tor=img_tor.permute([2,0,1])
img_batch=img_tor.unsqueeze(dim=0) #增加batch_size这一维度
4.将符合输入格式的torch数据输入到model中
outputs = model(img_batch.to(torch.float32))
5.为了显示出图片,又需要进行降维处理,调换维度位置,最后转成PIL可以显示的图片数据
outputs=outputs.squeeze(dim=0)
outputs=outputs.permute(2,1,0)
outputs=outputs.detach().numpy()
#print(outputs.shape) #[1280,1280,3]
pil_image=Image.fromarray(np.uint8(outputs)) #nuppy数组转成可以显示的数组数据
pil_image.show()