MemSeg推理代码

from data import create_dataset, create_dataloader

import cv2
import torch
import yaml
from timm import create_model
from models import MemSeg
from torchvision import transforms

def minmax_scaling(img):
        return (((img - img.min()) / (img.max() - img.min())) * 255).to(torch.uint8)

cfg = yaml.load(open('./configs.yaml', 'r'), Loader=yaml.FullLoader)

testset = create_dataset(
        datadir=cfg['DATASET']['datadir'],
        target='capsule',
        is_train=False,
        resize=cfg['DATASET']['resize'],
        texture_source_dir=cfg['DATASET']['texture_source_dir'],
        structure_grid_size=cfg['DATASET']['structure_grid_size'],
        transparency_range=cfg['DATASET']['transparency_range'],
        perlin_scale=cfg['DATASET']['perlin_scale'],
        min_perlin_scale=cfg['DATASET']['min_perlin_scale'],
        perlin_noise_threshold=cfg['DATASET']['perlin_noise_threshold']
    )

memory_bank = torch.load("C:\Code\MemSeg\saved_model\MemSeg-capsule\memory_bank.pt")
memory_bank.device = 'cpu'
for k in memory_bank.memory_information.keys():
        memory_bank.memory_information[k] = memory_bank.memory_information[k].cpu()
feature_extractor = create_model(
        cfg['MODEL']['feature_extractor_name'],
        pretrained=True,
        features_only=True
    )
model = MemSeg(
        memory_bank=memory_bank,
        feature_extractor=feature_extractor
    )
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean = (0.485, 0.456, 0.406),
                std  = (0.229, 0.224, 0.225)
            )
        ])
model.load_state_dict(torch.load(r"C:\Code\MemSeg\saved_model\MemSeg-capsule\best_model.pt"))
img = cv2.imread(r"C:\coco\mvtec_anomaly_detection\capsule\test\crack\014.png")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, dsize=(256,256))
input_i = transform(img)
output_i = model(input_i.unsqueeze(0)).detach()
output_i = torch.nn.functional.softmax(output_i, dim=1)
mask = output_i[0][1]
cv2.imwrite("mask.png",mask.numpy() * 255)
# trace_script_module = torch.jit.trace(model,input_i.unsqueeze(0))
# trace_script_module.save('net.torchscript')

你可能感兴趣的:(深度学习,pytorch,计算机视觉)