上一篇文章已经介绍了怎么训练一个MLP网络,这篇文章将介绍一下怎么用VeLO训练resnets
这篇文章基于https://colab.research.google.com/drive/1-ms12IypE-EdDSNjhFMdRdBbMnH94zpH#scrollTo=RQBACAPQZyB-,将介绍使用learned optimizer in the VeLO family:
#@title imports, configuration, and model classes
from absl import app
from datetime import datetime
from functools import partial
from typing import Any, Callable, Sequence, Tuple
from flax import linen as nn
import jax
import jax.numpy as jnp
from jaxopt import loss
from jaxopt import OptaxSolver
from jaxopt import tree_util
import optax
import tensorflow_datasets as tfds
import tensorflow as tf
# 可以使用的数据集
dataset_names = [
"mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"
]
L2REG = 1e-4
LEARNING_RATE = .2
EPOCHS = 10
MOMENTUM = .9
DATASET = 'cifar100' #@param [ "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"]
MODEL = 'resnet18' #@param ["resnet1", "resnet18", "resnet34"]
TRAIN_BATCH_SIZE = 256
TEST_BATCH_SIZE = 1024
# 加载数据集
def load_dataset(split, *, is_training, batch_size):
version = 3
ds, ds_info = tfds.load(
f"{DATASET}:{version}.*.*",
as_supervised=True, # remove useless keys
split=split,
with_info=True)
ds = ds.cache().repeat()
if is_training:
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds)), ds_info
class ResNetBlock(nn.Module):
"""ResNet block."""
filters: int
conv: Any
norm: Any
act: Callable
strides: Tuple[int, int] = (1, 1)
@nn.compact
def __call__(self, x,):
residual = x
y = self.conv(self.filters, (3, 3), self.strides)(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3))(y)
y = self.norm(scale_init=nn.initializers.zeros)(y)
if residual.shape != y.shape:
residual = self.conv(self.filters, (1, 1),
self.strides, name='conv_proj')(residual)
residual = self.norm(name='norm_proj')(residual)
return self.act(residual + y)
class ResNet(nn.Module):
"""ResNetV1."""
stage_sizes: Sequence[int]
block_cls: Any
num_classes: int
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
@nn.compact
def __call__(self, x, train: bool = True):
conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
norm = partial(nn.BatchNorm,
use_running_average=not train,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
x = conv(self.num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')(x)
x = norm(name='bn_init')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(self.num_filters * 2 ** i,
strides=strides,
conv=conv,
norm=norm,
act=self.act)(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, self.dtype)
return x
# 虽然不太清楚为啥ResNet为啥没有__init__函数,但是估计又是python某个不知名的骚操作吧 emmm 我看它__call__这个函数也写的挺骚的。
ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock)
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)
#@title training loop definition (run this cell to launch training)
import functools
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jaxopt._src import base
from jaxopt._src import tree_util
# 这个类的目的只是为了保存状态信息的吧
class OptaxState(NamedTuple):
"""Named tuple containing state information."""
iter_num: int # 迭代的数量
value: float # value
error: float #
internal_state: NamedTuple
aux: Any
# we need to reimplement optax's OptaxSolver's lopt_update method to properly pass in the loss data that VeLO expects.
def lopt_update(self,
params: Any,
state: NamedTuple,
*args,
**kwargs) -> base.OptStep:
"""Performs one iteration of the optax solver.
Args:
params: pytree containing the parameters. 应该是resnet参数的pytree
state: named tuple containing the solver state.
*args: additional positional arguments to be passed to ``fun``.
**kwargs: additional keyword arguments to be passed to ``fun``.
Returns:
(params, state)
"""
if self.pre_update:
params, state = self.pre_update(params, state, *args, **kwargs)
(value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
# note the only difference between this function and the baseline
# optax.OptaxSolver.lopt_update is that `extra_args` is now passed.
# if you would like to use a different optimizer, you will likely need to
# remove these extra_args.
delta, opt_state = self.opt.update(grad, state.internal_state, params, extra_args={"loss": value})
params = self._apply_updates(params, delta)
# Computes optimality error before update to re-use grad evaluation.
new_state = OptaxState(iter_num=state.iter_num + 1,
error=tree_util.tree_l2_norm(grad),
value=value,
aux=aux,
internal_state=opt_state)
return base.OptStep(params=params, state=new_state)
def train():
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
# tf.config.experimental.set_visible_devices([], 'GPU')
# typical data loading and iterator setup
train_ds, ds_info = load_dataset("train", is_training=True,
batch_size=TRAIN_BATCH_SIZE)
test_ds, _ = load_dataset("test", is_training=False,
batch_size=TEST_BATCH_SIZE)
input_shape = (1,) + ds_info.features["image"].shape
num_classes = ds_info.features["label"].num_classes
iter_per_epoch_train = ds_info.splits['train'].num_examples // TRAIN_BATCH_SIZE
iter_per_epoch_test = ds_info.splits['test'].num_examples // TEST_BATCH_SIZE
# Set up model.
if MODEL == "resnet1":
net = ResNet1(num_classes=num_classes)
elif MODEL == "resnet18":
net = ResNet18(num_classes=num_classes)
elif MODEL == "resnet34":
net = ResNet34(num_classes=num_classes)
else:
raise ValueError("Unknown model.")
def predict(params, inputs, aux, train=False):
x = inputs.astype(jnp.float32) / 255.
all_params = {"params": params, "batch_stats": aux}
if train:
# Returns logits and net_state (which contains the key "batch_stats").
return net.apply(all_params, x, train=True, mutable=["batch_stats"])
else:
# Returns logits only.
return net.apply(all_params, x, train=False)
logistic_loss = jax.vmap(loss.multiclass_logistic_loss)
def loss_from_logits(params, l2reg, logits, labels):
mean_loss = jnp.mean(logistic_loss(labels, logits))
sqnorm = tree_util.tree_l2_norm(params, squared=True)
return mean_loss + 0.5 * l2reg * sqnorm
def accuracy_and_loss(params, l2reg, data, aux):
inputs, labels = data
logits = predict(params, inputs, aux)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
loss = loss_from_logits(params, l2reg, logits, labels)
return accuracy, loss
def loss_fun(params, l2reg, data, aux):
inputs, labels = data
logits, net_state = predict(params, inputs, aux, train=True)
loss = loss_from_logits(params, l2reg, logits, labels)
# batch_stats will be stored in state.aux
return loss, net_state["batch_stats"]
# The default optimizer used by jaxopt is commented out here
# opt = optax.sgd(learning_rate=LEARNING_RATE,
# momentum=MOMENTUM,
# nesterov=True)
NUM_STEPS = EPOCHS * iter_per_epoch_train
opt = prefab.optax_lopt(NUM_STEPS)
# We need has_aux=True because loss_fun returns batch_stats.
solver = OptaxSolver(opt=opt,
fun=jax.value_and_grad(loss_fun, has_aux=True),
maxiter=EPOCHS * iter_per_epoch_train,
has_aux=True,
value_and_grad=True)
# Initialize parameters.
# 初始化训练的参数
rng = jax.random.PRNGKey(0)
init_vars = net.init(rng, jnp.zeros(input_shape), train=True) # 这里的net是resnet,但是我不清楚这里的 init_vars['params']是个什么东西 emm
params = init_vars["params"]
batch_stats = init_vars["batch_stats"]
start = datetime.now().replace(microsecond=0)
# Run training loop.
# 训练的循环
state = solver.init_state(params, L2REG, next(test_ds), batch_stats) # 初始化优化器
jitted_update = jax.jit(functools.partial(lopt_update, self=solver)) # 艹,各种骚操作,这里的jax.jit是什么东西呀?
print(f'Iterations: {solver.maxiter}')
for _ in range(solver.maxiter): # 优化器的最大迭代次数
train_minibatch = next(train_ds)
if state.iter_num % iter_per_epoch_train == iter_per_epoch_train - 1:
# Once per epoch evaluate the model on the train and test sets.
test_acc, test_loss = 0., 0.
# make a pass over test set to compute test accuracy
for _ in range(iter_per_epoch_test):
tmp = accuracy_and_loss(params, L2REG, next(test_ds), batch_stats)
test_acc += tmp[0] / iter_per_epoch_test
test_loss += tmp[1] / iter_per_epoch_test
train_acc, train_loss = 0., 0.
# make a pass over train set to compute train accuracy
for _ in range(iter_per_epoch_train):
tmp = accuracy_and_loss(params, L2REG, next(train_ds), batch_stats)
train_acc += tmp[0] / iter_per_epoch_train
train_loss += tmp[1] / iter_per_epoch_train
train_acc = jax.device_get(train_acc)
train_loss = jax.device_get(train_loss)
test_acc = jax.device_get(test_acc)
test_loss = jax.device_get(test_loss)
# time elapsed without microseconds
time_elapsed = (datetime.now().replace(microsecond=0) - start)
print(f"[Epoch {(state.iter_num+1) // (iter_per_epoch_train+1)}/{EPOCHS}] "
f"Train acc: {train_acc:.3f}, train loss: {train_loss:.3f}. "
f"Test acc: {test_acc:.3f}, test loss: {test_loss:.3f}. "
f"Time elapsed: {time_elapsed}")
params, state = jitted_update(params=params,
state=state,
l2reg=L2REG,
data=train_minibatch,
aux=batch_stats)
batch_stats = state.aux
train()
静态的baseline数据
#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_acc = [0.235,
0.333,
0.428,
0.430,
0.480,
0.528,
0.591,
0.617,
0.661,
0.709,]
baseline_test_acc = [0.216,
0.298,
0.362,
0.343,
0.359,
0.371,
0.375,
0.377,
0.379,
0.399,]
velo_train_acc = [0.170,
0.270,
0.346,
0.331,
0.466,
0.477,
0.551,
0.749,
0.848,
0.955,]
velo_test_acc = [0.163,
0.255,
0.310,
0.290,
0.377,
0.369,
0.385,
0.458,
0.464,
0.492,]
from matplotlib.pyplot import figure
figure(figsize=(8, 6), dpi=80)
plt.plot(range(10), baseline_train_acc, label="Baseline Train Accuracy", c='b', linestyle='dashed')
plt.plot(range(10), baseline_test_acc, label = "Baseline Test Accuracy", c='b')
plt.plot(range(10), velo_train_acc, label= "VeLO Train Accuracy", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_acc, label="VeLO Test Accuracy", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Accuracy")
plt.title("Training Accuracy Curves for Resnet18 on Cifar100")
plt.legend()
plt.show()
#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_loss = [3.470,
2.979,
2.535,
2.567,
2.351,
2.183,
1.970,
1.925,
1.781,
1.644,]
baseline_test_loss = [3.571,
3.206,
2.899,
3.064,
3.055,
3.107,
3.170,
3.447,
3.530,
3.597,]
velo_train_loss = [3.701,
3.071,
2.771,
2.948,
2.294,
2.287,
2.059,
1.268,
0.948,
0.645,]
velo_test_loss = [3.739,
3.188,
2.974,
3.266,
2.797,
2.925,
3.062,
2.769,
2.950,
2.882]
from matplotlib.pyplot import figure
figure(figsize=(8, 6), dpi=80)
plt.plot(range(10), baseline_train_loss, label="Baseline Train Loss", c='b', linestyle='dashed')
plt.plot(range(10), baseline_test_loss, label = "Baseline Test Loss", c='b')
plt.plot(range(10), velo_train_loss, label= "VeLO Train Loss", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_loss, label="VeLO Test Loss", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Curves for Resnet18 on Cifar100 ")
plt.legend()
plt.show()