dcgan.torch/main.lua

require 'torch'
  require 'nn'
  require 'optim'
   
  opt = {
  dataset = 'lsun', -- imagenet / lsun / folder
  batchSize = 64,
  loadSize = 96,
  fineSize = 64,
  nz = 100, -- # of dim for Z
  ngf = 64, -- # of gen filters in first conv layer
  ndf = 64, -- # of discrim filters in first conv layer
  nThreads = 4, -- # of data loading threads to use
  niter = 25, -- # of iter at starting learning rate
  lr = 0.0002, -- initial learning rate for adam
  beta1 = 0.5, -- momentum term of adam
  ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset
  display = 1, -- display samples while training. 0 = false
  display_id = 10, -- display window id.
  gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
  name = 'experiment1',
  noise = 'normal', -- uniform / normal
  }
   
  -- one-line argument parser. parses enviroment variables to override the defaults
  for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
  print(opt)
  if opt.display == 0 then opt.display = false end
   
  opt.manualSeed = torch.random(1, 10000) -- fix seed
  print("Random Seed: " .. opt.manualSeed)
  torch.manualSeed(opt.manualSeed)
  torch.setnumthreads(1)
  torch.setdefaulttensortype('torch.FloatTensor')
   
  -- create data loader
  local DataLoader = paths.dofile('data/data.lua')
  local data = DataLoader.new(opt.nThreads, opt.dataset, opt)
  print("Dataset: " .. opt.dataset, " Size: ", data:size())
  ----------------------------------------------------------------------------
  local function weights_init(m)
  local name = torch.type(m)
  if name:find('Convolution') then
  m.weight:normal(0.0, 0.02)
  m:noBias()
  elseif name:find('BatchNormalization') then
  if m.weight then m.weight:normal(1.0, 0.02) end
  if m.bias then m.bias:fill(0) end
  end
  end
   
  local nc = 3
  local nz = opt.nz
  local ndf = opt.ndf
  local ngf = opt.ngf
  local real_label = 1
  local fake_label = 0
   
  local SpatialBatchNormalization = nn.SpatialBatchNormalization
  local SpatialConvolution = nn.SpatialConvolution
  local SpatialFullConvolution = nn.SpatialFullConvolution
   
  local netG = nn.Sequential()
  -- input is Z, going into a convolution
  netG:add(SpatialFullConvolution(nz, ngf * 8, 4, 4))
  netG:add(SpatialBatchNormalization(ngf * 8)):add(nn.ReLU(true))
  -- state size: (ngf*8) x 4 x 4
  netG:add(SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1))
  netG:add(SpatialBatchNormalization(ngf * 4)):add(nn.ReLU(true))
  -- state size: (ngf*4) x 8 x 8
  netG:add(SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1))
  netG:add(SpatialBatchNormalization(ngf * 2)):add(nn.ReLU(true))
  -- state size: (ngf*2) x 16 x 16
  netG:add(SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1))
  netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true))
  -- state size: (ngf) x 32 x 32
  netG:add(SpatialFullConvolution(ngf, nc, 4, 4, 2, 2, 1, 1))
  netG:add(nn.Tanh())
  -- state size: (nc) x 64 x 64
   
  netG:apply(weights_init)
   
  local netD = nn.Sequential()
   
  -- input is (nc) x 64 x 64
  netD:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1))
  netD:add(nn.LeakyReLU(0.2, true))
  -- state size: (ndf) x 32 x 32
  netD:add(SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
  netD:add(SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
  -- state size: (ndf*2) x 16 x 16
  netD:add(SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1))
  netD:add(SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
  -- state size: (ndf*4) x 8 x 8
  netD:add(SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
  netD:add(SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
  -- state size: (ndf*8) x 4 x 4
  netD:add(SpatialConvolution(ndf * 8, 1, 4, 4))
  netD:add(nn.Sigmoid())
  -- state size: 1 x 1 x 1
  netD:add(nn.View(1):setNumInputDims(3))
  -- state size: 1
   
  netD:apply(weights_init)
   
  local criterion = nn.BCECriterion()
  ---------------------------------------------------------------------------
  optimStateG = {
  learningRate = opt.lr,
  beta1 = opt.beta1,
  }
  optimStateD = {
  learningRate = opt.lr,
  beta1 = opt.beta1,
  }
  ----------------------------------------------------------------------------
  local input = torch.Tensor(opt.batchSize, 3, opt.fineSize, opt.fineSize)
  local noise = torch.Tensor(opt.batchSize, nz, 1, 1)
  local label = torch.Tensor(opt.batchSize)
  local errD, errG
  local epoch_tm = torch.Timer()
  local tm = torch.Timer()
  local data_tm = torch.Timer()
  ----------------------------------------------------------------------------
  if opt.gpu > 0 then
  require 'cunn'
  cutorch.setDevice(opt.gpu)
  input = input:cuda(); noise = noise:cuda(); label = label:cuda()
   
  if pcall(require, 'cudnn') then
  require 'cudnn'
  cudnn.benchmark = true
  cudnn.convert(netG, cudnn)
  cudnn.convert(netD, cudnn)
  end
  netD:cuda(); netG:cuda(); criterion:cuda()
  end
   
  local parametersD, gradParametersD = netD:getParameters()
  local parametersG, gradParametersG = netG:getParameters()
   
  if opt.display then disp = require 'display' end
   
  noise_vis = noise:clone()
  if opt.noise == 'uniform' then
  noise_vis:uniform(-1, 1)
  elseif opt.noise == 'normal' then
  noise_vis:normal(0, 1)
  end
   
  -- create closure to evaluate f(X) and df/dX of discriminator
  local fDx = function(x)
  gradParametersD:zero()
   
  -- train with real
  data_tm:reset(); data_tm:resume()
  local real = data:getBatch()
  data_tm:stop()
  input:copy(real)
  label:fill(real_label)
   
  local output = netD:forward(input)
  local errD_real = criterion:forward(output, label)
  local df_do = criterion:backward(output, label)
  netD:backward(input, df_do)
   
  -- train with fake
  if opt.noise == 'uniform' then -- regenerate random noise
  noise:uniform(-1, 1)
  elseif opt.noise == 'normal' then
  noise:normal(0, 1)
  end
  local fake = netG:forward(noise)
  input:copy(fake)
  label:fill(fake_label)
   
  local output = netD:forward(input)
  local errD_fake = criterion:forward(output, label)
  local df_do = criterion:backward(output, label)
  netD:backward(input, df_do)
   
  errD = errD_real + errD_fake
   
  return errD, gradParametersD
  end
   
  -- create closure to evaluate f(X) and df/dX of generator
  local fGx = function(x)
  gradParametersG:zero()
   
  --[[ the three lines below were already executed in fDx, so save computation
  noise:uniform(-1, 1) -- regenerate random noise
  local fake = netG:forward(noise)
  input:copy(fake) ]]--
  label:fill(real_label) -- fake labels are real for generator cost
   
  local output = netD.output -- netD:forward(input) was already executed in fDx, so save computation
  errG = criterion:forward(output, label)
  local df_do = criterion:backward(output, label)
  local df_dg = netD:updateGradInput(input, df_do)
   
  netG:backward(noise, df_dg)
  return errG, gradParametersG
  end
   
  -- train
  for epoch = 1, opt.niter do
  epoch_tm:reset()
  local counter = 0
  for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do
  tm:reset()
  -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
  optim.adam(fDx, parametersD, optimStateD)
   
  -- (2) Update G network: maximize log(D(G(z)))
  optim.adam(fGx, parametersG, optimStateG)
   
  -- display
  counter = counter + 1
  if counter % 10 == 0 and opt.display then
  local fake = netG:forward(noise_vis)
  local real = data:getBatch()
  disp.image(fake, {win=opt.display_id, title=opt.name})
  disp.image(real, {win=opt.display_id * 3, title=opt.name})
  end
   
  -- logging
  if ((i-1) / opt.batchSize) % 1 == 0 then
  print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
  .. ' Err_G: %.4f Err_D: %.4f'):format(
  epoch, ((i-1) / opt.batchSize),
  math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
  tm:time().real, data_tm:time().real,
  errG and errG or -1, errD and errD or -1))
  end
  end
  paths.mkdir('checkpoints')
  parametersD, gradParametersD = nil, nil -- nil them to avoid spiking memory
  parametersG, gradParametersG = nil, nil
  torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_G.t7', netG:clearState())
  torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_D.t7', netD:clearState())
  parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them
  parametersG, gradParametersG = netG:getParameters()
  print(('End of epoch %d / %d \t Time Taken: %.3f'):format(
  epoch, opt.niter, epoch_tm:time().real))
  end

你可能感兴趣的:(AI,GANs)