Code for VeLO 2: Training Versatile Learned Optimizers by Scaling Up

Code for VeLO 2: Training Versatile Learned Optimizers by Scaling Up

上一篇文章已经介绍了怎么训练一个MLP网络,这篇文章将介绍一下怎么用VeLO训练resnets

这篇文章基于https://colab.research.google.com/drive/1-ms12IypE-EdDSNjhFMdRdBbMnH94zpH#scrollTo=RQBACAPQZyB-,将介绍使用learned optimizer in the VeLO family:

  • 一个简单的图片识别人物
  • resetnets
#@title imports, configuration, and model classes

from absl import app
from datetime import datetime

from functools import partial
from typing import Any, Callable, Sequence, Tuple

from flax import linen as nn

import jax
import jax.numpy as jnp

from jaxopt import loss
from jaxopt import OptaxSolver
from jaxopt import tree_util

import optax

import tensorflow_datasets as tfds
import tensorflow as tf


# 可以使用的数据集
dataset_names = [
    "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"
]


L2REG = 1e-4
LEARNING_RATE = .2
EPOCHS = 10
MOMENTUM = .9
DATASET = 'cifar100' #@param [ "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"]
MODEL = 'resnet18' #@param ["resnet1", "resnet18", "resnet34"]
TRAIN_BATCH_SIZE = 256
TEST_BATCH_SIZE = 1024


# 加载数据集
def load_dataset(split, *, is_training, batch_size):
  version = 3
  ds, ds_info = tfds.load(
      f"{DATASET}:{version}.*.*",
      as_supervised=True,  # remove useless keys
      split=split,
      with_info=True)
  ds = ds.cache().repeat()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  return iter(tfds.as_numpy(ds)), ds_info


class ResNetBlock(nn.Module):
  """ResNet block."""
  filters: int
  conv: Any
  norm: Any
  act: Callable
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x,):
    residual = x
    y = self.conv(self.filters, (3, 3), self.strides)(x)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters, (3, 3))(y)
    y = self.norm(scale_init=nn.initializers.zeros)(y)

    if residual.shape != y.shape:
      residual = self.conv(self.filters, (1, 1),
                           self.strides, name='conv_proj')(residual)
      residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)


class ResNet(nn.Module):
  """ResNetV1."""
  stage_sizes: Sequence[int]
  block_cls: Any
  num_classes: int
  num_filters: int = 64
  dtype: Any = jnp.float32
  act: Callable = nn.relu

  @nn.compact
  def __call__(self, x, train: bool = True):
    conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
    norm = partial(nn.BatchNorm,
                   use_running_average=not train,
                   momentum=0.9,
                   epsilon=1e-5,
                   dtype=self.dtype)

    x = conv(self.num_filters, (7, 7), (2, 2),
             padding=[(3, 3), (3, 3)],
             name='conv_init')(x)
    x = norm(name='bn_init')(x)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(self.stage_sizes):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = self.block_cls(self.num_filters * 2 ** i,
                           strides=strides,
                           conv=conv,
                           norm=norm,
                           act=self.act)(x)
    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
    x = jnp.asarray(x, self.dtype)
    return x


# 虽然不太清楚为啥ResNet为啥没有__init__函数,但是估计又是python某个不知名的骚操作吧 emmm 我看它__call__这个函数也写的挺骚的。
ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock)
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)



#@title training loop definition (run this cell to launch training)
import functools
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src import tree_util


# 这个类的目的只是为了保存状态信息的吧
class OptaxState(NamedTuple):
  """Named tuple containing state information."""
  iter_num: int # 迭代的数量
  value: float  # value
  error: float  # 
  internal_state: NamedTuple
  aux: Any

# we need to reimplement optax's OptaxSolver's lopt_update method to properly pass in the loss data that VeLO expects.
def lopt_update(self,
            params: Any,
            state: NamedTuple,
            *args,
            **kwargs) -> base.OptStep:
  """Performs one iteration of the optax solver.

  Args:
    params: pytree containing the parameters.  应该是resnet参数的pytree
    state: named tuple containing the solver state.  
    *args: additional positional arguments to be passed to ``fun``.
    **kwargs: additional keyword arguments to be passed to ``fun``.
  Returns:
    (params, state)
  """
  if self.pre_update:
    params, state = self.pre_update(params, state, *args, **kwargs)

  (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)

  # note the only difference between this function and the baseline 
  # optax.OptaxSolver.lopt_update is that `extra_args` is now passed.
  # if you would like to use a different optimizer, you will likely need to
  # remove these extra_args.
  delta, opt_state = self.opt.update(grad, state.internal_state, params, extra_args={"loss": value})
  params = self._apply_updates(params, delta)

  # Computes optimality error before update to re-use grad evaluation.
  new_state = OptaxState(iter_num=state.iter_num + 1,
                          error=tree_util.tree_l2_norm(grad),
                          value=value,
                          aux=aux,
                          internal_state=opt_state)
  return base.OptStep(params=params, state=new_state)




def train():

  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  # tf.config.experimental.set_visible_devices([], 'GPU')


  # typical data loading and iterator setup

  train_ds, ds_info = load_dataset("train", is_training=True,
                                    batch_size=TRAIN_BATCH_SIZE)
  test_ds, _ = load_dataset("test", is_training=False,
                            batch_size=TEST_BATCH_SIZE)
  input_shape = (1,) + ds_info.features["image"].shape
  num_classes = ds_info.features["label"].num_classes
  iter_per_epoch_train = ds_info.splits['train'].num_examples // TRAIN_BATCH_SIZE
  iter_per_epoch_test = ds_info.splits['test'].num_examples // TEST_BATCH_SIZE

  # Set up model.
  if MODEL == "resnet1":
    net = ResNet1(num_classes=num_classes)
  elif MODEL == "resnet18":
    net = ResNet18(num_classes=num_classes)
  elif MODEL == "resnet34":
    net = ResNet34(num_classes=num_classes)
  else:
    raise ValueError("Unknown model.")

  def predict(params, inputs, aux, train=False):
    x = inputs.astype(jnp.float32) / 255.
    all_params = {"params": params, "batch_stats": aux}
    if train:
      # Returns logits and net_state (which contains the key "batch_stats").
      return net.apply(all_params, x, train=True, mutable=["batch_stats"])
    else:
      # Returns logits only.
      return net.apply(all_params, x, train=False)

  logistic_loss = jax.vmap(loss.multiclass_logistic_loss)

  def loss_from_logits(params, l2reg, logits, labels):
    mean_loss = jnp.mean(logistic_loss(labels, logits))
    sqnorm = tree_util.tree_l2_norm(params, squared=True)
    return mean_loss + 0.5 * l2reg * sqnorm

  def accuracy_and_loss(params, l2reg, data, aux):
    inputs, labels = data
    logits = predict(params, inputs, aux)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    loss = loss_from_logits(params, l2reg, logits, labels)
    return accuracy, loss

  def loss_fun(params, l2reg, data, aux):
    inputs, labels = data
    logits, net_state = predict(params, inputs, aux, train=True)
    loss = loss_from_logits(params, l2reg, logits, labels)
    # batch_stats will be stored in state.aux
    return loss, net_state["batch_stats"]

  # The default optimizer used by jaxopt is commented out here
  # opt = optax.sgd(learning_rate=LEARNING_RATE,
  #                 momentum=MOMENTUM,
  #                 nesterov=True)
  
  NUM_STEPS = EPOCHS * iter_per_epoch_train
  opt = prefab.optax_lopt(NUM_STEPS)
  
  # We need has_aux=True because loss_fun returns batch_stats.
  solver = OptaxSolver(opt=opt, 
                       fun=jax.value_and_grad(loss_fun, has_aux=True), 
                       maxiter=EPOCHS * iter_per_epoch_train, 
                       has_aux=True, 
                       value_and_grad=True)

  # Initialize parameters.
  # 初始化训练的参数
  rng = jax.random.PRNGKey(0)
  init_vars = net.init(rng, jnp.zeros(input_shape), train=True)  # 这里的net是resnet,但是我不清楚这里的 init_vars['params']是个什么东西 emm
  params = init_vars["params"]
  batch_stats = init_vars["batch_stats"]
  start = datetime.now().replace(microsecond=0)

  # Run training loop.
  # 训练的循环
  state = solver.init_state(params, L2REG, next(test_ds), batch_stats)  # 初始化优化器
  jitted_update = jax.jit(functools.partial(lopt_update, self=solver))  # 艹,各种骚操作,这里的jax.jit是什么东西呀?
  print(f'Iterations: {solver.maxiter}')

  for _ in range(solver.maxiter):  # 优化器的最大迭代次数
    train_minibatch = next(train_ds)

    if state.iter_num % iter_per_epoch_train == iter_per_epoch_train - 1:
      # Once per epoch evaluate the model on the train and test sets.
      test_acc, test_loss = 0., 0.
      # make a pass over test set to compute test accuracy
      for _ in range(iter_per_epoch_test):
          tmp = accuracy_and_loss(params, L2REG, next(test_ds), batch_stats)
          test_acc += tmp[0] / iter_per_epoch_test
          test_loss += tmp[1] / iter_per_epoch_test

      train_acc, train_loss = 0., 0.
      # make a pass over train set to compute train accuracy
      for _ in range(iter_per_epoch_train):
          tmp = accuracy_and_loss(params, L2REG, next(train_ds), batch_stats)
          train_acc += tmp[0] / iter_per_epoch_train
          train_loss += tmp[1] / iter_per_epoch_train

      train_acc = jax.device_get(train_acc)
      train_loss = jax.device_get(train_loss)
      test_acc = jax.device_get(test_acc)
      test_loss = jax.device_get(test_loss)
      # time elapsed without microseconds
      time_elapsed = (datetime.now().replace(microsecond=0) - start)

      print(f"[Epoch {(state.iter_num+1) // (iter_per_epoch_train+1)}/{EPOCHS}] "
            f"Train acc: {train_acc:.3f}, train loss: {train_loss:.3f}. "
            f"Test acc: {test_acc:.3f}, test loss: {test_loss:.3f}. "
            f"Time elapsed: {time_elapsed}")

    params, state = jitted_update(params=params,
                                  state=state,
                                  l2reg=L2REG,
                                  data=train_minibatch,
                                  aux=batch_stats)
    batch_stats = state.aux

train()

跟baseline的值比较一下

静态的baseline数据

#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_acc = [0.235,
0.333,
0.428,
0.430,
0.480,
0.528,
0.591,
0.617,
0.661,
0.709,]

baseline_test_acc = [0.216,
0.298,
0.362,
0.343,
0.359,
0.371,
0.375,
0.377,
0.379,
0.399,]


velo_train_acc = [0.170,
0.270,
0.346,
0.331,
0.466,
0.477,
0.551,
0.749,
0.848,
0.955,]

velo_test_acc = [0.163,
0.255,
0.310,
0.290,
0.377,
0.369,
0.385,
0.458,
0.464,
0.492,]


from matplotlib.pyplot import figure

figure(figsize=(8, 6), dpi=80)

plt.plot(range(10), baseline_train_acc, label="Baseline Train Accuracy", c='b',  linestyle='dashed')
plt.plot(range(10), baseline_test_acc, label = "Baseline Test Accuracy", c='b')
plt.plot(range(10), velo_train_acc, label= "VeLO Train Accuracy", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_acc, label="VeLO Test Accuracy", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Accuracy")
plt.title("Training Accuracy Curves for Resnet18 on Cifar100")
plt.legend()
plt.show()






#@title Comparing a baseline optimizer vs. VeLO on resnet18 cifar100.
baseline_train_loss = [3.470,
2.979,
2.535,
2.567,
2.351,
2.183,
1.970,
1.925,
1.781,
1.644,]

baseline_test_loss = [3.571,
3.206,
2.899,
3.064,
3.055,
3.107,
3.170,
3.447,
3.530,
3.597,]


velo_train_loss = [3.701,
3.071,
2.771,
2.948,
2.294,
2.287,
2.059,
1.268,
0.948,
0.645,]

velo_test_loss = [3.739,
3.188,
2.974,
3.266,
2.797,
2.925,
3.062,
2.769,
2.950,
2.882]


from matplotlib.pyplot import figure

figure(figsize=(8, 6), dpi=80)

plt.plot(range(10), baseline_train_loss, label="Baseline Train Loss", c='b',  linestyle='dashed')
plt.plot(range(10), baseline_test_loss, label = "Baseline Test Loss", c='b')
plt.plot(range(10), velo_train_loss, label= "VeLO Train Loss", c='r', linestyle='dashed')
plt.plot(range(10), velo_test_loss, label="VeLO Test Loss", c='r')
plt.xlabel("Training Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Curves for Resnet18 on Cifar100 ")
plt.legend()
plt.show()

你可能感兴趣的:(我的科研之路~,深度学习,python,人工智能)