库,请提前安装好pip install nni
。依赖库版本限制:pip install torchmetrics==0.10
、pip install pytorch-lightning==1.9.4
加载 CIFAR-10 数据集。multi-trial strategies
是必须的。import nni
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.retiarii.evaluator.pytorch import DataLoader
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
transform_valid = transforms.Compose([
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
valid_data = nni.trace(CIFAR10)(root='./data', train=False, download=True, transform=transform_valid)
valid_loader = DataLoader(valid_data, batch_size=256, num_workers=6)
from nni.retiarii.hub.pytorch import DARTS as DartsSpace
darts_v2_model = DartsSpace.load_searched_model('darts-v2', pretrained=True, download=True)
def evaluate_model(model, cuda=False):
device = torch.device('cuda' if cuda else 'cpu')
with torch.no_grad():
correct = total = 0
for inputs, targets in valid_loader:
inputs, targets = inputs.to(device), targets.to(device)
logits = model(inputs)
_, predict = torch.max(logits, 1)
correct += (predict == targets).sum().cpu().item()
total += targets.size(0)
print('Accuracy:', correct / total)
return correct / total
evaluate_model(darts_v2_model, cuda=True) # Set this to false if there's no GPU.
from nni.retiarii.hub.pytorch import DARTS as DartsSpace
# 加载预训练模型
darts_v2_model = DartsSpace.load_searched_model('darts-v2', pretrained=True, download=True)
# 评估模型
def evaluate_model(model, cuda=False):
# 将模型迁移到GPU上
device = torch.device('cuda' if cuda else 'cpu')
# 将模型置为评估状态
# 不计算梯度
with torch.no_grad():
correct = total = 0
for inputs, targets in valid_loader:
# 将数据迁移到GPU上
inputs, targets = inputs.to(device), targets.to(device)
# 模型输出结果
logits = model(inputs)
_, predict = torch.max(logits, 1)
# 统计正确与错误数量
correct += (predict == targets).sum().cpu().item()
total += targets.size(0)
# 打印准确率
print('Accuracy:', correct / total)
return correct / total
evaluate_model(darts_v2_model, cuda=True) # Set this to false if there's no GPU.
Accuracy: 0.9737
,其中完整的模型是通过重复堆叠单个计算单元(称为cell)来构建的。网络中有两种类型的单元。第一种称为普通单元(normal cell),第二种称为缩减单元(reduction cell)。普通单元和缩减单元之间的主要区别在于缩减单元将对输入特征图进行下采样,并降低其分辨率。普通单元和缩减单元交替堆叠,如下图所示。DARTS
。fast_dev_run = True
model_space = DartsSpace(width=16,num_cells=8,dataset='cifar')
用作起点。import numpy as np
from nni.retiarii.evaluator.pytorch import Classification
from torch.utils.data import SubsetRandomSampler
# 图片预处理器
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
# 下载训练数据
train_data = nni.trace(CIFAR10)(root='./data', train=True, download=True, transform=transform)
# train_data数量
num_samples = len(train_data)
# 对图片进行随机排列
indices = np.random.permutation(num_samples)
# 分离点
split = num_samples // 2
# 训练数据加载器
# SubsetRandomSampler():无放回地按照给定的索引列表采样样本元素
search_train_loader = DataLoader(
train_data, batch_size=64, num_workers=6,
# 验证集数据加载器
search_valid_loader = DataLoader(
train_data, batch_size=64, num_workers=6,
# 评估模型
evaluator = Classification(learning_rate=1e-3,
(可微分架构搜索)作为探索模型空间的搜索策略。 DARTS
策略将搜索与模型训练结合到一次运行中。与多试验策略相比,one-shot NAS
不需要迭代产生新的试验(即模型),从而节省了模型训练的过多成本。from nni.retiarii.strategy import DARTS as DartsStrategy
strategy = DartsStrategy()
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
config = RetiariiExeConfig(execution_engine='oneshot')
experiment = RetiariiExperiment(model_space, evaluator=evaluator, strategy=strategy)
。模型是一个 dict
(称为architecture dict
),描述了所选普通单元格和缩减单元格。exported_arch = experiment.export_top_models()[0]
{'normal/op_2_0': 'skip_connect',
'normal/input_2_0': 0,
'normal/op_2_1': 'dil_conv_3x3',
'normal/input_2_1': 1,
'normal/op_3_0': 'sep_conv_3x3',
'normal/input_3_0': 2,
'normal/op_3_1': 'avg_pool_3x3',
'normal/input_3_1': 0,
'normal/op_4_0': 'dil_conv_5x5',
'normal/input_4_0': 0,
'normal/op_4_1': 'dil_conv_5x5',
'normal/input_4_1': 1,
'normal/op_5_0': 'sep_conv_3x3',
'normal/input_5_0': 2,
'normal/op_5_1': 'dil_conv_5x5',
'normal/input_5_1': 0,
'reduce/op_2_0': 'dil_conv_3x3',
'reduce/input_2_0': 1,
'reduce/op_2_1': 'max_pool_3x3',
'reduce/input_2_1': 0,
'reduce/op_3_0': 'sep_conv_3x3',
'reduce/input_3_0': 0,
'reduce/op_3_1': 'sep_conv_3x3',
'reduce/input_3_1': 1,
'reduce/op_4_0': 'dil_conv_3x3',
'reduce/input_4_0': 0,
'reduce/op_4_1': 'dil_conv_5x5',
'reduce/input_4_1': 3,
'reduce/op_5_0': 'sep_conv_5x5',
'reduce/input_5_0': 4,
'reduce/op_5_1': 'sep_conv_3x3',
'reduce/input_5_1': 0}
import io
import graphviz
import matplotlib.pyplot as plt
from PIL import Image
def plot_single_cell(arch_dict, cell_name):
g = graphviz.Digraph(
node_attr=dict(style='filled', shape='rect', align='center'),
g.node('c_{k-2}', fillcolor='darkseagreen2')
g.node('c_{k-1}', fillcolor='darkseagreen2')
assert len(arch_dict) % 2 == 0
for i in range(2, 6):
g.node(str(i), fillcolor='lightblue')
for i in range(2, 6):
for j in range(2):
op = arch_dict[f'{cell_name}/op_{i}_{j}']
from_ = arch_dict[f'{cell_name}/input_{i}_{j}']
if from_ == 0:
u = 'c_{k-2}'
elif from_ == 1:
u = 'c_{k-1}'
u = str(from_)
v = str(i)
g.edge(u, v, label=op, fillcolor='gray')
g.node('c_{k}', fillcolor='palegoldenrod')
for i in range(2, 6):
g.edge(str(i), 'c_{k}', fillcolor='gray')
g.attr(label=f'{cell_name.capitalize()} cell')
image = Image.open(io.BytesIO(g.pipe()))
return image
def plot_double_cells(arch_dict):
image1 = plot_single_cell(arch_dict, 'normal')
image2 = plot_single_cell(arch_dict, 'reduce')
height_ratio = max(image1.size[1] / image1.size[0], image2.size[1] / image2.size[0])
_, axs = plt.subplots(1, 2, figsize=(20, 10 * height_ratio))
from nni.retiarii import fixed_arch
with fixed_arch(exported_arch):
final_model = DartsSpace(width=16, num_cells=8, dataset='cifar')
train_loader = DataLoader(train_data, batch_size=96, num_workers=6) # Use the original training data
ax_epochs = 100
evaluator = Classification(