Backbone共享参数,代码出现的一个错误

代码地址:https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/blob/master/pytorch%20version/DSIFN.py

这块是共享一个backbone特征提取网络

class DSIFN(nn.Module):
    def __init__(self, model_A, model_B):
        super().__init__()
        self.t1_base = model_A
        self.t2_base = model_B
        self.sa1 = SpatialAttention()
        self.sa2= SpatialAttention()
        self.sa3 = SpatialAttention()
        self.sa4 = SpatialAttention()
        self.sa5 = SpatialAttention()

我跑代码,将其改成

class DSIFN(nn.Module):
    def __init__(self): # 删掉两个参数
        super().__init__()
        self.t1_base = vgg16_base() # 直接在这里定义,将其写死
        self.t2_base = vgg16_base()
        
        self.sa1 = SpatialAttention()
        self.sa2= SpatialAttention()
        self.sa3 = SpatialAttention()
        self.sa4 = SpatialAttention()
        self.sa5 = SpatialAttention()

因为是共享参数的骨干网络,如果这样写的话,就是不共享参数的网络了

正确做法应该是,初始化的时候,用同一个vgg来初始化

vgg = vgg16_base()
net = DSIFN(vgg, vgg) # 共享同一个vgg

你可能感兴趣的:(深度学习,深度学习,人工智能,机器学习)