【记录一下数据和模型测试代码】

这个是模型测试代码,自己设计完模型可以用下面代码进行测试

if __name__ == '__main__':
    input = torch.randn(16, 3, 256, 256)#图片大小可以改
    net = ViTResNet([3,3,3],BasicBlock)#给定自己设计的模型的参数,如果class model(in_channel,out_channel),就可以写 model(3,1)#3可以换成自己输入数据的通道数,1换成自己输出数据的通道数
    output = net(input)
    print(output.shape)

这个是数据读取测试的代码,自己设计玩自己数据集读取的代码,一般情况就不用改了,测试代码如下:

if __name__ == "__main__":
    for i in range(1,11):
        test_dataset = ImageFolder(#imageFolder是自己编写的数据集的代码,下面是自己的参数
             root_path='/data/image',
            mask_path='/data/mask',
            class_csv='/data//label.csv',
            fold_json='/data//3D_patient_folds.json',
            fold_num=i,
            csv_file_path='/data/survival.csv',
            mode='train')

        test_dataloader = get_loader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=True
        )
        num=0
        for i,j,n,m,s in test_dataloader:#定义了数据集有几个输出就for几个变量,如果顺利运行就非常可以
            num=num+1
            print(i.shape)
            print(j.shape)
            print(n.shape)
            print(m)
            print(s)
            print(num)
            continue

你可能感兴趣的:(python,深度学习,数据读取)