pytorch复习笔记--python类中getitem的用法

目录

1-- 类中__getitem__的作用

2-- 实例

3-- 结合pytorch封装并读取batch数据

4-- 参考


1-- 类中__getitem__的作用

当一个python类中定义了__getitem__函数,则其实例对象能够通过下标来进行索引数据

2-- 实例

代码:

import numpy as np

# 创建类
class Example():
    def __getitem__(self, index):
        data = np.array([[1,2,3], [4,5,6], [7,8,9]])
        return data[index]

# 使用Example类实例对象example1
example1 = Example()

# 索引访问数据
print('example1[0][0]:', example1[0][0])
print('example1[0]:', example1[0])

# 切片访问数据
print('example1[0:2]:\n', example1[0:2])

输出:

example1[0][0]: 1

example1[0]: [1 2 3]

example1[0:2]:
 [[1 2 3]
 [4 5 6]]

3-- 结合pytorch封装并读取batch数据

代码:

import torch
import numpy as np
from torch.utils.data import Dataset

# 创建MyDataset类
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.data = torch.from_numpy(x).float()
        self.label = torch.LongTensor(y)

    def __getitem__(self, idx):
        return self.data[idx], self.label[idx], idx

    def __len__(self):
        return len(self.data)

Train_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
Train_label = np.array([10, 11, 12, 13])
TrainDataset = MyDataset(Train_data, Train_label) # 创建实例对象
print('len:', len(TrainDataset))

# 创建DataLoader
loader = torch.utils.data.DataLoader(
    dataset=TrainDataset,
    batch_size=2,
    shuffle=False,
    num_workers=0,
    drop_last=False)

# 按batchsize打印数据
for batch_idx, (data, label, index) in enumerate(loader):
    print('batch_idx:',batch_idx, '\ndata:',data, '\nlabel:',label, '\nindex:',index)
    print('---------')

输出:

len: 4

batch_idx: 0 
data: tensor([[1., 2., 3.],
        [4., 5., 6.]]) 
label: tensor([10, 11]) 
index: tensor([0, 1])
---------
batch_idx: 1 
data: tensor([[ 7.,  8.,  9.],
        [10., 11., 12.]]) 
label: tensor([12, 13]) 
index: tensor([2, 3])
---------

4-- 参考

参考链接1

你可能感兴趣的:(Pytorch学习笔记,深度学习,python,pytorch)