无监督网络

import torch
import math
from torch import onnx
from torch import nn
import numpy as np
import torch.nn.functional as F
import scipy.ndimage as ndimage

def pth2onnx(model, dummy_input, dynamiconnx):
    torch.set_grad_enabled(False)
    input_names = ["input1"]
    output_names = ["output1"]
    # 保存维度变化的onnx
    onnx.export(model=model, args=dummy_input, f=dynamiconnx, input_names=input_names,
                output_names=output_names, verbose=False,
                dynamic_axes=dict([(k, {0: 'batch_size'}) for k in input_names] +
                                  [(k, {0: 'batch_size'}) for k in output_names]),
                keep_initializers_as_inputs=True)

class Convolution(torch.nn.Module):
    def __init__(self,in_chanel,out_chanel,kernalsize,strid,padding):
        super(Convolution,self).__init__()
        self.conv = nn.Conv2d(in_chanel,out_chanel,kernalsize,strid,padding)
        self.bn = nn.BatchNorm2d(out_chanel)
        self.active = nn.Mish(True)

    def forward(self,x):
        return self.active(self.bn(self.conv(x)))

class PatchMaker(nn.Module):
    def __init__(self, patchsize, top_k=0, stride=None):
        super(PatchMaker,self).__init__()
        self.patchsize = patchsize
        self.stride = stride
        self.top_k = top_k

    def forward(self, features):
        """Convert a tensor into a tensor of respective patches.
        Args:
            x: [torch.Tensor, bs x c x w x h]
        Returns:
            x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
            patchsize]
        """
        return_spatial_info = True
        padding = int((self.patchsize - 1) / 2)#1
        unfolder = torch.nn.Unfold(
            kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1
        )
        unfolded_features = unfolder(features)
        number_of_total_patches = []
        for s in features.shape[-2:]:
            n_patches = (
                s + 2 * padding - 1 * (self.patchsize - 1) - 1
            ) / self.stride + 1
            number_of_total_patches.append(int(n_patches))
        unfolded_features = unfolded_features.reshape(
            *features.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)

        if return_spatial_info:
            return unfolded_features, number_of_total_patches
        return unfolded_features

class Resblock(nn.Module):
    def __init__(self,ch):
        super(Resblock,self).__init__()
        self.conv1 = Convolution(ch, ch // 2, 1, 1, 0)
        self.conv2 = Convolution(ch // 2,ch // 2, 3, 1, 1)
        self.conv3 = nn.Conv2d(ch // 2,ch, 1, 1)
        self.relu = nn.ReLU(True)
    def forward(self,x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        return self.relu(x + y)

class MeanMapper(torch.nn.Module):
    def __init__(self, preprocessing_dim):
        super(MeanMapper, self).__init__()
        self.preprocessing_dim = preprocessing_dim
    def forward(self, features):
        features = features.reshape(len(features), 1, -1)
        return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)

def init_weight(m):

    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
    elif isinstance(m, torch.nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight)

class Projection(torch.nn.Module):
    def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
        super(Projection, self).__init__()

        if out_planes is None:
            out_planes = in_planes
        self.layers = torch.nn.Sequential()
        _in = None
        _out = None
        for i in range(n_layers):
            _in = in_planes if i == 0 else _out
            _out = out_planes
            self.layers.add_module(f"{i}fc",
                                   torch.nn.Linear(_in, _out))
            if i < n_layers - 1:
                # if layer_type > 0:
                #     self.layers.add_module(f"{i}bn",
                #                            torch.nn.BatchNorm1d(_out))
                if layer_type > 1:
                    self.layers.add_module(f"{i}relu",
                                           torch.nn.LeakyReLU(.2))
        self.apply(init_weight)

    def forward(self, x):
        # x = .1 * self.layers(x) + x
        x = self.layers(x)
        return x

class Discriminator(torch.nn.Module):
    def __init__(self, in_planes, n_layers=1, hidden=None):
        super(Discriminator, self).__init__()
        _hidden = in_planes if hidden is None else hidden
        self.body = torch.nn.Sequential()
        for i in range(n_layers-1):
            _in = in_planes if i == 0 else _hidden
            _hidden = int(_hidden // 1.5) if hidden is None else hidden
            self.body.add_module('block%d'%(i+1),
                                 torch.nn.Sequential(
                                     torch.nn.Linear(_in, _hidden),
                                     torch.nn.BatchNorm1d(_hidden),
                                     torch.nn.LeakyReLU(0.2)
                                 ))
        self.tail = torch.nn.Linear(_hidden, 1, bias=False)
        self.apply(init_weight)

    def forward(self,x):
        x = self.body(x)
        x = self.tail(x)
        return x

class unsupervisedNet(torch.nn.Module):
    def __init__(self):
        super(unsupervisedNet,self).__init__()
        self.batchsize = 1
        self.device = 'cpu'#0
        self.input_shape = [3,512,512]
        self.target_size = self.input_shape[-2:]
        self.patchsize = 3
        self.stride = 1
        self.top_k = 0
        self.input_dims = [512,1024]
        self.output_dim = 1536
        self.smoothing = 4
        self.preprocessing_modules = torch.nn.ModuleList()
        for input_dim in self.input_dims:
            module = MeanMapper(self.output_dim)
            self.preprocessing_modules.append(module)
        self.pre_projection = Projection(self.output_dim, self.output_dim,1,0)

        self.discriminator = Discriminator(self.output_dim, n_layers=2, hidden=1024)

        self.conv1 = Convolution(3,64,7,2,3)
        self.maxpool1 = nn.MaxPool2d(3,2)
        self.conv2 = Convolution(64,128,1,1,0)
        self.conv3 = Convolution(128,128,3,1,1)
        self.conv4 = nn.Conv2d(128,256,1,1,0)
        self.conv5 = nn.Conv2d(64,256,1,1,0)

        self.relu = nn.ReLU(True)

        self.conv6 = Convolution(256,128,1,1,0)
        self.conv7 = Convolution(128,128,3,1,1)
        self.conv8 = nn.Conv2d(128,256,1,1)

        self.conv9 = Convolution(256,128,1,1,0)
        self.conv10 = Convolution(128,128,3,1,1)
        self.conv11 = nn.Conv2d(128,256,1,1)

        self.conv12 = Convolution(256,256,1,1,0)
        self.conv13 = Convolution(256,256,3,2,1)
        self.conv14 = nn.Conv2d(256,512,1,1)
        self.conv15 = nn.Conv2d(256,512,1,2)
        self.conv16 = Resblock(512)
        self.conv17 = Resblock(512)

        self.conv18 = Convolution(512,512,1,1,0)
        self.conv19 = Convolution(512,512,3,2,1)
        self.conv20 = nn.Conv2d(512,1024,1,1)
        self.conv21 = nn.Conv2d(512,1024,3,2,1)

        self.conv22 = Resblock(1024)
        self.conv23 = Resblock(1024)
        self.conv24 = Resblock(1024)
        self.conv25 = Resblock(1024)
        self.conv26 = Resblock(1024)

        self.unfolded_features = []
        self.patch_shapes = []
        self.padding = int((self.patchsize - 1) / 2)
        self.unfolder = nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=self.padding, dilation=1)

    def score(self, x):
        was_numpy = False
        if isinstance(x, np.ndarray):
            was_numpy = True
            x = torch.from_numpy(x)
        while x.ndim > 2:
            x = torch.max(x, dim=-1).values
        if x.ndim == 2:
            if self.top_k > 1:
                x = torch.topk(x, self.top_k, dim=1).values.mean(1)
            else:
                x = torch.max(x, dim=1).values
        if was_numpy:
            return x.numpy()
        return x

    def forward(self,x):
        y = self.conv1(x)
        y = self.maxpool1(y)
        y1 = self.conv2(y)
        y1 = self.conv3(y1)
        y1 = self.conv4(y1)
        y2 = self.conv5(y)
        out1 = self.relu(y1 + y2)
        y3 = self.conv6(out1)
        y3 = self.conv7(y3)
        y3 = self.conv8(y3)
        out2 = self.relu(out1 + y3)
        y4 = self.conv9(out2)
        y4 = self.conv10(y4)
        y4 = self.conv11(y4)
        out3 = self.relu(out2 + y4)
        y5 = self.conv12(out3)
        y5 = self.conv13(y5)
        y5 = self.conv14(y5)
        y6 = self.conv15(out3)
        out4 = self.relu(y5 + y6)
        out4 = self.conv16(out4)
        output1 = self.conv17(out4)
        y7 = self.conv18(output1)
        y7 = self.conv19(y7)
        y7 = self.conv20(y7)
        y8 = self.conv21(output1)
        out5 = self.relu(y7 + y8)
        out5 = self.conv22(out5)
        out5 = self.conv23(out5)
        out5 = self.conv24(out5)
        out5 = self.conv25(out5)
        output2 = self.conv26(out5)

        unfolded_features1 = self.unfolder(output1)
        patch_shapes1 = []
        for s in output1.shape[-2:]:
            n_patches = (
                                s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
                        ) / self.stride + 1
            patch_shapes1.append(int(n_patches))
        unfolded_features1 = unfolded_features1.reshape(
            *output1.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features1 = unfolded_features1.permute(0, 4, 1, 2, 3)

        unfolded_features2 = self.unfolder(output2)
        patch_shapes2 = []
        for s in output2.shape[-2:]:
            n_patches = (
                                s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
                        ) / self.stride + 1
            patch_shapes2.append(int(n_patches))
        unfolded_features2 = unfolded_features2.reshape(
            *output2.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features2 = unfolded_features2.permute(0, 4, 1, 2, 3)

        ref_num_patches = patch_shapes1
        _features = unfolded_features2
        patch_dims = patch_shapes2
        _features = _features.reshape(
            _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
        )
        _features = _features.permute(0, -3, -2, -1, 1, 2)
        perm_base_shape = _features.shape
        _features = _features.reshape(-1, *_features.shape[-2:])
        _features = F.interpolate(
            _features.unsqueeze(1),
            size=(ref_num_patches[0], ref_num_patches[1]),
            mode="bilinear",
            align_corners=False,
        )
        _features = _features.squeeze(1)
        _features = _features.reshape(
            *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
        )
        _features = _features.permute(0, -2, -1, 1, 2, 3)
        _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
        unfolded_features2 = _features

        unfolded_features1 = unfolded_features1.reshape(-1,*unfolded_features1.shape[-3:])
        unfolded_features2 = unfolded_features2.reshape(-1, *unfolded_features2.shape[-3:])

        # _features = []
        model1 = self.preprocessing_modules[0]
        feature1 = model1(unfolded_features1)
        model2 = self.preprocessing_modules[1]
        feature2 = model2(unfolded_features2)
        features = torch.stack([feature1,feature2], dim=1)

        features = features.reshape(len(features), 1, -1)
        features = F.adaptive_avg_pool1d(features,self.output_dim)
        features = features.reshape(len(features), -1)#torch.Size([4096, 1536])

        patch_shapes = []
        patch_shapes.append(patch_shapes1)
        patch_shapes.append(patch_shapes2)

        features = self.pre_projection(features)#torch.Size([4096, 1536])

        patch_scores = image_scores = -self.discriminator(features)

        patch_scores = patch_scores.cpu().detach().numpy()
        image_scores = image_scores.cpu().detach().numpy()

        image_scores = image_scores.reshape(self.batchsize,-1,*image_scores.shape[1:])
        image_scores = image_scores.reshape(*image_scores.shape[:2], -1)

        image_scores = self.score(image_scores)
        image_scores = image_scores.reshape(self.batchsize,-1,*image_scores.shape[1:])
        scales = patch_shapes[0]
        patch_scores = patch_scores.reshape(1, scales[0], scales[1])
        features = features.reshape(1, scales[0], scales[1], -1)

        with torch.no_grad():
            if isinstance(patch_scores, np.ndarray):
                patch_scores = torch.from_numpy(patch_scores)
            _scores = patch_scores.to(self.device)
            _scores = _scores.unsqueeze(1)
            _scores = F.interpolate(
                _scores, size=self.target_size, mode="bilinear", align_corners=False
            )
            _scores = _scores.squeeze(1)
            patch_scores = _scores.cpu().numpy()

            if isinstance(features, np.ndarray):
                features = torch.from_numpy(features)
            features = features.to(self.device).permute(0, 3, 1, 2)
            if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2 ** 31:
                subbatch_size = int((2 ** 31 - 1) / (self.target_size[0] * self.target_size[1] * features.shape[1]))
                interpolated_features = []
                for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)):
                    subfeatures = features[i_subbatch * subbatch_size:(i_subbatch + 1) * subbatch_size]
                    subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures
                    subfeatures = F.interpolate(
                        subfeatures, size=self.target_size, mode="bilinear", align_corners=False
                    )
                    interpolated_features.append(subfeatures)
                features = torch.cat(interpolated_features, 0)
            else:
                features = F.interpolate(
                    features, size=self.target_size, mode="bilinear", align_corners=False
                )
            # features = features.cpu().detach().numpy()
        masks = [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
        masks = torch.tensor(masks)
        return masks#,self.patch_shapes

net = unsupervisedNet()
# net.cuda()
x = torch.randn((1,3,512,512))#.cuda()
masks = net(x)
print(masks.shape)
# pth2onnx(net,x,'test.onnx')
# trace_script_module = torch.jit.trace(net,x)
# trace_script_module.save('net.torchscript')
class unsupervisedNet(torch.nn.Module):
    def __init__(self,batchsize,train):
        super(unsupervisedNet,self).__init__()
        self.batchsize = batchsize
        self.train_backbone = train
        self.backbone_name = 'wideresnet50'
        self.backbone = backbones.load(self.backbone_name)
        self.device = 'cpu'  # 0
        self.input_shape = [3, 512, 512]

        self.target_size = self.input_shape[-2:]
        self.patchsize = 3
        self.stride = 1
        self.top_k = 0
        self.input_dims = [512, 1024]
        self.output_dim = 1536
        self.smoothing = 4
        self.preprocessing_modules = torch.nn.ModuleList()
        for input_dim in self.input_dims:
            module = MeanMapper(self.output_dim)
            self.preprocessing_modules.append(module)
        self.pre_projection = Projection(self.output_dim, self.output_dim, 1, 0)

        self.discriminator = Discriminator(self.output_dim, n_layers=2, hidden=1024)

        self.layer2 = nn.Sequential(
            self.backbone.conv1,
            self.backbone.bn1,
            self.backbone.relu,
            self.backbone.maxpool,
            self.backbone.layer1,
            self.backbone.layer2,
        )

        self.layer3 = nn.Sequential(
            self.backbone.conv1,
            self.backbone.bn1,
            self.backbone.relu,
            self.backbone.maxpool,
            self.backbone.layer1,
            self.backbone.layer2,
            self.backbone.layer3
        )

        self.unfolded_features = []
        self.patch_shapes = []
        self.padding = int((self.patchsize - 1) / 2)
        self.unfolder = nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=self.padding, dilation=1)
        # self.feature_aggregator = NetworkFeatureAggregator(
        #     self.backbone, self.layers_to_extract_from, self.device, self.train_backbone
        # )

        # self.feature_dimensions = self.feature_aggregator.feature_dimensions(self.input_shape)

    def score(self, x):
        was_numpy = False
        if isinstance(x, np.ndarray):
            was_numpy = True
            x = torch.from_numpy(x)
        while x.ndim > 2:
            x = torch.max(x, dim=-1).values
        if x.ndim == 2:
            if self.top_k > 1:
                x = torch.topk(x, self.top_k, dim=1).values.mean(1)
            else:
                x = torch.max(x, dim=1).values
        if was_numpy:
            return x.numpy()
        return x

    def forward(self,x):
        output1 = self.layer2(x)
        output2= self.layer3(x)
        unfolded_features1 = self.unfolder(output1)
        patch_shapes1 = []
        for s in output1.shape[-2:]:
            n_patches = (
                                s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
                        ) / self.stride + 1
            patch_shapes1.append(int(n_patches))
        unfolded_features1 = unfolded_features1.reshape(
            *output1.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features1 = unfolded_features1.permute(0, 4, 1, 2, 3)

        unfolded_features2 = self.unfolder(output2)
        patch_shapes2 = []
        for s in output2.shape[-2:]:
            n_patches = (
                                s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
                        ) / self.stride + 1
            patch_shapes2.append(int(n_patches))
        unfolded_features2 = unfolded_features2.reshape(
            *output2.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features2 = unfolded_features2.permute(0, 4, 1, 2, 3)

        ref_num_patches = patch_shapes1
        _features = unfolded_features2
        patch_dims = patch_shapes2
        _features = _features.reshape(
            _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
        )
        _features = _features.permute(0, -3, -2, -1, 1, 2)
        perm_base_shape = _features.shape
        _features = _features.reshape(-1, *_features.shape[-2:])
        _features = F.interpolate(
            _features.unsqueeze(1),
            size=(ref_num_patches[0], ref_num_patches[1]),
            mode="bilinear",
            align_corners=False,
        )
        _features = _features.squeeze(1)
        _features = _features.reshape(
            *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
        )
        _features = _features.permute(0, -2, -1, 1, 2, 3)
        _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
        unfolded_features2 = _features

        unfolded_features1 = unfolded_features1.reshape(-1, *unfolded_features1.shape[-3:])
        unfolded_features2 = unfolded_features2.reshape(-1, *unfolded_features2.shape[-3:])

        # _features = []
        model1 = self.preprocessing_modules[0]
        feature1 = model1(unfolded_features1)
        model2 = self.preprocessing_modules[1]
        feature2 = model2(unfolded_features2)
        features = torch.stack([feature1, feature2], dim=1)

        features = features.reshape(len(features), 1, -1)
        features = F.adaptive_avg_pool1d(features, self.output_dim)
        features = features.reshape(len(features), -1)  # torch.Size([4096, 1536])

        patch_shapes = []
        patch_shapes.append(patch_shapes1)
        patch_shapes.append(patch_shapes2)

        features = self.pre_projection(features)  # torch.Size([4096, 1536])

        patch_scores = image_scores = -self.discriminator(features)

        patch_scores = patch_scores.cpu().detach().numpy()
        image_scores = image_scores.cpu().detach().numpy()

        image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
        image_scores = image_scores.reshape(*image_scores.shape[:2], -1)

        image_scores = self.score(image_scores)
        image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
        scales = patch_shapes[0]
        patch_scores = patch_scores.reshape(1, scales[0], scales[1])
        features = features.reshape(1, scales[0], scales[1], -1)

        with torch.no_grad():
            if isinstance(patch_scores, np.ndarray):
                patch_scores = torch.from_numpy(patch_scores)
            _scores = patch_scores.to(self.device)
            _scores = _scores.unsqueeze(1)
            _scores = F.interpolate(
                _scores, size=self.target_size, mode="bilinear", align_corners=False
            )
            _scores = _scores.squeeze(1)
            patch_scores = _scores.cpu().numpy()

            if isinstance(features, np.ndarray):
                features = torch.from_numpy(features)
            features = features.to(self.device).permute(0, 3, 1, 2)
            if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2 ** 31:
                subbatch_size = int((2 ** 31 - 1) / (self.target_size[0] * self.target_size[1] * features.shape[1]))
                interpolated_features = []
                for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)):
                    subfeatures = features[i_subbatch * subbatch_size:(i_subbatch + 1) * subbatch_size]
                    subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures
                    subfeatures = F.interpolate(
                        subfeatures, size=self.target_size, mode="bilinear", align_corners=False
                    )
                    interpolated_features.append(subfeatures)
                features = torch.cat(interpolated_features, 0)
            else:
                features = F.interpolate(
                    features, size=self.target_size, mode="bilinear", align_corners=False
                )
            # features = features.cpu().detach().numpy()
        masks = [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
        masks = torch.tensor(masks)
        return masks  # ,self.patch_shapes
import torch
import math
from torch import onnx
from torch import nn
import numpy as np
import torch.nn.functional as F
import scipy.ndimage as ndimage
import backbones
import copy

def pth2onnx(model, dummy_input, dynamiconnx):
    torch.set_grad_enabled(False)
    input_names = ["input1"]
    output_names = ["output1"]
    # 保存维度变化的onnx
    onnx.export(model=model, args=dummy_input, f=dynamiconnx, input_names=input_names,
                output_names=output_names, verbose=False,
                dynamic_axes=dict([(k, {0: 'batch_size'}) for k in input_names] +
                                  [(k, {0: 'batch_size'}) for k in output_names]),
                keep_initializers_as_inputs=True)

class Convolution(torch.nn.Module):
    def __init__(self,in_chanel,out_chanel,kernalsize,strid,padding):
        super(Convolution,self).__init__()
        self.conv = nn.Conv2d(in_chanel,out_chanel,kernalsize,strid,padding)
        self.bn = nn.BatchNorm2d(out_chanel)
        self.active = nn.Mish(True)

    def forward(self,x):
        return self.active(self.bn(self.conv(x)))

class PatchMaker(nn.Module):
    def __init__(self, patchsize, top_k=0, stride=None):
        super(PatchMaker,self).__init__()
        self.patchsize = patchsize
        self.stride = stride
        self.top_k = top_k

    def forward(self, features):
        """Convert a tensor into a tensor of respective patches.
        Args:
            x: [torch.Tensor, bs x c x w x h]
        Returns:
            x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
            patchsize]
        """
        return_spatial_info = True
        padding = int((self.patchsize - 1) / 2)#1
        unfolder = torch.nn.Unfold(
            kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1
        )
        unfolded_features = unfolder(features)
        number_of_total_patches = []
        for s in features.shape[-2:]:
            n_patches = (
                s + 2 * padding - 1 * (self.patchsize - 1) - 1
            ) / self.stride + 1
            number_of_total_patches.append(int(n_patches))
        unfolded_features = unfolded_features.reshape(
            *features.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)

        if return_spatial_info:
            return unfolded_features, number_of_total_patches
        return unfolded_features

class Resblock(nn.Module):
    def __init__(self,ch):
        super(Resblock,self).__init__()
        self.conv1 = Convolution(ch, ch // 2, 1, 1, 0)
        self.conv2 = Convolution(ch // 2,ch // 2, 3, 1, 1)
        self.conv3 = nn.Conv2d(ch // 2,ch, 1, 1)
        self.relu = nn.ReLU(True)
    def forward(self,x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        return self.relu(x + y)

class MeanMapper(torch.nn.Module):
    def __init__(self, preprocessing_dim):
        super(MeanMapper, self).__init__()
        self.preprocessing_dim = preprocessing_dim
    def forward(self, features):
        features = features.reshape(len(features), 1, -1)
        return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)

def init_weight(m):

    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
    elif isinstance(m, torch.nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight)

class Projection(torch.nn.Module):
    def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
        super(Projection, self).__init__()

        if out_planes is None:
            out_planes = in_planes
        self.layers = torch.nn.Sequential()
        _in = None
        _out = None
        for i in range(n_layers):
            _in = in_planes if i == 0 else _out
            _out = out_planes
            self.layers.add_module(f"{i}fc",
                                   torch.nn.Linear(_in, _out))
            if i < n_layers - 1:
                # if layer_type > 0:
                #     self.layers.add_module(f"{i}bn",
                #                            torch.nn.BatchNorm1d(_out))
                if layer_type > 1:
                    self.layers.add_module(f"{i}relu",
                                           torch.nn.LeakyReLU(.2))
        self.apply(init_weight)

    def forward(self, x):
        # x = .1 * self.layers(x) + x
        x = self.layers(x)
        return x

class Discriminator(torch.nn.Module):
    def __init__(self, in_planes, n_layers=1, hidden=None):
        super(Discriminator, self).__init__()
        _hidden = in_planes if hidden is None else hidden
        self.body = torch.nn.Sequential()
        for i in range(n_layers-1):
            _in = in_planes if i == 0 else _hidden
            _hidden = int(_hidden // 1.5) if hidden is None else hidden
            self.body.add_module('block%d'%(i+1),
                                 torch.nn.Sequential(
                                     torch.nn.Linear(_in, _hidden),
                                     torch.nn.BatchNorm1d(_hidden),
                                     torch.nn.LeakyReLU(0.2)
                                 ))
        self.tail = torch.nn.Linear(_hidden, 1, bias=False)
        self.apply(init_weight)

    def forward(self,x):
        x = self.body(x)
        x = self.tail(x)
        return x

class ForwardHook:
    def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
        self.hook_dict = hook_dict
        self.layer_name = layer_name
        self.raise_exception_to_break = copy.deepcopy(
            layer_name == last_layer_to_extract
        )

    def __call__(self, module, input, output):
        self.hook_dict[self.layer_name] = output
        return None

class NetworkFeatureAggregator(torch.nn.Module):
    """Efficient extraction of network features."""

    def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):
        super(NetworkFeatureAggregator, self).__init__()
        """Extraction of network features.

        Runs a network only to the last layer of the list of layers where
        network features should be extracted from.

        Args:
            backbone: torchvision.model
            layers_to_extract_from: [list of str]
        """
        self.layers_to_extract_from = layers_to_extract_from
        self.backbone = backbone
        self.device = device
        self.train_backbone = train_backbone

        for extract_layer in layers_to_extract_from:
            if extract_layer == 'layer2':
                self.network_layer2 = backbone.__dict__["_modules"][extract_layer]
            if extract_layer == 'layer3':
                self.network_layer3 = backbone.__dict__["_modules"][extract_layer]
        print(self.network_layer2,'#%$^&*^*&(^^$%\n',self.network_layer3)
        self.to(self.device)

    def forward(self, images, eval=True):
        y = torch.randn((1,1000))
        if self.train_backbone and not eval:
            y = self.backbone(images)
        else:
            with torch.no_grad():
                try:
                    y = self.backbone(images)
                except:
                    pass
        return y

    # def feature_dimensions(self, input_shape):
    #     """Computes the feature dimensions for all layers given input_shape."""
    #     _input = torch.ones([1] + list(input_shape)).to(self.device)
    #     _output = self(_input)
    #     return [_output[layer].shape[1] for layer in self.layers_to_extract_from]

class unsupervisedNet(torch.nn.Module):
    def __init__(self,batchsize,train):
        super(unsupervisedNet,self).__init__()
        self.batchsize = batchsize
        self.train_backbone = train
        self.backbone_name = 'wideresnet50'
        self.backbone = backbones.load(self.backbone_name)
        self.device = 'cpu'  # 0
        self.input_shape = [3, 512, 512]

        self.target_size = self.input_shape[-2:]
        self.patchsize = 3
        self.stride = 1
        self.top_k = 0
        self.input_dims = [512, 1024]
        self.output_dim = 1536
        self.smoothing = 4
        self.preprocessing_modules = torch.nn.ModuleList()
        for input_dim in self.input_dims:
            module = MeanMapper(self.output_dim)
            self.preprocessing_modules.append(module)
        self.pre_projection = Projection(self.output_dim, self.output_dim, 1, 0)

        self.discriminator = Discriminator(self.output_dim, n_layers=2, hidden=1024)

        self.layer2 = nn.Sequential(
            self.backbone.conv1,
            self.backbone.bn1,
            self.backbone.relu,
            self.backbone.maxpool,
            self.backbone.layer1,
            self.backbone.layer2,
        )

        self.layer3 = nn.Sequential(
            self.backbone.conv1,
            self.backbone.bn1,
            self.backbone.relu,
            self.backbone.maxpool,
            self.backbone.layer1,
            self.backbone.layer2,
            self.backbone.layer3
        )

        self.unfolded_features = []
        self.patch_shapes = []
        self.padding = int((self.patchsize - 1) / 2)
        self.unfolder = nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=self.padding, dilation=1)
        # self.feature_aggregator = NetworkFeatureAggregator(
        #     self.backbone, self.layers_to_extract_from, self.device, self.train_backbone
        # )

        # self.feature_dimensions = self.feature_aggregator.feature_dimensions(self.input_shape)

    def score(self, x):
        was_numpy = False
        if isinstance(x, np.ndarray):
            was_numpy = True
            x = torch.from_numpy(x)
        while x.ndim > 2:
            x = torch.max(x, dim=-1).values
        if x.ndim == 2:
            if self.top_k > 1:
                x = torch.topk(x, self.top_k, dim=1).values.mean(1)
            else:
                x = torch.max(x, dim=1).values
        if was_numpy:
            return x.numpy()
        return x

    def forward(self,x):
        output1 = self.layer2(x)
        output2= self.layer3(x)
        unfolded_features1 = self.unfolder(output1)
        patch_shapes1 = []
        for s in output1.shape[-2:]:
            n_patches = (
                                s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
                        ) / self.stride + 1
            patch_shapes1.append(int(n_patches))
        unfolded_features1 = unfolded_features1.reshape(
            *output1.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features1 = unfolded_features1.permute(0, 4, 1, 2, 3)

        unfolded_features2 = self.unfolder(output2)
        patch_shapes2 = []
        for s in output2.shape[-2:]:
            n_patches = (
                                s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
                        ) / self.stride + 1
            patch_shapes2.append(int(n_patches))
        unfolded_features2 = unfolded_features2.reshape(
            *output2.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features2 = unfolded_features2.permute(0, 4, 1, 2, 3)

        ref_num_patches = patch_shapes1
        _features = unfolded_features2
        patch_dims = patch_shapes2
        _features = _features.reshape(
            _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
        )
        _features = _features.permute(0, -3, -2, -1, 1, 2)
        perm_base_shape = _features.shape
        _features = _features.reshape(-1, *_features.shape[-2:])
        _features = F.interpolate(
            _features.unsqueeze(1),
            size=(ref_num_patches[0], ref_num_patches[1]),
            mode="bilinear",
            align_corners=False,
        )
        _features = _features.squeeze(1)
        _features = _features.reshape(
            *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
        )
        _features = _features.permute(0, -2, -1, 1, 2, 3)
        _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
        unfolded_features2 = _features

        unfolded_features1 = unfolded_features1.reshape(-1, *unfolded_features1.shape[-3:])
        unfolded_features2 = unfolded_features2.reshape(-1, *unfolded_features2.shape[-3:])

        # _features = []
        model1 = self.preprocessing_modules[0]
        feature1 = model1(unfolded_features1)
        model2 = self.preprocessing_modules[1]
        feature2 = model2(unfolded_features2)
        features = torch.stack([feature1, feature2], dim=1)

        features = features.reshape(len(features), 1, -1)
        features = F.adaptive_avg_pool1d(features, self.output_dim)
        features = features.reshape(len(features), -1)

        patch_shapes = []
        patch_shapes.append(patch_shapes1)
        patch_shapes.append(patch_shapes2)

        features = self.pre_projection(features)  # torch.Size([4096, 1536])
        self.features = features

        patch_scores = image_scores = -self.discriminator(features)

        patch_scores = patch_scores.cpu().detach().numpy()
        image_scores = image_scores.cpu().detach().numpy()

        image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
        image_scores = image_scores.reshape(*image_scores.shape[:2], -1)

        image_scores = self.score(image_scores)
        image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
        scales = patch_shapes[0]
        # patch_scores = patch_scores.reshape(1, scales[0], scales[1])
        patch_scores = patch_scores.reshape(self.batchsize, scales[0], scales[1])
        # features = features.reshape(1, scales[0], scales[1], -1)
        features = features.reshape(self.batchsize, scales[0], scales[1], -1)

        with torch.no_grad():
            if isinstance(patch_scores, np.ndarray):
                patch_scores = torch.from_numpy(patch_scores)
            _scores = patch_scores.to(self.device)
            _scores = _scores.unsqueeze(1)
            _scores = F.interpolate(
                _scores, size=self.target_size, mode="bilinear", align_corners=False
            )
            _scores = _scores.squeeze(1)
            patch_scores = _scores.cpu().numpy()

            if isinstance(features, np.ndarray):
                features = torch.from_numpy(features)
            features = features.to(self.device).permute(0, 3, 1, 2)
            if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2 ** 31:
                subbatch_size = int((2 ** 31 - 1) / (self.target_size[0] * self.target_size[1] * features.shape[1]))
                interpolated_features = []
                for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)):
                    subfeatures = features[i_subbatch * subbatch_size:(i_subbatch + 1) * subbatch_size]
                    subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures
                    subfeatures = F.interpolate(
                        subfeatures, size=self.target_size, mode="bilinear", align_corners=False
                    )
                    interpolated_features.append(subfeatures)
                features = torch.cat(interpolated_features, 0)
            else:
                features = F.interpolate(
                    features, size=self.target_size, mode="bilinear", align_corners=False
                )
            # features = features.cpu().detach().numpy()
        masks = [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
        masks = torch.tensor(masks)
        return masks  # ,self.patch_shapes

# net = unsupervisedNet(2,False)
# # net.cuda()
# x = torch.randn((2,3,512,512))#.cuda()
# y = net(x)
# print(y.shape)
# pth2onnx(net,x,'test.onnx')
# trace_script_module = torch.jit.trace(net,x)
# trace_script_module.save('net1.torchscript')

你可能感兴趣的:(pytorch,人工智能,python)