Dive Into MindSpore – ImageFolderDataset For Dataset LoadMindSpore精讲系列–数据集加载之ImageFolderDataset本文开发环境Ubuntu 20.04Python 3.8MindSpore 1.7.0本文内容摘要先看API简单示例深入探究本文总结遇到问题本文参考1. 先看API
下面对主要参数做简单介绍:dataset_dir – 数据集目录num_samples – 读取的样本数,通常选用默认值即可num_paraller_workers – 读取数据采用的线程数,一般为CPU线程数的1/4到1/2shuffle – 是否打乱数据集,还是按顺序读取,默认为None。这里一定要注意,默认None并非是不打乱数据集,这个参数的默认值有点让人困惑。extensions – 图片文件扩展名,可以为多个即list。如[“.JPEG”, “.png”],则读取文件夹相应扩展名的图片文件。if empty, read everything under the dir.class_indexing – 文件夹名到label的索引映射字典decode – 是否对图片数据进行解码,默认为False,即不解码num_shards – 分布式场景下使用,可以认为是GPU或NPU的卡数shard_id – 同上面参数在分布式场景下配合使用,可以认为是GPU或NPU卡的ID2. 简单示例本文使用的是Fruits 360数据集Kaggle 下载地址启智平台 下载地址) – 对于无法访问kaggle的读者,可以采用启智平台。2.1 解压数据将Fruits 360数据集下载后,会得到archive.zip文件,使用unzip -x archive.zip命令进行解压。在同级目录下得到两个文件夹fruits-360_dataset和fruits-360-original-size。使用命令tree -d -L 3 .对数据情况进行简单查看,输出内容如下:.
├── fruits-360_dataset
│ └── fruits-360
│ ├── Lemon
│ ├── papers
│ ├── Test
│ ├── test-multiple_fruits
│ └── Training
└── fruits-360-original-size
└── fruits-360-original-size
├── Meta
├── Papers
├── Test
├── Training
└── Validation
本文将使用fruits-360_dataset文件夹。2.2 最简用法下面对fruits-360_dataset文件夹下的训练集fruits-360/Training进行加载。代码如下:参考1中参数介绍,需要将shuffle参数显示设置为False,否则无法复现。from mindspore.dataset import ImageFolderDataset
def dataset_load(dataset_dir, shuffle=False, decode=False):
dataset = ImageFolderDataset(
dataset_dir=dataset_dir, shuffle=shuffle, decode=decode)
data_size = dataset.get_dataset_size()
print("data size: {}".format(data_size), flush=True)
data_iter = dataset.create_dict_iterator()
item = None
for data in data_iter:
item = data
break
# 打印数据
print(item, flush=True)
def main():
# 注意替换为个人路径
train_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Training"
#####################
# test decode param #
#####################
dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)
if name == "__main__":
main()
将以上代码保存到load.py文件,使用如下命令运行:python3 load.py
输出内容如下:数据集大小为67692,因为该文件夹下只有图片文件,也可以认为有67692个图片。数据包含两个字段:image和label。image字段在decode参数为默认值False时,不对图片解码,所以可以认为是二进制数据,且其shape为一维的。label字段已经进行了数值化转换。data size: 67692
{'image': Tensor(shape=[4773], dtype=UInt8, value= [255, 216, 255, 224, 0, 16, 74, 70, 73, 70, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 255, 219, 0, 67,
0, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 2, 4, 3, 2, 2, 2, 2, 5, 4,
4, 3, 4, 6, 5, 6, 6, 6, 5, 6, 6, 6, 7, 9, 8, 6, 7, 9, 7, 6, 6, 8, 11, 8,
......
251, 94, 126, 219, 218, 84, 16, 178, 91, 197, 168, 248, 91, 193, 130, 70, 243, 163, 144, 177, 104, 229, 186, 224,
121, 120, 1, 92, 34, 146, 78, 229, 201, 92, 21, 175, 220, 146, 112, 51, 65, 32, 117, 52, 112, 69, 117, 66,
10, 10, 200, 241, 234, 213, 157, 105, 243, 72, 40, 162, 138, 178, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138,
40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 15, 255, 217]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}
2.3 是否解码下面将decode参数设置为True,来看看数据情况。将如下代码dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)
修改为dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=True)
使用如下命令,重新运行load.py文件。python3 load.py
输出内容如下:数据集大小同2.2一致。数据包含两个字段:image和label。因为decode参数设置为True,已经对图片进行了解码,可以看到image字段的数据维度和数值已经有了变化。label字段同2.2。data size: 67692
{'image': Tensor(shape=[100, 100, 3], dtype=UInt8, value=
[[[254, 255, 255],
[254, 255, 255],
[254, 255, 255],
...
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]]]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}
- 深入探究在深入探究部分,本文来详细研究一下class_indexing参数,看看这个参数有什么意义。首先本文提出一种异常情况,即训练集内的某个类别文件夹,在验证集/测试集不存在(可能因为数据极度不平衡或人为错误)。那么数据的标签id还能否对应好。3.1 正常测试集针对测试集,我们先做一次label统计。代码如下:import json
from mindspore.dataset import ImageFolderDataset
def label_check(dataset_dir, shuffle=False, decode=False, class_indexing=None):
dataset = ImageFolderDataset(
dataset_dir=dataset_dir, shuffle=shuffle, decode=decode, class_indexing=class_indexing)
data_size = dataset.get_dataset_size()
print("data size: {}".format(data_size), flush=True)
data_iter = dataset.create_dict_iterator()
label_dict = {}
for data in data_iter:
label_id = data["label"].asnumpy().tolist()
label_dict[label_id] = label_dict.get(label_id, 0) + 1
# 打印数据
print("====== label dict ======\n{}".format(label_dict), flush=True)
def main():
# 注意替换为个人路径
test_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Test"
label_check(dataset_dir=test_dataset_dir, shuffle=False, decode=False, class_indexing=None)
if name == "__main__":
main()
将上述代码保存到check.py文件,运行命令:python3 check.py
输出内容如下:数据集大小为22688总共标签id为131data size: 22688
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 59: 164, 60: 166, 61: 166, 62: 166, 63: 166, 64: 166, 65: 142, 66: 102, 67: 166, 68: 246, 69: 164, 70: 164, 71: 160, 72: 218, 73: 178, 74: 150, 75: 155, 76: 146, 77: 160, 78: 164, 79: 166, 80: 164, 81: 246, 82: 164, 83: 164, 84: 232, 85: 166, 86: 234, 87: 102, 88: 166, 89: 222, 90: 237, 91: 166, 92: 166, 93: 148, 94: 234, 95: 222, 96: 222, 97: 164, 98: 164, 99: 166, 100: 163, 101: 166, 102: 151, 103: 142, 104: 304, 105: 164, 106: 153, 107: 150, 108: 151, 109: 150, 110: 150, 111: 166, 112: 164, 113: 166, 114: 164, 115: 162, 116: 164, 117: 246, 118: 166, 119: 166, 120: 246, 121: 225, 122: 246, 123: 160, 124: 164, 125: 228, 126: 127, 127: 153, 128: 158, 129: 249, 130: 157}
3.2 异常测试集为了进行测试,人为制造一些异常,将Test文件夹下的Lemon数据文件夹移动到上层目录。命令如下:cd {your_path}/fruits-360_dataset/fruits-360/Test
mv Lemon ../
3.2.1 未指定class_indexing再次运行3.1中的check.py文件,输出内容如下:数据大小为22524总共标签id为130data size: 22524
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 59: 166, 60: 166, 61: 166, 62: 166, 63: 166, 64: 142, 65: 102, 66: 166, 67: 246, 68: 164, 69: 164, 70: 160, 71: 218, 72: 178, 73: 150, 74: 155, 75: 146, 76: 160, 77: 164, 78: 166, 79: 164, 80: 246, 81: 164, 82: 164, 83: 232, 84: 166, 85: 234, 86: 102, 87: 166, 88: 222, 89: 237, 90: 166, 91: 166, 92: 148, 93: 234, 94: 222, 95: 222, 96: 164, 97: 164, 98: 166, 99: 163, 100: 166, 101: 151, 102: 142, 103: 304, 104: 164, 105: 153, 106: 150, 107: 151, 108: 150, 109: 150, 110: 166, 111: 164, 112: 166, 113: 164, 114: 162, 115: 164, 116: 246, 117: 166, 118: 166, 119: 246, 120: 225, 121: 246, 122: 160, 123: 164, 124: 228, 125: 127, 126: 153, 127: 158, 128: 249, 129: 157}
解读:仔细观察,可以看出3.2.1中的数据标签id已经同3.1中不同,也就是说如果我们是在训练后进行测试,那么标签id已经出错,测试结果肯定相当糟糕。3.2.2 指定class_indexing备注:这里我们默认训练数据集也使用了class_indexing字典文件进行数据加载,或者加载的标签ID与我们后期生成的一致。为了能够与训练集的标签id保持一致,我们先利用训练集来生成class_indexing字典文件。生成代码如下:import json
import os
def make_class_indexing_file(dataset_dir, class_indexing_file):
class_names = []
for dir_or_file in os.listdir(dataset_dir):
if os.path.isfile(dir_or_file):
continue
class_names.append(dir_or_file)
sorted_class_names = sorted(class_names)
print("num_classes: {}\n{}".format(len(sorted_class_names), "\n".join(sorted_class_names)), flush=True)
class_indexing_dict = dict(zip(sorted_class_names, list(range(len(sorted_class_names)))))
print("class_indexing dict: {}".format(class_indexing_dict), flush=True)
with open(class_indexing_file, "w", encoding="UTF8") as fp:
json.dump(class_indexing_dict, fp, indent=4, separators=(",", ": "))
def main():
train_dataset_dir = "{your_path}/Fruits_360/fruits-360_dataset/fruits-360/Training"
class_indexing_file = "{your_path}/Fruits_360/fruits-360_dataset/class_indexing.json"
make_class_indexing_file(dataset_dir=dataset_dir, class_indexing_file=class_indexing_file)
if name == "__main__":
main()
保存代码到make_class_indexing.py文件,运行命令:python3 make_class_indexing.py
备注:生成的字典文件为{your_path}/Fruits_360/fruits-360_dataset/class_indexing.json,读者可自行更改路径。有了字典文件,再次修改check.py文件,修改为:import json
from mindspore.dataset import ImageFolderDataset
def label_check(dataset_dir, shuffle=False, decode=False, class_indexing=None):
dataset = ImageFolderDataset(
dataset_dir=dataset_dir, shuffle=shuffle, decode=decode, class_indexing=class_indexing)
data_size = dataset.get_dataset_size()
print("data size: {}".format(data_size), flush=True)
data_iter = dataset.create_dict_iterator()
label_dict = {}
for data in data_iter:
label_id = data["label"].asnumpy().tolist()
label_dict[label_id] = label_dict.get(label_id, 0) + 1
# 打印数据
print("====== label dict ======\n{}".format(label_dict), flush=True)
def load_class_indexing_file(class_indexing_file):
with open(class_indexing_file, "r", encoding="UTF8") as fp:
class_indexing_dict = json.load(fp)
print("====== class_indexing_dict: ======\n{}".format(class_indexing_dict), flush=True)
return class_indexing_dict
def main():
# 注意替换为个人路径
test_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Test"
class_indexing_file = "{your_path}/fruits-360_dataset/class_indexing.json"
class_indexing_dict = load_class_indexing_file(class_indexing_file)
label_check(dataset_dir=test_dataset_dir, shuffle=False, decode=False, class_indexing=class_indexing_dict)
if name == "__main__":
main()
再次运行check.py文件,输出内容如下:数据大小同3.2.1中相同数据总标签id为131其中标签id为59数据为零,也就是我们上面移除的数据。data size: 22524
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 60: 166, 61: 166, 62: 166, 63: 166, 64: 166, 65: 142, 66: 102, 67: 166, 68: 246, 69: 164, 70: 164, 71: 160, 72: 218, 73: 178, 74: 150, 75: 155, 76: 146, 77: 160, 78: 164, 79: 166, 80: 164, 81: 246, 82: 164, 83: 164, 84: 232, 85: 166, 86: 234, 87: 102, 88: 166, 89: 222, 90: 237, 91: 166, 92: 166, 93: 148, 94: 234, 95: 222, 96: 222, 97: 164, 98: 164, 99: 166, 100: 163, 101: 166, 102: 151, 103: 142, 104: 304, 105: 164, 106: 153, 107: 150, 108: 151, 109: 150, 110: 150, 111: 166, 112: 164, 113: 166, 114: 164, 115: 162, 116: 164, 117: 246, 118: 166, 119: 166, 120: 246, 121: 225, 122: 246, 123: 160, 124: 164, 125: 228, 126: 127, 127: 153, 128: 158, 129: 249, 130: 157}
- 本文总结本文主要讲解了MindSpore中的ImageFolderDataset数据集接口,并对其中的两个参数decode和class_indexing进行了深入探究。一个小建议:笔者建议用户在使用ImageFolderDataset进行数据集加载时,人为指定class_indexing参数。毕竟相关字典文件的生成并没有几行代码,但对于类别数不一致的预训练模型(比如ImageNet22k和1k)或测试集出现人为问题的情况,可以有更好的保留空间。5. 遇到问题shuffle参数默认为None,却是对数据集进行了打乱,有点让人费解。6. 本文参考官方文档本文为原创文章,版权归作者所有,未经授权不得转载!