pytorch基于resnet18预训练模型用于自己的训练数据集进行迁移学习

本文记录利用resnet18预训练模型进行迁移学习,在自己的训练数据集上进行重新训练。相关代码重点部分分别介绍如下:

model=torchvision.models.resnet18(pretrained=True)
num_features=model.fc.in_features
model.fc=nn.Linear(num_features,num_classes)
model=model.to(device)

说明:上述代码首先加载resnet18预训练模型,然后根据训练数据集中的分类数量num_classes修改模型的输出

建议如果有NVIDIA GPU则训练时尽量使用GPU,否则训练时间太久。

如果是使用GPU,则请指定device为GPU,例如:

device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')

然后将model以及输入数据都是用to函数指定到GPU上进行运算:

model=model.to(device)
    for i,(images,labels) in enumerate(train_loader):
        images=images.to(device)
        labels=labels.to(device)

需要指出的是,使用GPU训练出来的模型,如果后续用于分类的时候使用的是CPU,则需要将模型进行转化,否则会报错:

torch.load('tensors.pt', map_location='cpu')

以上为利用resnet18预训练模型训练自己的数据集形成新的特征模型的关键代码,如需完整代码,请关注如下公众号 健哥聊量化,关注之后直接键入: resnet18 即可得到下载链接。该代码中分为两个文件,一个是模型训练的源代码,一个是利用该模型进行分类的源代码。

-------------------- 正文到此结束------------------------

推荐一个公众号:健哥聊量化,会持续推出股票相关基础知识,以及python实现的一些基本的分析代码。欢迎大家关注,二维码如下:

相关文章列表如下:

  • 股票基础知识----- K线形态

  • 股票K线形态 ----早晨之星

  • “早晨之星”实际操作篇---通达信软件为例

  • 牛刀小试----python+tushare进行股票分析

  • 股票K线形态----黄昏之星

  • 股票K线形态-----墓碑线

  • 股票K线形态-----多方炮

  • 股票K线形态-----红三兵

  • 股票K线形态----三只乌鸦

  • 股票K线形态-----锤头线、吊颈线、倒锤头线

你可能感兴趣的:(python,pytorch,人工智能,机器学习,python)