元学习示例:maml_miniimagenet无法下载数据的解决办法

        在learn2learn库中,有一个示例:maml_miniimagenet.py。在运行该示例时,总是不能下载到miniimagenet数据,这里可以直接从源码中找到下载链接。源码文件为learn2learn/vision/datasets/mini_imagenet.py。具体代码部分如下

    def __init__(
        self,
        root,
        mode='train',
        transform=None,
        target_transform=None,
        download=False,
    ):
        super(MiniImagenet, self).__init__()
        self.root = os.path.expanduser(root)
        if not os.path.exists(self.root):
            os.mkdir(self.root)
        self.transform = transform
        self.target_transform = target_transform
        self.mode = mode
        self._bookkeeping_path = os.path.join(self.root, 'mini-imagenet-bookkeeping-' + mode + '.pkl')
        if self.mode == 'test':
            google_drive_file_id = '1wpmY-hmiJUUlRBkO9ZDCXAcIpHEFdOhD'
            dropbox_file_link = 'https://www.dropbox.com/s/ye9jeb5tyz0x01b/mini-imagenet-cache-test.pkl?dl=1'
            zenodo_file_link = 'https://zenodo.org/record/7978538/files/mini-imagenet-cache-test.pkl'
        elif self.mode == 'train':
            google_drive_file_id = '1I3itTXpXxGV68olxM5roceUMG8itH9Xj'
            dropbox_file_link = 'https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1'
            zenodo_file_link = 'https://zenodo.org/record/7978538/files/mini-imagenet-cache-train.pkl'
        elif self.mode == 'validation':
            google_drive_file_id = '1KY5e491bkLFqJDp0-UWou3463Mo8AOco'
            dropbox_file_link = 'https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1'
            zenodo_file_link = 'https://zenodo.org/record/7978538/files/mini-imagenet-cache-validation.pkl'
        else:
            raise ValueError('Needs to be train, test or validation')

        pickle_file = os.path.join(self.root, 'mini-imagenet-cache-' + mode + '.pkl')
        try:
            if not self._check_exists() and download:
                print('Downloading mini-ImageNet --', mode)
                download_file(dropbox_file_link, pickle_file)
            with open(pickle_file, 'rb') as f:
                self.data = pickle.load(f)
        except Exception:
            try:
                if not self._check_exists() and download:
                    print('Downloading mini-ImageNet --', mode)
                    download_pkl(google_drive_file_id, self.root, mode)
                with open(pickle_file, 'rb') as f:
                    self.data = pickle.load(f)
            except pickle.UnpicklingError:
                if not self._check_exists() and download:
                    print('Download failed. Re-trying mini-ImageNet --', mode)
                    download_file(dropbox_file_link, pickle_file)
                with open(pickle_file, 'rb') as f:
                    self.data = pickle.load(f)

        self.x = torch.from_numpy(self.data["image_data"]).permute(0, 3, 1, 2).float()
        self.y = np.ones(len(self.x))

        # TODO Remove index_classes from here
        self.class_idx = index_classes(self.data['class_dict'].keys())
        for class_name, idxs in self.data['class_dict'].items():
            for idx in idxs:
                self.y[idx] = self.class_idx[class_name]

        从上边的代码中可以发现,learn2learn一共提供了三种下载地址,但实际上我们国内所能访问的只有最后一种zenodo,至少在我的电脑上google drive和drop box都是无法连接的。而learn2learn下载部分的代码其实没有使用zemodo这个地址,只是使用了前两种,导致运行程序时不会自动的下载数据。

        解决办法很简单,就是直接将zenodo开头的这三个地址直接复制到浏览器中,就可以将train,validation,test这三个pkl文件下载到了。将下载好的文件放到home/data路径下,那么再次运行程序的时候就不需要下载了,就可以开心的进行训练了!

你可能感兴趣的:(元学习,learn2learn)