目录
函数-dir()、help()
Dataset类
dir()
函数,打开工具箱(例如PyTorch,进一步打开某一些分隔区)
help()
函数,查看工具包中某一个工具函数的用法
(1) 查看torch工具包有哪些分割区
dir(torch)
# ['AVG', 'AggregationType', 'AnyType', 'Argument', 'ArgumentSpec', 'BFloat16Storage', 'BFloat16Tensor',...]
(2) 查看torch.cuda有哪些分隔区
dir(torch.cuda)
# ['Any', 'BFloat16Storage', 'BFloat16Tensor', 'BoolStorage', 'BoolTensor', 'ByteStorage', ...]
(3) 查看torch.cuda.is_available()有哪些分隔区
dir(torch.cuda.is_available()) # 函数后面的()去掉,效果一样
# ['__abs__', '__add__', '__and__', '__bool__', '__ceil__', '__class__', ...]
此时发现前后都是带有两个下划线的:__
这说明是规定好不可更改的,也就说明是torch.cuda.is_available
不再是一个分隔区而是一个函数,因此可调用help()
来查看该函数的基本作用。
help(torch.cuda.is_available) # 注意这后面不能跟有()
# 打印结果,该函数会返回一个bool值
# Help on function is_available in module torch.cuda:
# is_available() -> bool
# Returns a bool indicating if CUDA is currently available.
手写加载数据集的类:MyData。主要是要重写__init__()、__getitem__()、__len()__这3个类
get到一个小技巧,可以直接用+对两个Data类进行拼接(可用于数据集不足时,直接将两个数据集这样加起来一起使用)
new_path = os.path.join(path1,path2,...)将所有路径联合起来,返回一个整合路径(str)
file_name_list = os.listdir(path)读取path路径中的所有文件名称,返回一个名称列表(list)
from torch.utils.data import Dataset
from PIL import Image
import os
# 构造一个子文件夹数据集类MyData
class MyData(Dataset):
def __init__(self, root_dir, label_dir): # root_dir是指整个数据集的根目录,label_dir是指具体某一个类的子目录
# 在init初始化函数中,定义一些类中的全局变量,即跟在self.后的变量们
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_list = os.listdir(self.path)
def __getitem__(self, index): # 传入下标获取元素
img_name = self.img_list[index]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label[:-6] # 返回的是一个元组
# 这里进行了截取,因为我不想要label_dir最后面的'_image'这6个元素
def __len__(self):
return len(self.img_list)
# --------------实例化ants_data和bees_data------------- #
root_dir = 'hymenoptera_data/train'
ants_dir = 'ants'
bees_dir = 'bees'
ants_data = Mydata(root_dir, ants_dir)
bees_data = Mydata(root_dir, bees_dir)
# ---------------------------------------------------- #
# -------------返回一个元组,分别赋值给img和label------- #
img, label = ants_data[0]
img.show()
# ----------------------------------------------------- #
# ---因为是元组,所以可用[0]、[1]直接提取出img、label---- #
print(label == ants_data[0][1]) # true
# ----------------------------------------------------- #
# ----------将ants_data和bees_data相加起来使用---------- #
sum = ants_data + bees_data
len_ants = len(ants_data) # 124
len_bees = len(bees_data) # 121
len_sum = len(sum) # 245
print(len_sum == len_ants+len_bees) # True
print(sum[123][1]) # ants
print(sum[124][1]) # bees