GCC(Graph Contrastive Clustering)论文代码复现

Graph Contrastive Clustering论文代码复现

  • 前言
  • 一、Graph Contrastive Clustering
  • 二、代码复现
    • 1.注意事项
    • 2.utils/mypath.py
    • 3.utils/collate.py
    • 4.data/datasets_imagenet_dogs.py
    • 5.data/datasets_imagenet10.py
  • 总结
  • 参考文献


前言

未来的学习方向主要会围绕深度聚类,而近年来对比学习算法也十分火热,故对比学习也会有所涉及。本文的Graph Contrastive Clustering出自ICCV2021,可见得其兼具图和对比学习的特点,然而这里面的图仅仅体现在样本点之间的邻居关系而并非图神经网络,在代码中表现为使用MemoryBank维护一个KNN矩阵,所以读者没有图神经知识也可以放心食用。

一、Graph Contrastive Clustering

GCC(Graph Contrastive Clustering)论文代码复现_第1张图片
算法整体结构其实很清晰,即首先使用主干网络CNN提取特征,然后通过MLP得到每个样本所表示的向量或者说表征(Representation),通过优化RGC和AGC损失使网络收敛。其中Updated在代码中体现为发掘每个样本的最近5个邻居,然后通过这些邻居计算RGC和AGC即图中的“Guided”指导。蓝色箭头表示把RGC和AGC这两部分单独拿出来解释而已,RGC形似AAAI 2021 Contrasive Clustering一文中的Instance Head实例头,AGC形似Cluster Head。下图为Contrastive Clustering的结构图,可以看出确实是极其相似:
GCC(Graph Contrastive Clustering)论文代码复现_第2张图片

二、代码复现

1.注意事项

作者实验环境为Ubuntu操作系统,硬件配置为一块GeForce RTX3090,电脑内存64G。
GCC这个项目代码只能在Linux操作系统而不能在Windows操作系统运行,因为其中有faiss这个库的存在,其为FaceBook开源项目,目前仅支持Linux操作系统,其提供GPU加速的矩阵检索,可以用于快速寻找K近邻。因为训练涉及MemoryBank所以需要消耗极大的计算机资源——CPU处理能力、内存空间,且此项目不支持断点恢复,因其维护的是一个图结构,这些很多都存在于内存,一旦重新开始训练之前的内存数据无法恢复也就自然无法续上断点,并且训练时间一般都是24小时以上,这都是训练时很头疼的事。

接下来我们开始进行代码复现,因其原项目代码直接跑会出些错误,所以进行如下修正!!!

2.utils/mypath.py

class MyPath(object):
    @staticmethod
    def db_root_dir(database=''):
        db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200', 'tiny_imagenet'}
        assert(database in db_names)

        if database == 'cifar-10':
            return 'gruntdata/dataset'

        elif database == 'cifar-20':
            return 'gruntdata/dataset'

        elif database == 'stl-10':
            return 'gruntdata/dataset'

        elif database == 'tiny_imagenet':
            return 'gruntdata/dataset'

        elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']:
            return 'path/to/imagenet/'

        else:
            raise NotImplementedError

gruntdata/dataset是来存储训练用数据集的目录,而原代码中gruntdata多了“/”导致无法正常读取路径,去掉即可。

3.utils/collate.py

import torch
import numpy as np
import collections
# from torch._six import string_classes, int_classes
string_classes=str
int_classes=int

有可能因为torch版本原因导致第4行报错,这时只需要将其注释掉,补上最后两句代码即可。如果这部分没报错可不予理会。

4.data/datasets_imagenet_dogs.py

import torchvision
from PIL import Image
import numpy as np
from skimage import io
from torch.utils.data.dataset import Dataset


class ImageNetDogs(Dataset):
    base_folder = 'imagenet-dogs'
    class_names_file = 'class_names.txt'
    train_list = [
        ['ImageNetdogs.h5', '918c2871b30a85fa023e0c44e0bee87f'],
        ['ImageNetdogsAll.h5', '918c2871b30a85fa023e0c44e0bee87f'],
    ]

    splits = ('train', 'test', 'train+unlabeled')

    def __init__(self, split='train',
                 transform=None, target_transform=None, download=False):
        if split not in self.splits:
            raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
                split, ', '.join(self.splits),
            ))

        self.transform = transform
        self.target_transform = target_transform
        self.split = split  # train/test/unlabeled set
        
        self.data, self.targets = self.__loadfile()
        print("Dataset Loaded.")

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            dict: {'image': image, 'target': index of target class, 'meta': dict}
        """
        img, target = self.data[index], self.targets[index]
        img_size = (img.shape[0], img.shape[1])
        img = Image.fromarray(np.uint8(img)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index, 'class_name': 'unlabeled'}}

        return out

    def __len__(self):
        return len(self.data)

    def __loadfile(self):
        datas,labels = [],[]
        source_dataset = torchvision.datasets.ImageFolder(root='gruntdata/dataset/ImageNet-dogs/train')

        for line,target in zip(source_dataset.imgs,source_dataset.targets):
            try:
                img = io.imread(line[0])
            except:
                #print(line[0])
                continue
            else:
                datas.append(img)
                labels.append(target)

        return datas, labels

    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)

原文作者并没有详细说明怎么训练ImageNet-dog和ImageNet10数据集,故这里我重写了ImageNetDogs这个类,能正确运行的前提是ImageNet-dog15数据集已经放在了gruntdata/dataset目录下。

5.data/datasets_imagenet10.py

import torchvision
from PIL import Image
import os
import os.path
import numpy as np
from skimage import io, color

import torchvision.datasets as datasets

from torch.utils.data.dataset import Dataset

class ImageNet10(Dataset):
    """`ImageNet10 `_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``stl10_binary`` exists.
        split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
            Accordingly dataset is selected.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
    base_folder = 'imagenet-10'
    class_names_file = 'class_names.txt'
    train_list = [
        ['ImageNet10_112.h5', '918c2871b30a85fa023e0c44e0bee87f'],
    ]

    splits = ('train', 'test')

    def __init__(self, split='train',
                 transform=None, target_transform=None, download=False):
        if split not in self.splits:
            raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
                split, ', '.join(self.splits),
            ))

        self.transform = transform
        self.target_transform = target_transform
        self.split = split  # train/test/unlabeled set

        self.data, self.targets = self.__loadfile()
        print("Dataset Loaded.")

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            dict: {'image': image, 'target': index of target class, 'meta': dict}
        """
        img, target = self.data[index], self.targets[index]
        img_size = (img.shape[0], img.shape[1])
        img = Image.fromarray(np.uint8(img)).convert('RGB')
        # class_name = self.classes[target]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index, 'class_name': 'unlabeled'}}

        return out

    def __len__(self):
        return len(self.data)

    def __loadfile(self):
        datas,labels = [],[]
        source_dataset = torchvision.datasets.ImageFolder(root='gruntdata/dataset/ImageNet-10/train/')

        for line,tar in zip(source_dataset.imgs,source_dataset.targets):
            try:
                img = io.imread(line[0])
                # img = color.gray2rgb(img)
            except:
                print(line[0])
                continue
            else:
                datas.append(img)
                labels.append(tar)

        return datas, labels

    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)

读取ImageNet10的代码,和读取ImageNet-dog数据集的代码是一模一样的。


总结

不出意外,就能顺利地把项目代码运行起来,如有问题欢迎评论区交流。

参考文献

Github深度聚类相关论文代码整理
Contrastive Clustering原论文
Graph Contrastive Clustering原论文

你可能感兴趣的:(人工智能,机器学习,深度学习,python,深度聚类,KNN)