未来的学习方向主要会围绕深度聚类,而近年来对比学习算法也十分火热,故对比学习也会有所涉及。本文的Graph Contrastive Clustering出自ICCV2021,可见得其兼具图和对比学习的特点,然而这里面的图仅仅体现在样本点之间的邻居关系而并非图神经网络,在代码中表现为使用MemoryBank维护一个KNN矩阵,所以读者没有图神经知识也可以放心食用。
算法整体结构其实很清晰,即首先使用主干网络CNN提取特征,然后通过MLP得到每个样本所表示的向量或者说表征(Representation),通过优化RGC和AGC损失使网络收敛。其中Updated在代码中体现为发掘每个样本的最近5个邻居,然后通过这些邻居计算RGC和AGC即图中的“Guided”指导。蓝色箭头表示把RGC和AGC这两部分单独拿出来解释而已,RGC形似AAAI 2021 Contrasive Clustering一文中的Instance Head实例头,AGC形似Cluster Head。下图为Contrastive Clustering的结构图,可以看出确实是极其相似:
作者实验环境为Ubuntu操作系统,硬件配置为一块GeForce RTX3090,电脑内存64G。
GCC这个项目代码只能在Linux操作系统而不能在Windows操作系统运行,因为其中有faiss这个库的存在,其为FaceBook开源项目,目前仅支持Linux操作系统,其提供GPU加速的矩阵检索,可以用于快速寻找K近邻。因为训练涉及MemoryBank所以需要消耗极大的计算机资源——CPU处理能力、内存空间,且此项目不支持断点恢复,因其维护的是一个图结构,这些很多都存在于内存,一旦重新开始训练之前的内存数据无法恢复也就自然无法续上断点,并且训练时间一般都是24小时以上,这都是训练时很头疼的事。
接下来我们开始进行代码复现,因其原项目代码直接跑会出些错误,所以进行如下修正!!!
class MyPath(object):
@staticmethod
def db_root_dir(database=''):
db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200', 'tiny_imagenet'}
assert(database in db_names)
if database == 'cifar-10':
return 'gruntdata/dataset'
elif database == 'cifar-20':
return 'gruntdata/dataset'
elif database == 'stl-10':
return 'gruntdata/dataset'
elif database == 'tiny_imagenet':
return 'gruntdata/dataset'
elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']:
return 'path/to/imagenet/'
else:
raise NotImplementedError
gruntdata/dataset是来存储训练用数据集的目录,而原代码中gruntdata多了“/”导致无法正常读取路径,去掉即可。
import torch
import numpy as np
import collections
# from torch._six import string_classes, int_classes
string_classes=str
int_classes=int
有可能因为torch版本原因导致第4行报错,这时只需要将其注释掉,补上最后两句代码即可。如果这部分没报错可不予理会。
import torchvision
from PIL import Image
import numpy as np
from skimage import io
from torch.utils.data.dataset import Dataset
class ImageNetDogs(Dataset):
base_folder = 'imagenet-dogs'
class_names_file = 'class_names.txt'
train_list = [
['ImageNetdogs.h5', '918c2871b30a85fa023e0c44e0bee87f'],
['ImageNetdogsAll.h5', '918c2871b30a85fa023e0c44e0bee87f'],
]
splits = ('train', 'test', 'train+unlabeled')
def __init__(self, split='train',
transform=None, target_transform=None, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.transform = transform
self.target_transform = target_transform
self.split = split # train/test/unlabeled set
self.data, self.targets = self.__loadfile()
print("Dataset Loaded.")
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
dict: {'image': image, 'target': index of target class, 'meta': dict}
"""
img, target = self.data[index], self.targets[index]
img_size = (img.shape[0], img.shape[1])
img = Image.fromarray(np.uint8(img)).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index, 'class_name': 'unlabeled'}}
return out
def __len__(self):
return len(self.data)
def __loadfile(self):
datas,labels = [],[]
source_dataset = torchvision.datasets.ImageFolder(root='gruntdata/dataset/ImageNet-dogs/train')
for line,target in zip(source_dataset.imgs,source_dataset.targets):
try:
img = io.imread(line[0])
except:
#print(line[0])
continue
else:
datas.append(img)
labels.append(target)
return datas, labels
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
原文作者并没有详细说明怎么训练ImageNet-dog和ImageNet10数据集,故这里我重写了ImageNetDogs这个类,能正确运行的前提是ImageNet-dog15数据集已经放在了gruntdata/dataset目录下。
import torchvision
from PIL import Image
import os
import os.path
import numpy as np
from skimage import io, color
import torchvision.datasets as datasets
from torch.utils.data.dataset import Dataset
class ImageNet10(Dataset):
"""`ImageNet10 `_ Dataset.
Args:
root (string): Root directory of dataset where directory
``stl10_binary`` exists.
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
Accordingly dataset is selected.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'imagenet-10'
class_names_file = 'class_names.txt'
train_list = [
['ImageNet10_112.h5', '918c2871b30a85fa023e0c44e0bee87f'],
]
splits = ('train', 'test')
def __init__(self, split='train',
transform=None, target_transform=None, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.transform = transform
self.target_transform = target_transform
self.split = split # train/test/unlabeled set
self.data, self.targets = self.__loadfile()
print("Dataset Loaded.")
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
dict: {'image': image, 'target': index of target class, 'meta': dict}
"""
img, target = self.data[index], self.targets[index]
img_size = (img.shape[0], img.shape[1])
img = Image.fromarray(np.uint8(img)).convert('RGB')
# class_name = self.classes[target]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index, 'class_name': 'unlabeled'}}
return out
def __len__(self):
return len(self.data)
def __loadfile(self):
datas,labels = [],[]
source_dataset = torchvision.datasets.ImageFolder(root='gruntdata/dataset/ImageNet-10/train/')
for line,tar in zip(source_dataset.imgs,source_dataset.targets):
try:
img = io.imread(line[0])
# img = color.gray2rgb(img)
except:
print(line[0])
continue
else:
datas.append(img)
labels.append(tar)
return datas, labels
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
读取ImageNet10的代码,和读取ImageNet-dog数据集的代码是一模一样的。
不出意外,就能顺利地把项目代码运行起来,如有问题欢迎评论区交流。
Github深度聚类相关论文代码整理
Contrastive Clustering原论文
Graph Contrastive Clustering原论文