PDARTS 即 Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation,是对 DARTS 的改进。DARTS 内存占用过高,训练不了较大的模型;PDARTS 将训练划分为3个阶段,逐步搜索,在增加网络深度的同时缩减操作种类。构造3次网络拉长了训练周期,过程如下图所示:
此外,算法还对筛选细节进行了控制。chenxin061/pdarts 修改自 quark0/darts,主函数逻辑稍显复杂。
start_time = time.time()
main()
end_time = time.time()
duration = end_time - start_time
logging.info('Total searching time: %ds', duration)
if not torch.cuda.is_available():
logging.info('No GPU device available')
sys.exit(1)
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled=True
torch.cuda.manual_seed(args.seed)
logging.info('GPU device = %d' % args.gpu)
logging.info("args = %s", args)
没有将阶段内的处理封装为函数,流程不太直观。
_data_transforms_cifar100 包括随机截取、翻转、标准化和随机裁剪。
CIFAR100 是 CIFAR10 的子类。
torch.utils.data.sampler.SubsetRandomSampler 从给定的索引列表中随机抽取元素样本,不替换。
# prepare dataset
if args.cifar100:
train_transform, valid_transform = utils._data_transforms_cifar100(args)
else:
train_transform, valid_transform = utils._data_transforms_cifar10(args)
if args.cifar100:
train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
else:
train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=args.workers)
valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=args.workers)
PRIMITIVES 定义了网络可用的原语,共8种。经3轮丢弃num_to_drop
后,操作位置上剩1种或无操作。
switches_normal
和switches_reduce
为操作名称列表。单元内的连接数量为14。
# build Network
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
switches = []
for i in range(14):
switches.append([True for j in range(len(PRIMITIVES))])
switches_normal = copy.deepcopy(switches)
switches_reduce = copy.deepcopy(switches)
# To be moved to args
num_to_keep = [5, 3, 1]
num_to_drop = [3, 2, 2]
if len(args.add_width) == 3:
add_width = args.add_width
else:
add_width = [0, 0, 0]
if len(args.add_layers) == 3:
add_layers = args.add_layers
else:
add_layers = [0, 6, 12]
if len(args.dropout_rate) ==3:
drop_rate = args.dropout_rate
else:
drop_rate = [0.0, 0.0, 0.0]
eps_no_archs = [10, 10, 10]
依次构建每个阶段的网络进行训练。sp
即 search phase。
P-DARTS 网络深度为5->11->17,DARTS 为7。
Network 构建网络。
count_parameters_in_MB 统计模型大小。
train 传入两种优化器,搜索结构用 Adam,训练模型用 SGD。
最后5个 epoch 调用 infer 在验证集上测试模型。
for sp in range(len(num_to_keep)):
model = Network(args.init_channels + int(add_width[sp]), CIFAR_CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]))
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
network_params = []
for k, v in model.named_parameters():
if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
network_params.append(v)
optimizer = torch.optim.SGD(
network_params,
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)
optimizer_a = torch.optim.Adam(model.arch_parameters(),
lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(args.epochs), eta_min=args.learning_rate_min)
sm_dim = -1
epochs = args.epochs
eps_no_arch = eps_no_archs[sp]
scale_factor = 0.2
for epoch in range(epochs):
scheduler.step()
lr = scheduler.get_lr()[0]
logging.info('Epoch: %d lr: %e', epoch, lr)
epoch_start = time.time()
# training
if epoch < eps_no_arch:
model.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
model.update_p()
train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=False)
else:
model.p = float(drop_rate[sp]) * np.exp(-(epoch - eps_no_arch) * scale_factor)
model.update_p()
train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=True)
logging.info('Train_acc %f', train_acc)
epoch_duration = time.time() - epoch_start
logging.info('Epoch time: %ds', epoch_duration)
# validation
if epochs - epoch < 5:
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('Valid_acc %f', valid_acc)
utils.save 保存阶段训练的结果。问题是名字一样会覆盖。
switches_normal_2
和switches_reduce_2
为第2阶段处理前的操作列表。
utils.save(model, os.path.join(args.save, 'weights.pt'))
print('------Dropping %d paths------' % num_to_drop[sp])
# Save switches info for s-c refinement.
if sp == len(num_to_keep) - 1:
switches_normal_2 = copy.deepcopy(switches_normal)
switches_reduce_2 = copy.deepcopy(switches_reduce)
arch_parameters 返回 ( α n o r m a l , α r e d u c e ) (\alpha_{normal}, \alpha_{reduce}) (αnormal,αreduce)。
计算normal_prob
:
e x p ( α o ( i , j ) ) ∑ o ′ ∈ O e x p ( α o ′ ( i , j ) ) \begin{aligned} \frac{\mathrm{exp}(\alpha_o^{(i,j)})}{\sum_{o'\in\mathcal{O}}\mathrm{exp}(\alpha_{o'}^{(i,j)})} \end{aligned} ∑o′∈Oexp(αo′(i,j))exp(αo(i,j))
idxs
记录处于活跃状态的操作符的类型索引。
get_min_k 返回最小的num_to_drop[sp]
个索引。
get_min_k_no_zero 先检查idxs
是否有0。
在最后一个阶段丢弃所有空操作,否则丢弃指定数量的小权重操作。
# drop operations with low architecture weights
arch_param = model.arch_parameters()
normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
for i in range(14):
idxs = []
for j in range(len(PRIMITIVES)):
if switches_normal[i][j]:
idxs.append(j)
if sp == len(num_to_keep) - 1:
# for the last stage, drop all Zero operations
drop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])
else:
drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
for idx in drop:
switches_normal[i][idxs[idx]] = False
缩减单元的处理与之相同。
reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
for i in range(14):
idxs = []
for j in range(len(PRIMITIVES)):
if switches_reduce[i][j]:
idxs.append(j)
if sp == len(num_to_keep) - 1:
drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])
else:
drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
for idx in drop:
switches_reduce[i][idxs[idx]] = False
logging.info('switches_normal = %s', switches_normal)
logging_switches(switches_normal)
logging.info('switches_reduce = %s', switches_reduce)
logging_switches(switches_reduce)
在阶段的末尾,读取结构参数。
normal_final
和reduce_final
记录每个单元中非空操作选中的最大概率。
if sp == len(num_to_keep) - 1:
arch_param = model.arch_parameters()
normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
normal_final = [0 for idx in range(14)]
reduce_final = [0 for idx in range(14)]
# remove all Zero operations
for i in range(14):
if switches_normal_2[i][0] == True:
normal_prob[i][0] = 0
normal_final[i] = max(normal_prob[i])
if switches_reduce_2[i][0] == True:
reduce_prob[i][0] = 0
reduce_final[i] = max(reduce_prob[i])
单元中的第1层为两个操作,start = 2
跳过。2-4,5-8,9-13。
tbsn
和tbsr
为标准和缩减单元当前层供选择的位置。根据操作概率的大小排序。keep_normal
和keep_reduce
记录需要保持的连接的索引。
过滤得到最终的switches_normal
和switches_reduce
,每层两个操作。
# Generate Architecture, similar to DARTS
keep_normal = [0, 1]
keep_reduce = [0, 1]
n = 3
start = 2
for i in range(3):
end = start + n
tbsn = normal_final[start:end]
tbsr = reduce_final[start:end]
edge_n = sorted(range(n), key=lambda x: tbsn[x])
keep_normal.append(edge_n[-1] + start)
keep_normal.append(edge_n[-2] + start)
edge_r = sorted(range(n), key=lambda x: tbsr[x])
keep_reduce.append(edge_r[-1] + start)
keep_reduce.append(edge_r[-2] + start)
start = end
n = n + 1
# set switches according the ranking of arch parameters
for i in range(14):
if not i in keep_normal:
for j in range(len(PRIMITIVES)):
switches_normal[i][j] = False
if not i in keep_reduce:
for j in range(len(PRIMITIVES)):
switches_reduce[i][j] = False
parse_network 根据编码列表解析得到网络基因型。
check_sk_number 检查网络标准单元中skip_connect
的数量,对应 PRIMITIVES 的索引3。
delete_min_sk_prob 删除最小权重的跳跃连接。
keep_1_on 丢2留一。
keep_2_branches 修剪连接,每层仅保留两个。
逐渐减少网络标准单元中skip_connect
的数量并记录。
# translate switches into genotype
genotype = parse_network(switches_normal, switches_reduce)
logging.info(genotype)
## restrict skipconnect (normal cell only)
logging.info('Restricting skipconnect...')
# generating genotypes with different numbers of skip-connect operations
for sks in range(0, 9):
max_sk = 8 - sks
num_sk = check_sk_number(switches_normal)
if not num_sk > max_sk:
continue
while num_sk > max_sk:
normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)
switches_normal = keep_1_on(switches_normal_2, normal_prob)
switches_normal = keep_2_branches(switches_normal, normal_prob)
num_sk = check_sk_number(switches_normal)
logging.info('Number of skip-connect: %d', max_sk)
genotype = parse_network(switches_normal, switches_reduce)
logging.info(genotype)
初始化3个指标。
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
如果训练结构,从valid_queue
中取数据,先行训练。
for step, (input, target) in enumerate(train_queue):
model.train()
n = input.size(0)
input = input.cuda()
target = target.cuda(non_blocking=True)
if train_arch:
# In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down
# the training when using PyTorch 0.4 and above.
try:
input_search, target_search = next(valid_queue_iter)
except:
valid_queue_iter = iter(valid_queue)
input_search, target_search = next(valid_queue_iter)
input_search = input_search.cuda()
target_search = target_search.cuda(non_blocking=True)
optimizer_a.zero_grad()
logits = model(input_search)
loss_a = criterion(logits, target_search)
loss_a.backward()
nn.utils.clip_grad_norm_(model.arch_parameters(), args.grad_clip)
optimizer_a.step()
在训练集上训练权重。
optimizer.zero_grad()
logits = model(input)
loss = criterion(logits, target)
loss.backward()
nn.utils.clip_grad_norm_(network_params, args.grad_clip)
optimizer.step()
调用 utils.accuracy 计算训练集上的准确率。
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
objs.update(loss.data.item(), n)
top1.update(prec1.data.item(), n)
top5.update(prec5.data.item(), n)
if step % args.report_freq == 0:
logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
with torch.no_grad():
logits = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data.item(), n)
top1.update(prec1.data.item(), n)
top5.update(prec5.data.item(), n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
相比原有变换多了 Cutout。
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
嵌套定义函数_parse_switches
。解析两种类型的单元,记录操作类型和所在层次,得到 Genotype 类型的元组。
def _parse_switches(switches):
n = 2
start = 0
gene = []
step = 4
for i in range(step):
end = start + n
for j in range(start, end):
for k in range(len(switches[j])):
if switches[j][k]:
gene.append((PRIMITIVES[k], j - start))
start = end
n = n + 1
return gene
gene_normal = _parse_switches(switches_normal)
gene_reduce = _parse_switches(switches_reduce)
concat = range(2, 6)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype
C
为通道数量,layers
为层数,steps
为内部所划分的层次,multiplier
为输出通道的乘数,stem_multiplier
为柄通道乘数。
switch_ons
记录每个操作位置可选操作的数量。self.switch_on
直接取第一个位置的操作数。
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3, switches_normal=[], switches_reduce=[], p=0.0):
super(Network, self).__init__()
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
self.p = p
self.switches_normal = switches_normal
switch_ons = []
for i in range(len(switches_normal)):
ons = 0
for j in range(len(switches_normal[i])):
if switches_normal[i][j]:
ons = ons + 1
switch_ons.append(ons)
ons = 0
self.switch_on = switch_ons[0]
网络起始未下采样,在1/3和2/3处插入缩减单元。
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_reduce, self.p)
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_normal, self.p)
# cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier*C_curr
_initialize_alphas 初始化结构参数,类型为Variable
,而不是 torch.nn.Parameter。
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._initialize_alphas()
同类型的不同单元公用结构参数。
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
if self.alphas_reduce.size(1) == 1:
weights = F.softmax(self.alphas_reduce, dim=0)
else:
weights = F.softmax(self.alphas_reduce, dim=-1)
else:
if self.alphas_normal.size(1) == 1:
weights = F.softmax(self.alphas_normal, dim=0)
else:
weights = F.softmax(self.alphas_normal, dim=-1)
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
update_p 给数据并行带来了麻烦。
for cell in self.cells:
cell.p = self.p
cell.update_p()
函数没有用到。
logits = self(input)
return self._criterion(logits, target)
k
为单元中 MixedOp 的数量,self.switch_on
为 MixedOp 中候选操作的种类。
k = sum(1 for i in range(self._steps) for n in range(2+i))
num_ops = self.switch_on
self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
return self._arch_parameters
FactorizedReduce 采用位置交错的两组卷积。
与 NASNet、AmoebaNet 和 PNAS 一样卷积采用 ReLUConvBN。
没有手动初始化权重。
steps=4
,使得 Cell 中包含 2+3+4+5=14 个 MixedOp,即len(self.cell_ops)=14
。每层多2个用于处理输入。
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, switches, p):
super(Cell, self).__init__()
self.reduction = reduction
self.p = p
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self.cell_ops = nn.ModuleList()
switch_count = 0
for i in range(self._steps):
for j in range(2+i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride, switch=switches[switch_count], p=self.p)
self.cell_ops.append(op)
switch_count = switch_count + 1
for op in self.cell_ops:
op.p = self.p
op.update_p()
每个中间节点都基于其所有先前节点计算:
x ( j ) = ∑ i < j o ( i , j ) ( x ( i ) ) \begin{aligned} x^{(j)} = \sum_{i<j} o^{(i, j)}(x^{(i)}) \end{aligned} x(j)=i<j∑o(i,j)(x(i))
还包括一个特殊的 z e r o \mathit{zero} zero 操作,表示两个节点之间缺少连接。 因此,学习单元的任务减少了学习其边缘的操作。
对于每一步,累加所有操作的输出。offset
不断累加意味着self.cell_ops
的数量为2+3+4+5=14。
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self._steps):
s = sum(self.cell_ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
offset += len(states)
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
OPS 为操作字典。affine=False
设置 nn.BatchNorm2d 屏蔽可学习参数,等效于 Caffe 中的 BN 层。
DARTS 的 A.1.1 中指出由于架构在整个搜索过程中会有所不同,因此其始终使用批量特定的统计信息进行批量标准化而不是全局移动平均值。在搜索过程中禁用所有批量标准化中可学习的仿射参数,以避免重新调整候选操作的输出。然而,代码中并未设置
track_running_stats=False
。
switch
为操作的掩码,len(switch)=len(PRIMITIVES)
。PRIMITIVES 共有8种操作,存储到self.m_ops
。
def __init__(self, C, stride, switch, p):
super(MixedOp, self).__init__()
self.m_ops = nn.ModuleList()
self.p = p
for i in range(len(switch)):
if switch[i]:
primitive = PRIMITIVES[i]
op = OPS[primitive](C, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
if isinstance(op, Identity) and p > 0:
op = nn.Sequential(op, nn.Dropout(self.p))
self.m_ops.append(op)
如果第一个操作是Identity,则在后面添加操作。
for op in self.m_ops:
if isinstance(op, nn.Sequential):
if isinstance(op[0], Identity):
op[1].p = self.p
令 O \mathcal{O} O 为一组候选操作(例如卷积、最大合并、 z e r o \mathit{zero} zero),其中每个操作代表应用于 x ( i ) x^{(i)} x(i) 的函数 o ( ⋅ ) o(\cdot) o(⋅)。
为了使搜索空间连续,DARTS 将特定操作的分类选择放宽为所有可能操作的 softmax:
o ˉ ( i , j ) ( x ) = ∑ o ∈ O exp ( α o ( i , j ) ) ∑ o ′ ∈ O exp ( α o ′ ( i , j ) ) o ( x ) \begin{aligned} \bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o' \in \mathcal{O}} \exp(\alpha_{o'}^{(i,j)})} o(x) \end{aligned} oˉ(i,j)(x)=o∈O∑∑o′∈Oexp(αo′(i,j))exp(αo(i,j))o(x)
其中一对节点 ( i , j ) (i,j) (i,j) 的操作混合权重由维数 ∣ O ∣ |\mathcal{O}| ∣O∣ 的向量 α ( i , j ) \alpha^{(i,j)} α(i,j) 参数化。
然后,架构搜索的任务化简为学习一组连续变量 α = { α ( i , j ) } \alpha = \big\{ \alpha^{(i,j)} \big\} α={α(i,j)}。在搜索结束时,可以通过用最可能的操作替换每个混合操作 o ˉ ( i , j ) \bar{o}^{(i,j)} oˉ(i,j) 来获得离散体系结构,即
o ( i , j ) = a r g m a x o ∈ O   α o ( i , j ) o^{(i,j)} = \mathrm{argmax}_{o \in \mathcal{O}} \, \alpha^{(i,j)}_o o(i,j)=argmaxo∈Oαo(i,j).
return sum(w * op(x) for w, op in zip(weights, self.m_ops))
模型中定义forward
之外的函数,导致不能正常使用 torch.nn.DataParallel。
嵌套定义_get_sk_idx
函数。如果输入的列表里没有跳跃连接则返回-1;否则返回原列表switches_bk
中的跳跃连接索引。
def _get_sk_idx(switches_in, switches_bk, k):
if not switches_in[k][3]:
idx = -1
else:
idx = 0
for i in range(3):
if switches_bk[k][i]:
idx = idx + 1
return idx
避免修改输入,sk_prob
记录每个位置上跳跃连接的权重。从中取最小的置为0。
probs_out = copy.deepcopy(probs_in)
sk_prob = [1.0 for i in range(len(switches_bk))]
for i in range(len(switches_in)):
idx = _get_sk_idx(switches_in, switches_bk, i)
if not idx == -1:
sk_prob[i] = probs_out[i][idx]
d_idx = np.argmin(sk_prob)
idx = _get_sk_idx(switches_in, switches_bk, d_idx)
probs_out[d_idx][idx] = 0.0
return probs_out
对于每个操作位,idxs
记录可选操作的索引。get_min_k_no_zero 查找操作位概率最小且非空的2个操作,丢弃掉。
switches = copy.deepcopy(switches_in)
for i in range(len(switches)):
idxs = []
for j in range(len(PRIMITIVES)):
if switches[i][j]:
idxs.append(j)
drop = get_min_k_no_zero(probs[i, :], idxs, 2)
for idx in drop:
switches[i][idxs[idx]] = False
return switches
final_prob
为每个操作位上操作最大概率。
switches = copy.deepcopy(switches_in)
final_prob = [0.0 for i in range(len(switches))]
for i in range(len(switches)):
final_prob[i] = max(probs[i])
第1层只有两个操作位,所以直接保留。
后续3层依次取出其最大概率,排序后取最大的两个位置。
keep = [0, 1]
n = 3
start = 2
for i in range(3):
end = start + n
tb = final_prob[start:end]
edge = sorted(range(n), key=lambda x: tb[x])
keep.append(edge[-1] + start)
keep.append(edge[-2] + start)
start = end
n = n + 1
遍历位置,在switches
屏蔽未选中的位置。
for i in range(len(switches)):
if not i in keep:
for j in range(len(PRIMITIVES)):
switches[i][j] = False
return switches
for i in range(len(switches)):
ops = []
for j in range(len(switches[i])):
if switches[i][j]:
ops.append(PRIMITIVES[j])
logging.info(ops)