def DataLoader(data_place):
"""
Define dataloder which is applicable to your data
### ouput
images : 4 dimension tensor (the number of image x channel x image_height x image_width)
id_labels : one-hot vector with Nd dimension
pose_labels : one-hot vetor with Np dimension
Nd : the nuber of ID in the data
Np : the number of discrete pose in the data
Nz : size of noise vector (Default in the paper is 50)
"""
# Nd = []
# Np = []
# Nz = []
# channel_num = []
# images = []
# id_labels = []
# pose_labels = []
# mycase
Nz = 50
channel_num = 3
images = np.load('{}/images.npy'.format(data_place)) # default ./data/images.npy
id_labels = np.load('{}/ids.npy'.format(data_place))
pose_labels = np.load('{}/yaws.npy'.format(data_place))
# 一共有Nd个人,每个人都有Np个角度,通过人和角度的两个one-hot向量可以确定一个有着某个角度的人,这也就是需要提供的额外信息,要告诉生成器,我给你Nd个人,每个人分别都是这些姿势(这些姿势对于每个人是固定的)
Np = int(pose_labels.max() + 1)# 这个不知道为什么是这么写,不是one-hot么?难道不应该是pose_label的长度吗?
Nd = int(id_labels.max() + 1) # 发现了,这个不是one-hot。。。这个还没变成one-hot
return [images, id_labels, pose_labels, Nd, Np, Nz, channel_num]
if __name__=="__main__":
parser = argparse.ArgumentParser(description='DR_GAN')
# learning & saving parameterss
parser.add_argument('-lr', type=float, default=0.0002, help='initial learning rate [default: 0.0002]')
parser.add_argument('-beta1', type=float, default=0.5, help='adam optimizer parameter [default: 0.5]')
parser.add_argument('-beta2', type=float, default=0.999, help='adam optimizer parameter [default: 0.999]')
parser.add_argument('-epochs', type=int, default=1000, help='number of epochs for train [default: 1000]')
parser.add_argument('-batch-size', type=int, default=8, help='batch size for training [default: 8]')
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
parser.add_argument('-save-freq', type=int, default=1, help='save learned model for every "-save-freq" epoch')
parser.add_argument('-cuda', action='store_true', default=False, help='enable the gpu')
# data souce
parser.add_argument('-random', action='store_true', default=False, help='use randomely created data to run program, instead of from data-place')
parser.add_argument('-data-place', type=str, default='./data', help='prepared data path to run program')
# model
parser.add_argument('-multi-DRGAN', action='store_true', default=False, help='use multi image DR_GAN model')
parser.add_argument('-images-perID', type=int, default=0, help='number of images per person to input to multi image DR_GAN')
# option
parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot(snapshot/{Single or Multiple}/{date}/{epoch}) [default: None]')
parser.add_argument('-generate', action='store_true', default=None, help='Generate pose modified image from given image')
args = parser.parse_args()
# update args and print
if args.multi_DRGAN:
args.save_dir = os.path.join(args.save_dir, 'Multi',datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
else:
args.save_dir = os.path.join(args.save_dir, 'Single',datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
os.makedirs(args.save_dir)
print("Parameters:")
for attr, value in sorted(args.__dict__.items()):
# __dict__ is the dictionary containing the class's namespace
text ="\t{}={}\n".format(attr.upper(), value)
print(text)
with open('{}/Parameters.txt'.format(args.save_dir),'a') as f:
f.write(text)
# input data
if args.random:
images, id_labels, pose_labels, Nd, Np, Nz, channel_num = create_randomdata()
else:
print('n\Loading data from [%s]...' % args.data_place)
try:
images, id_labels, pose_labels, Nd, Np, Nz, channel_num = DataLoader(args.data_place)
except:
print("Sorry, failed to load data")
# model
if args.snapshot is None:
if not(args.multi_DRGAN): # 如果是多张图输入的话,要有images-perID, 也就是每个人有多少张图片,这个要说清楚
D = single_model.Discriminator(Nd, Np, channel_num)
G = single_model.Generator(Np, Nz, channel_num)
else:
if args==0:
print("Please specify -images-perID of your data to input to multi_DRGAN")
exit()
else:
D = multi_model.Discriminator(Nd, Np, channel_num)
G = multi_model.Generator(Np, Nz, channel_num, args.images_perID) # 最后这个id数被送到了生成器
else:
print('\nLoading model from [%s]...' % args.snapshot)
try:
D = torch.load('{}_D.pt'.format(args.snapshot))
G = torch.load('{}_G.pt'.format(args.snapshot))
except:
print("Sorry, This snapshot doesn't exist.")
exit()
if not(args.generate):
if not(args.multi_DRGAN):
train_single_DRGAN(images, id_labels, pose_labels, Nd, Np, Nz, D, G, args)
else:
if args.batch_size % args.images_perID == 0:
train_multiple_DRGAN(images, id_labels, pose_labels, Nd, Np, Nz, D, G, args)
else:
print("Please give valid combination of batch_size, images_perID")
exit()
else:
# pose_code = [] # specify arbitrary pose code for every image
pose_code = np.random.uniform(-1,1, (images.shape[0], Np))
features = Generate_Image(images, pose_code, Nz, G, args)
something like these