MedMNIST:医学领域中的MNIST数据集

本数据集是由上海交通大学(倪冰冰团队)提供,共有十个医学图像分类数据集(分辨率为28*28),由于自己对眼底图片相对来说熟悉一点,所以就先看了一下眼底图片的一些情况。

首先是可视化这28*28的Diabetic Retinopathy(DR)图片

  • 数据来源是ISBI2020 challenge(The 2nd diabetic retinopathy – grading and image quality estimation challenge)
  • 名称:DeepDR Diabetic Retinopathy Image Dataset (DeepDRiD)
  • 数据集下载链接为https://isbi.deepdr.org/data.html

由于原始图片大小为1736*1824,所以MedMNIST中28*28可想而知分辨率有多低
MedMNIST:医学领域中的MNIST数据集_第1张图片
code:

from medmnist import environ
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image


class MedMNIST(Dataset):
    flag = ...
    flag = "retinamnist"

    def __init__(self, split='train', transform=None, target_transform=None):

        npz_file = np.load(os.path.join(environ.dataroot,"{}.npz".format(self.flag)))

        self.split = split
        self.transform = transform
        self.target_transform = target_transform

        if self.split == 'train':
            self.img = npz_file['train_images']
            self.label = npz_file['train_labels']
        elif self.split == 'val':
            self.img = npz_file['val_images']
            self.label = npz_file['val_labels']
        elif self.split == 'test':
            self.img = npz_file['test_images']
            self.label = npz_file['test_labels']

    def __getitem__(self, index):
        img, target = self.img[index], int(self.label[index])
        img = Image.fromarray(np.uint8(img))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return self.img.shape[0]


class PathMNIST(MedMNIST):
    flag = "pathmnist"


class OCTMNIST(MedMNIST):
    flag = "octmnist"


class PneumoniaMNIST(MedMNIST):
    flag = "pneumoniamnist"


class ChestMNIST(MedMNIST):
    flag = "chestmnist"


class DermaMNIST(MedMNIST):
    flag = "dermamnist"


class RetinaMNIST(MedMNIST):
    flag = "retinamnist"


class BreastMNIST(MedMNIST):
    flag = "breastmnist"


class OrganMNIST_Axial(MedMNIST):
    flag = "organmnist_axial"


class OrganMNIST_Coronal(MedMNIST):
    flag = "organmnist_coronal"


class OrganMNIST_Sagittal(MedMNIST):
    flag = "organmnist_sagittal"

if __name__ == '__main__':
    import sys
    import torch.utils.data as data
    from torchvision import transforms
    import matplotlib.pyplot as plt

    train_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    dataclass = {
        "pathmnist": PathMNIST,
        "chestmnist": ChestMNIST,
        "dermamnist": DermaMNIST,
        "octmnist": OCTMNIST,
        "pneumoniamnist": PneumoniaMNIST,
        "retinamnist": RetinaMNIST,
        "breastmnist": BreastMNIST,
        "organmnist_axial": OrganMNIST_Axial,
        "organmnist_coronal": OrganMNIST_Coronal,
        "organmnist_sagittal": OrganMNIST_Sagittal,
    }
    print(dataclass["retinamnist"])

    train_dataset = dataclass["retinamnist"](split='train', transform=train_transform)
    train_loader = data.DataLoader(dataset=train_dataset, batch_size=1, shuffle=False)
    for img,target in train_loader:
        # print(img.size())
        img = img.squeeze().permute(1, 2, 0)  # 先squeeze讲B压缩,然后permute进行维度交换,将(C,H,W)变为(H,W,C)
        # print(img.shape)
        plt.imshow(img)  # plt.imshow(x) x表示shape=(224,224,)
        plt.show()
        print(target)

你可能感兴趣的:(眼底图像,深度学习)