通过图片数据集生成标签csv文件

通过图片数据集生成labels的csv文件

导入包

import os
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torchvision.io import read_image
import torch
import torchvision
import numpy as np
from PIL import Image
import pandas as pd

创建一个三维度的数组保存csv文件的信息

三个维度分别是文件路径,文件名,类别标签(数值化)

info_array = []#3维的数组,三个维度分别是文件路径,文件名,类别标签(数值化)

os.listdir()遍历文件夹

dataset_dir = '/Users/logic/Documents/多模态/数据集/Caltech-101/RandomizedCaltech101'
#读取文件路径下的多个类别(每个类别一个文件夹)
classes = os.listdir(dataset_dir)
print("image classes length:",len(classes))
for kindname in classes:
    #获取每个类别文件夹的路径
    if(kindname.startswith('.')):
        print("pass .DStore file")
    else:
        classpath = dataset_dir + '/' + kindname
        for filename in os.listdir(classpath):
            #读取每一个类的文件夹中的每一个图片文件的路径信息
            filepath = classpath+'/'+filename
            label = classes.index(kindname)#把label的字符串标签转化成数字标签
            info_array.append([filename,filepath,label])


转换成numpy数组,查看大小维度和内容

info_array = np.array(info_array)
info_array.shape
print(info_array)

输出如下:
通过图片数据集生成标签csv文件_第1张图片
8677张图片,含有3维信息
三个维度分别是文件路径,文件名,类别标签(数值化)
在这里插入图片描述
生成csv文件:

df=pd.DataFrame(info_array,columns=col)
df.to_csv("../Desktop/dataset.csv",encoding='UTF-8')

效果如下:

完整代码黏贴在最后:

import os
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torchvision.io import read_image
import torch
import torchvision
import numpy as np
from PIL import Image
import pandas as pd

 
info_array = []#3维的数组,三个维度分别是文件路径,文件名,类别标签(数值化)
col =['filename','filepath','label']

dataset_dir = '/Users/logic/Documents/多模态/数据集/Caltech-101/RandomizedCaltech101'
#读取文件路径下的多个类别(每个类别一个文件夹)
classes = os.listdir(dataset_dir)
print("image classes length:",len(classes))
for kindname in classes:
    #获取每个类别文件夹的路径
    if(kindname.startswith('.')):
        print("pass .DStore file")
    else:
        classpath = dataset_dir + '/' + kindname
        for filename in os.listdir(classpath):
            #读取每一个类的文件夹中的每一个图片文件的路径信息
            filepath = classpath+'/'+filename
            label = classes.index(kindname)#把label的字符串标签转化成数字标签
            info_array.append([filename,filepath,label])

info_array = np.array(info_array)
info_array.shape
# print(info_array)

df=pd.DataFrame(info_array,columns=col)
df.to_csv("../Desktop/dataset.csv",encoding='UTF-8')

你可能感兴趣的:(python,python,深度学习,pytorch)