pytorch lightning之加载权重

迁移学习

使用预训练的模型分两种情况,一种是pytorch训练的,lightning的模型也包括其中,另外一种是第三方的。第三方的详见对应的使用说明。

pytorch模型

LightningModules 也是nn.Module的子类,可以直接使用

LightningModule

调用load_from_checkpoint()方法,它是LightningModule中实现的一个方法。

class Encoder(torch.nn.Module):
    ...


class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()


class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # init the pretrained LightningModule
        # 加载预训练权重load_from_checkpoint
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

你可能感兴趣的:(pytorch,lightning)