修改pytorch和Keras预训练模型路径

目录

      • 1、Pytorch预训练模型路径修改
      • 2、Keras修改预训练模型位置

1、Pytorch预训练模型路径修改

Pytorch安装目录下有一个hub.py,改文件指定了预训练模型的加载位置。该文件存在于xxx\site-packages\torch,例如我的存在于“C:\ProgramData\Miniconda3\Lib\site-packages\torch”。
打开hub.py文件,找到load_state_dict_from_url函数,其中第二个参数
model_dir用于指定权重文件路径:model_dir (string, optional): directory in which to save the object。将该参数值由None改为权重文件位置即可,例如model_dir=‘D:/Models_Download/torch’。

def load_state_dict_from_url(url, model_dir='D:/Models_Download/torch', map_location=None, progress=True, check_hash=False, file_name=None):
    r"""Loads the Torch serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically
    decompressed.

    If the object is already present in `model_dir`, it's deserialized and
    returned.
    The default value of `model_dir` is ``/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: True
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-.ext`` where ```` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: False
        file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set.

    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """

2、Keras修改预训练模型位置

Keras安装路径内并没有一个文件来定义预训练模型位置,我只能在调用预训练模型的时候指定模型文件的路径(有没有更好的设置方法?)。

base_model = vgg19.VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, 
                         weights='D:\\Models_Download\\keras\\vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')

你可能感兴趣的:(计算机视觉,机器学习,Pytorch,tensorflow,python,深度学习,pytorch)