transforms的compose类

def load_data(batch_size,resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0,transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=False)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=False)
    return data.DataLoader(mnist_train,batch_size),data.DataLoader(mnist_test,batch_size)

就是把多个transform类集成到一起了,像nn.Sequential()

你可能感兴趣的:(深度学习,python)