Pytorch中的collate_fn

DataLoader是拿出batch_size条features和label交给这个函数来collate,所以简单来说,x的最外层是batch个数,x的次外层是特征和label,因此len(x[0])其实是2,以下是一个测试code

import torch.utils.data as Data
import numpy as np
import torch

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])
inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
#inputting shape is [10,3]
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
#target shape is [10,1]
torch_dataset = Data.TensorDataset(inputing,target)
batch = 3
def collate_fn(x):
    # x的最外层是batch个数
    ## x的次外层是特征和label,因此len(x[0])其实是2
    ### x[i][j].unsqueeze(0)实际是为了torch.cat准备的,或者使用torch.stack也行。torch.stack会增加维度,torch.cat不增加维度
    for j in range(len(x[0])):
        print('tmp1={}'.format([x[i][j].unsqueeze(0) for i in range(len(x))]))
        print('tmp2={}'.format(torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))])))
    return x

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=lambda x:(
      torch.cat(
       [x[i][j].unsqueeze(0) for i in range(len(x))]
       ) for j in range(len(x[0]))
    )
)
for features,label in loader:
    # print('features={}'.format(features))
    print('features={} label={}'.format(features,label))

也就是说,默认的collate_fn和下面的等价的:

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch
)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=batch,
    collate_fn=lambda x:(
      torch.cat(
       [x[i][j].unsqueeze(0) for i in range(len(x))]
       ) for j in range(len(x[0]))
    )
)

你可能感兴趣的:(算法,pytorch)