pytorch生成对抗网络生成动漫图像

代码地址:pytorch实战,使用生成对抗网络生成动漫图像

dataset

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
import PIL.Image as Image

class trainData(Dataset):
    def __init__(self, root, transform=transforms.ToTensor()):
        self.file_names = []
        self.transforms = transform
        for file in os.listdir(root):
            file_name = os.path.join(root, file)
            self.file_names.append(file_name)

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, index):
        file = self.file_names[index]
        img_array = Image.open(file)

        xs = self.transforms(img_array)

        return xs

if __name__=='__main__':
    import matplotlib.pyplot as plt
    import numpy as np
    import random

你可能感兴趣的:(深度学习,pytorch,生成对抗网络,深度学习)