记录一下学习深度学习的一些。本篇简述如何使用 cleanlab 清洗分类数据
所使用环境:
首先将上篇所说的猫狗大战的训练集,按猫狗分为0/1两个目录,父目录命名为 train,接着将两类图片分类,按名字检索一下就可以分开。
因为图片的大小不一,所以在进行交叉验证训练的时候,将图片缩放至统一尺寸 224 × 224 224 \times 224 224×224,方便训练。
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
psx
python train_crossval.py -a efficientnet-b6 -o output -j 4 -b 20 --epochs 10 --lr 0.001 --pretrained --world-size 1 --rank 0 --dist-url tcp://localhost:8001 --multiprocessing-distributed --cvn 5 --cv 0 --num-classes 2 data
修改参数 cv
,直至 k 折交叉验证均完成,相应的 out-of-sample 预测概率 P ^ k , i \hat{P}_{k,i} P^k,i 的 npy
文件会生成至 output
参数所指定的目录。这里有一点不同的是,官方提供的实例是使用 val 目录下的数据集来调优的,而我用交叉验证的 holdout 部分,所以效果可能会有所下降。
npy
文件,生成完整的 psx
python train_crossval.py -a efficientnet-b6 --cvn 5 -o output --num-classes 11 --combine-folds data
python get_error_labels.py --psx output/train__model_efficientnet-b6__pyx.npy -db output/test.db -t test data/train
Overall accuracy: 98.12%
Total: 101
可以看到通过 cleanlab,25,000 张训练图像中有 101 的错误标注,通过图来验证一下。
原始标签为猫的,cleanlab 预测为狗的:
nrows = math.ceil(num_cats_dogs / ncols)
plt.figure(figsize=(24,16))
for i, (name, given, p_given, pred, p_pred) in enumerate(cats_dogs, start=1):
path = os.path.join('data/train', '0', name)
img = Image.open(path).convert("RGB")
plt.subplot(nrows, ncols, i)
plt.imshow(img)
plt.axis('off')
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
plt.show()
原始标签为狗,cleanlab 预测为猫的:
nrows = math.ceil(num_dogs_cats / ncols)
plt.figure(figsize=(24,16))
for i, (name, given, p_given, pred, p_pred) in enumerate(dogs_cats, start=1):
path = os.path.join('data/train', '1', name)
img = Image.open(path).convert("RGB")
plt.subplot(nrows, ncols, i)
plt.imshow(img)
plt.axis('off')
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
plt.show()
可以看到上述结果中,有部分被标注错误了的,还有一部分的卡通图及其他非猫/狗图,还有一部分同时出现两种的,当然也还有有一部分出错了,但还是具有一定的参考意义。
将使用 cleanlab 清洗数据的流程写入了 sh,可以通过修改 sh 的相关变量,再调用 sh 来完成数据清洗。其中,train.sh 为单标签分类的清洗流程,train_m.sh 为多标签的清洗流程。