我们在看别人写的代码时,在类中经常会看到__getitem__方法,这个方法的作用是,可以将类中的数据 像数组一样读出,以下进行代码演示:
在类中创建__getitem__方法,并使用数组形式读取类中的数据:
class Test():
def __init__(self):
self.a=[1,2,3,4,5]
def __getitem__(self,idx):
return(self.a[idx])
data=Test()
print(data)
print(data[0])
输出结果为:
<__main__.Test object at 0x000002BBB3256DF0>
1
在类中不创建__getitem__方法,并使用数组形式读取类中的数据:
class Test():
def __init__(self):
self.a=[1,2,3,4,5]
# def __getitem__(self,idx):
# return(self.a[idx])
data=Test()
print(data)
print(data[0])
输出结果报错:
<__main__.Test object at 0x000002BBB45C2C70>
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [20], in ()
6 data=Test()
7 print(data)
----> 8 print(data[0])
TypeError: 'Test' object is not subscriptable |
在类中不创建__getitem__方法,并不使用数组形式读取类中的数据:
class Test():
def __init__(self):
self.a=[1,2,3,4,5]
# def __getitem__(self,idx):
# return(self.a[idx])
data=Test()
print(data)
输出结果:
<__main__.Test object at 0x000002BBB43AFC70>
总结:
类中的__getitem__方法是为了将类中的数据可以用数组的形式读出,如果不使用数组的方法读类中的数据,那么就不需要在类中创建__getitem__方法
如何将__getitem__与dataloader结合使用
import torch
import numpy as np
from torch.utils.data import Dataset
# 创建MyDataset类
class MyDataset(Dataset):
def __init__(self, x, y):
self.data = torch.from_numpy(x).float()
self.label = torch.LongTensor(y)
def __getitem__(self, idx):
return self.data[idx], self.label[idx], idx
def __len__(self):
return len(self.data)
Train_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
Train_label = np.array([10, 11, 12, 13])
TrainDataset = MyDataset(Train_data, Train_label) # 创建实例对象
print('len:', len(TrainDataset))
# 创建DataLoader
loader = torch.utils.data.DataLoader(
dataset=TrainDataset,
batch_size=2,
shuffle=False,
num_workers=0,
drop_last=False)
# 按batchsize打印数据
for batch_idx, (data, label, index) in enumerate(loader):
print('batch_idx:',batch_idx, '\ndata:',data, '\nlabel:',label, '\nindex:',index)
print('---------')
输出结果:
len: 4
batch_idx: 0
data: tensor([[1., 2., 3.],
[4., 5., 6.]])
label: tensor([10, 11])
index: tensor([0, 1])
---------
batch_idx: 1
data: tensor([[ 7., 8., 9.],
[10., 11., 12.]])
label: tensor([12, 13])
index: tensor([2, 3])
---------
https://blog.csdn.net/weixin_43863869/article/details/125602643