for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train(True)
else:
model.train(False)
running_loss = 0.0
running_corrects = 0.0
for iter, data in enumerate(dataloaders[phase]):
inputs, labels = data
now_batch_size,c,h,w = inputs.shape
if now_batch_size<opt.batchsize:
continue
if use_gpu:
inputs = Variable(inputs.cuda().detach())
labels = Variable(labels.cuda().detach())
else:
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
if phase == 'val':
with torch.no_grad():
outputs = model(inputs)
else:
outputs = model(inputs)
if opt.adv>0 and iter%opt.aiter==0:
inputs_adv = ODFA(model, inputs)
outputs_adv = model(inputs_adv)
sm = nn.Softmax(dim=1)
log_sm = nn.LogSoftmax(dim=1)
return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.instance or opt.lifted or opt.sphere
if return_feature:
logits, ff = outputs
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
ff = ff.div(fnorm.expand_as(ff))
loss = criterion(logits, labels)
_, preds = torch.max(logits.data, 1)
if opt.adv>0 and iter%opt.aiter==0:
logits_adv, _ = outputs_adv
loss += opt.adv * criterion(logits_adv, labels)
if opt.arcface:
loss += criterion_arcface(ff, labels)/now_batch_size
if opt.cosface:
loss += criterion_cosface(ff, labels)/now_batch_size
if opt.circle:
loss += criterion_circle(*convert_label_to_similarity( ff, labels))/now_batch_size
if opt.triplet:
hard_pairs = miner(ff, labels)
loss += criterion_triplet(ff, labels, hard_pairs)
if opt.lifted:
loss += criterion_lifted(ff, labels)
if opt.contrast:
loss += criterion_contrast(ff, labels)
if opt.instance:
loss += criterion_instance(ff) /now_batch_size
if opt.sphere:
loss += criterion_sphere(ff, labels)/now_batch_size
elif opt.PCB:
part = {}
num_part = 6
for i in range(num_part):
part[i] = outputs[i]
score = sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5])
_, preds = torch.max(score.data, 1)
loss = criterion(part[0], labels)
for i in range(num_part-1):
loss += criterion(part[i+1], labels)
else:
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if opt.adv>0 and iter%opt.aiter==0:
loss += opt.adv * criterion(outputs_adv, labels)
del inputs
if opt.DG and phase == 'train' and epoch > num_epochs*0.1:
try:
_, batch = DGloader_iter.__next__()
except StopIteration:
DGloader_iter = enumerate(dataloaders['DG'])
_, batch = DGloader_iter.__next__()
except UnboundLocalError:
DGloader_iter = enumerate(dataloaders['DG'])
_, batch = DGloader_iter.__next__()
inputs1, inputs2, _ = batch
inputs1 = inputs1.cuda().detach()
inputs2 = inputs2.cuda().detach()
outputs1 = model(inputs1)
if return_feature:
outputs1, _ = outputs1
elif opt.PCB:
for i in range(num_part):
part[i] = outputs1[i]
outputs1 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
outputs2 = model(inputs2)
if return_feature:
outputs2, _ = outputs2
elif opt.PCB:
for i in range(num_part):
part[i] = outputs2[i]
outputs2 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
mean_pred = sm(outputs1 + outputs2)
kl_loss = nn.KLDivLoss(size_average=False)
reg= (kl_loss(log_sm(outputs2) , mean_pred) + kl_loss(log_sm(outputs1) , mean_pred))/2
loss += 0.01*reg
del inputs1, inputs2
if epoch<opt.warm_epoch and phase == 'train':
warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
loss = loss*warm_up
if phase == 'train':
if fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
if int(version[0])>0 or int(version[2]) > 3:
running_loss += loss.item() * now_batch_size
else :
running_loss += loss.data[0] * now_batch_size
del loss
running_corrects += float(torch.sum(preds == labels.data))
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
y_loss[phase].append(epoch_loss)
y_err[phase].append(1.0-epoch_acc)
if phase == 'val':
last_model_wts = model.state_dict()
if epoch%10 == 9:
save_network(model, epoch)
draw_curve(epoch)
if phase == 'train':
scheduler.step()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
model.load_state_dict(last_model_wts)
save_network(model, 'last')
return model