如何才能使用内存小或者显存小的设备训练神经网络

每一轮就清空一次,加载新的数据集 每轮保存一次模型,每轮都加载预训练模型 或者在每次输入数据前使用 torch.cuda.empty_cache()这句话及时的清理显存

深度学习与PyTorch实战

        for epoch in range(60):
            torch.cuda.empty_cache()
            if epoch:
                train_data = MyDataset(epoch)
                self.train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True, num_workers=10)
                self.train_length = len(self.train_loader)
def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(Dataset):
    # 构造函数带有默认参数
    def __init__(self,epoch=0,transform=None, target_transform=None, loader=default_loader):
        path_list = os.listdir(
            "/home/chenyang/PycharmProjects/openpose_pruning/openpose_openface_net/save_model_weight")

        imgs = []

        path_list.remove(".pth")
        for one_image in path_list[epoch*256:(epoch+1)*256]:
            data=torch.load("/home/chenyang/PycharmProjects/openpose_pruning/openpose_openface_net/save_model_weight/"+one_image)

            imgs.append(("/home/chenyang/PycharmProjects/coco2017/train2017/"+one_image[:-4], data.get("pafs"),data.get("heatmaps")))  # imgs中包含有图像路径和标签
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    def __getitem__(self, index):
        fn, label1,label2= self.imgs[index]
        # 调用定义的loader方法
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img, label1,label2
    def __len__(self):
        return len(self.imgs)

你可能感兴趣的:(日常)