1、学习神经微分方程的笔记,主要锻炼自己学习新知识的能力和看有很多数学原理的论文能力;
2、神经微分方程可以用于时序数据建模、动力学建模等,但是本文专注于分类问题-resnet变体<比较容易理解>;
联合灵敏度的代码实现比较复杂,代码逻辑和算法步骤是一样的,对照看就很容易明白,其实本质上就是把梯度计算归结为求解微分的问题:
工程上实现OdeintAdjointMethod的方法是继承torch.autograd.Function类,实现forward和backward方法,将forward和backward替换成ODE求解器的方式,而不是用原来torch.autograd.Function的链式法则进行梯度求解。
梯度反传法是用于训练神经网络的方法,可以避免使用反向传播训练导数函数时所遇到的可扩展性问题。这种方法涉及使用普通微分方程(ODE)求解器进行前向传播,然后使用联合灵敏度方法进行反向传播,从而使得可以再次使用ODE求解器进行反向传播。为了更新导数函数的参数,需要使用联合灵敏度方法获取损失函数相对于动态函数参数的梯度。最终算法涉及设置某些变量并将信息打包到其中,然后调用ODE求解器反向传播以获得theta并更新CNN编码器和导数函数参数。
梯度反传法算法流程如下:
更完整的版本:
联合灵敏度的代码实现比较复杂,代码逻辑和算法步骤是一样的,对照看就很容易明白,其实本质上就是把梯度计算归结为求解微分的问题:
**工程上实现OdeintAdjointMethod的方法是继承torch.autograd.Function类,实现forward和backward方法,将forward和backward替换成ODE求解器的方式,而不是用原来torch.autograd.Function的链式法则进行梯度求解。**odeint则是本文使用的ODE求解器。
代码仓库提供的求解器种类如下:
SOLVERS = {
'dopri8': Dopri8Solver,
'dopri5': Dopri5Solver,
'bosh3': Bosh3Solver,
'fehlberg2': Fehlberg2,
'adaptive_heun': AdaptiveHeunSolver,
'euler': Euler,
'midpoint': Midpoint,
'rk4': RK4,
'explicit_adams': AdamsBashforth,
'implicit_adams': AdamsBashforthMoulton,
# Backward compatibility: use the same name as before
'fixed_adams': AdamsBashforthMoulton,
# ~Backwards compatibility
'scipy_solver': ScipyWrapperODESolver,
}
完整源码在:rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com),这里放上核心部分和注释,前向和反向传播部分;
class OdeintAdjointMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
adjoint_options, t_requires_grad, *adjoint_params):
ctx.shapes = shapes
ctx.func = func
ctx.adjoint_rtol = adjoint_rtol
ctx.adjoint_atol = adjoint_atol
ctx.adjoint_method = adjoint_method
ctx.adjoint_options = adjoint_options
ctx.t_requires_grad = t_requires_grad
ctx.event_mode = event_fn is not None
with torch.no_grad():
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)
if event_fn is None:
y = ans
ctx.save_for_backward(t, y, *adjoint_params)
else:
event_t, y = ans
ctx.save_for_backward(t, y, event_t, *adjoint_params)
return ans
@staticmethod
def backward(ctx, *grad_y):
with torch.no_grad():
func = ctx.func
adjoint_rtol = ctx.adjoint_rtol
adjoint_atol = ctx.adjoint_atol
adjoint_method = ctx.adjoint_method
adjoint_options = ctx.adjoint_options
t_requires_grad = ctx.t_requires_grad
# 反向传播如果积分到达时间,不会在事件时间内反向传播。
# Backprop as if integrating up to event time.
# Does NOT backpropagate through the event time.
event_mode = ctx.event_mode
if event_mode:
t, y, event_t, *adjoint_params = ctx.saved_tensors
_t = t
t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
grad_y = grad_y[1]
else:
t, y, *adjoint_params = ctx.saved_tensors
grad_y = grad_y[0]
adjoint_params = tuple(adjoint_params)
##################################
# 创建初始化状态 #
##################################
# [-1] because y and grad_y are both of shape (len(t), *y0.shape)
aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y
aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params
##################################
# 创建反向ODE函数 #
##################################
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
def augmented_dynamics(t, y_aug):
# 动力学函数
# Dynamics of the original system augmented with
# the adjoint wrt y, and an integrator wrt t and args.
y = y_aug[1]
adj_y = y_aug[2]
# ignore gradients wrt time and parameters
with torch.enable_grad():
t_ = t.detach()
t = t_.requires_grad_(True)
y = y.detach().requires_grad_(True)
# If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
# doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
#如果使用自适应求解器,不想浪费时间来求解dL/dt,除非我们需要它(如果有分段结构,它甚至不存在),所以关闭梯度
# wrt t here means we won't compute that if we don't need it.
func_eval = func(t if t_requires_grad else t_, y)
# Workaround for PyTorch bug #39784
_t = torch.as_strided(t, (), ()) # noqa
_y = torch.as_strided(y, (), ()) # noqa
_params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa
vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
func_eval, (t, y) + adjoint_params, -adj_y,
allow_unused=True, retain_graph=True
)
# autograd.grad returns None if no gradient, set to zero.
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
for param, vjp_param in zip(adjoint_params, vjp_params)]
return (vjp_t, func_eval, vjp_y, *vjp_params)
##################################
# 求解联合ODE #
##################################
if t_requires_grad:
time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
else:
time_vjps = None
for i in range(len(t) - 1, 0, -1):
if t_requires_grad:
# Compute the effect of moving the current time measurement point.
# We don't compute this unless we need to, to save some computation.
func_eval = func(t[i], y[i])
dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
aug_state[0] -= dLd_cur_t
time_vjps[i] = dLd_cur_t
# Run the augmented system backwards in time.
# 运行增强系统反向
aug_state = odeint(
augmented_dynamics, tuple(aug_state),
t[i - 1:i + 1].flip(0),
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
)
aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state
aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point
if t_requires_grad:
time_vjps[0] = aug_state[0]
# 计算梯度
# Only compute gradient wrt initial time when in event handling mode.
if event_mode and t_requires_grad:
time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])
adj_y = aug_state[2]
adj_params = aug_state[3:]
return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)
使用small的residual net,在下采样后,有6个标准残差块被ODESolve替代,得到ODE-Net。
Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
(1): GroupNorm(32, 64, eps=1e-05, affine=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(4): GroupNorm(32, 64, eps=1e-05, affine=True)
(5): ReLU(inplace=True)
(6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(7): ODEBlock(
(odefunc): ODEfunc(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(relu): ReLU(inplace=True)
(conv1): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv2): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
)
)
(8): GroupNorm(32, 64, eps=1e-05, affine=True)
(9): ReLU(inplace=True)
(10): AdaptiveAvgPool2d(output_size=(1, 1))
(11): Flatten()
(12): Linear(in_features=64, out_features=10, bias=True)
)
结果如下:
参数量和内存比Resnet优,但是时间没有明确展示出来,只是展示时间复杂度;
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()
if args.adjoint:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def norm(dim):
return nn.GroupNorm(min(32, dim), dim)
class ResBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.norm1 = norm(inplanes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.conv1 = conv3x3(inplanes, planes, stride)
self.norm2 = norm(planes)
self.conv2 = conv3x3(planes, planes)
def forward(self, x):
shortcut = x
out = self.relu(self.norm1(x))
if self.downsample is not None:
shortcut = self.downsample(out)
out = self.conv1(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(out)
return out + shortcut
class ConcatConv2d(nn.Module):
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
super(ConcatConv2d, self).__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
self._layer = module(
dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=bias
)
def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)
class ODEfunc(nn.Module):
def __init__(self, dim):
super(ODEfunc, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
self.nfe = 0
def forward(self, t, x):
self.nfe += 1
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
return out
class ODEBlock(nn.Module):
def __init__(self, odefunc):
super(ODEBlock, self).__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, 1]).float()
def forward(self, x):
self.integration_time = self.integration_time.type_as(x)
out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
return out[1]
@property
def nfe(self):
return self.odefunc.nfe
@nfe.setter
def nfe(self, value):
self.odefunc.nfe = value
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
class RunningAverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, momentum=0.99):
self.momentum = momentum
self.reset()
def reset(self):
self.val = None
self.avg = 0
def update(self, val):
if self.val is None:
self.avg = val
else:
self.avg = self.avg * self.momentum + val * (1 - self.momentum)
self.val = val
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
if data_aug:
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
])
else:
transform_train = transforms.Compose([
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
])
train_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
shuffle=True, num_workers=2, drop_last=True
)
train_eval_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
test_loader = DataLoader(
datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
return train_loader, test_loader, train_eval_loader
def inf_generator(iterable):
"""Allows training with DataLoaders in a single infinite loop:
for i, (x, y) in enumerate(inf_generator(train_loader)):
"""
iterator = iterable.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = iterable.__iter__()
def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
initial_learning_rate = args.lr * batch_size / batch_denom
boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(itr):
lt = [itr < b for b in boundaries] + [True]
i = np.argmax(lt)
return vals[i]
return learning_rate_fn
def one_hot(x, K):
return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)
def accuracy(model, dataset_loader):
total_correct = 0
for x, y in dataset_loader:
x = x.to(device)
y = one_hot(np.array(y.numpy()), 10)
target_class = np.argmax(y, axis=1)
predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
total_correct += np.sum(predicted_class == target_class)
return total_correct / len(dataset_loader.dataset)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
logger = logging.getLogger()
if debug:
level = logging.DEBUG
else:
level = logging.INFO
logger.setLevel(level)
if saving:
info_file_handler = logging.FileHandler(logpath, mode="a")
info_file_handler.setLevel(level)
logger.addHandler(info_file_handler)
if displaying:
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
logger.addHandler(console_handler)
logger.info(filepath)
with open(filepath, "r") as f:
logger.info(f.read())
for f in package_files:
logger.info(f)
with open(f, "r") as package_f:
logger.info(package_f.read())
return logger
if __name__ == '__main__':
makedirs(args.save)
logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
logger.info(args)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
is_odenet = args.network == 'odenet'
if args.downsampling_method == 'conv':
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
norm(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 4, 2, 1),
]
elif args.downsampling_method == 'res':
downsampling_layers = [
nn.Conv2d(1, 64, 3, 1),
ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
]
feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
logger.info(model)
logger.info('Number of parameters: {}'.format(count_parameters(model)))
criterion = nn.CrossEntropyLoss().to(device)
train_loader, test_loader, train_eval_loader = get_mnist_loaders(
args.data_aug, args.batch_size, args.test_batch_size
)
data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)
lr_fn = learning_rate_with_decay(
args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
decay_rates=[1, 0.1, 0.01, 0.001]
)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
best_acc = 0
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()
for itr in range(args.nepochs * batches_per_epoch):
for param_group in optimizer.param_groups:
param_group['lr'] = lr_fn(itr)
optimizer.zero_grad()
x, y = data_gen.__next__()
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criterion(logits, y)
if is_odenet:
nfe_forward = feature_layers[0].nfe
feature_layers[0].nfe = 0
loss.backward()
optimizer.step()
if is_odenet:
nfe_backward = feature_layers[0].nfe
feature_layers[0].nfe = 0
batch_time_meter.update(time.time() - end)
if is_odenet:
f_nfe_meter.update(nfe_forward)
b_nfe_meter.update(nfe_backward)
end = time.time()
if itr % batches_per_epoch == 0:
with torch.no_grad():
train_acc = accuracy(model, train_eval_loader)
val_acc = accuracy(model, test_loader)
if val_acc > best_acc:
torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
best_acc = val_acc
logger.info(
"Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
"Train Acc {:.4f} | Test Acc {:.4f}".format(
itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
b_nfe_meter.avg, train_acc, val_acc
)
)
训练过程:
Epoch 0000 | Time 3.425 (3.425) | NFE-F 32.0 | NFE-B 0.0 | Train Acc 0.0987 | Test Acc 0.0958
Epoch 0001 | Time 3.279 (0.840) | NFE-F 20.3 | NFE-B 0.0 | Train Acc 0.9755 | Test Acc 0.9779
Epoch 0002 | Time 3.500 (0.839) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9858 | Test Acc 0.9875
Epoch 0003 | Time 3.403 (0.828) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9884 | Test Acc 0.9879
Epoch 0004 | Time 3.303 (0.807) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9926 | Test Acc 0.9921
Epoch 0005 | Time 3.308 (0.801) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9940 | Test Acc 0.9930
Epoch 0006 | Time 3.255 (0.804) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9917 | Test Acc 0.9894
Epoch 0007 | Time 3.376 (0.808) | NFE-F 20.2 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9929
Epoch 0008 | Time 3.260 (0.806) | NFE-F 20.1 | NFE-B 0.0 | Train Acc 0.9935 | Test Acc 0.9934
Epoch 0009 | Time 3.248 (0.832) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9948 | Test Acc 0.9909
Epoch 0010 | Time 3.286 (0.817) | NFE-F 20.4 | NFE-B 0.0 | Train Acc 0.9959 | Test Acc 0.9947
Epoch 0011 | Time 3.281 (0.827) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9967 | Test Acc 0.9951
Epoch 0012 | Time 3.382 (0.825) | NFE-F 20.9 | NFE-B 0.0 | Train Acc 0.9949 | Test Acc 0.9929
Epoch 0013 | Time 3.299 (0.862) | NFE-F 22.0 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9949
Epoch 0014 | Time 3.326 (0.824) | NFE-F 20.8 | NFE-B 0.0 | Train Acc 0.9947 | Test Acc 0.9936
Epoch 0015 | Time 3.291 (0.839) | NFE-F 21.3 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9948
Epoch 0016 | Time 3.467 (0.935) | NFE-F 24.4 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9941
Epoch 0017 | Time 3.483 (0.900) | NFE-F 23.2 | NFE-B 0.0 | Train Acc 0.9970 | Test Acc 0.9939
Epoch 0018 | Time 3.309 (0.872) | NFE-F 22.2 | NFE-B 0.0 | Train Acc 0.9961 | Test Acc 0.9932
Epoch 0019 | Time 3.294 (0.913) | NFE-F 23.6 | NFE-B 0.0 | Train Acc 0.9974 | Test Acc 0.9954
Epoch 0020 | Time 3.504 (0.984) | NFE-F 25.9 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9951
Epoch 0021 | Time 3.589 (0.966) | NFE-F 25.2 | NFE-B 0.0 | Train Acc 0.9966 | Test Acc 0.9929
Epoch 0022 | Time 3.503 (0.994) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9977 | Test Acc 0.9949
Epoch 0023 | Time 3.457 (0.995) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9978 | Test Acc 0.9939
Epoch 0024 | Time 3.529 (0.985) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9958
Epoch 0025 | Time 3.459 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9973 | Test Acc 0.9947
Epoch 0026 | Time 3.541 (0.988) | NFE-F 26.0 | NFE-B 0.0 | Train Acc 0.9979 | Test Acc 0.9946
Epoch 0027 | Time 3.513 (0.993) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9959
Epoch 0028 | Time 3.505 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9982 | Test Acc 0.9953
Epoch 0029 | Time 3.501 (0.990) | NFE-F 26.1 | NFE-B 0.0 | Train Acc 0.9985 | Test Acc 0.9953
Epoch 0030 | Time 3.475 (0.992) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9983 | Test Acc 0.9954
Epoch 0031 | Time 3.506 (0.993) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9986 | Test Acc 0.9947
Epoch 0032 | Time 3.527 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9981 | Test Acc 0.9954
Epoch 0033 | Time 3.529 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9976 | Test Acc 0.9945
Epoch 0034 | Time 3.545 (0.996) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9959
Epoch 0035 | Time 3.479 (0.995) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9990 | Test Acc 0.9953
Epoch 0036 | Time 3.479 (0.997) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9989 | Test Acc 0.9963
Epoch 0037 | Time 3.540 (0.998) | NFE-F 26.2 | NFE-B 0.0 | Train Acc 0.9988 | Test Acc 0.9957
1、神经常微分方程 (Neural ODE):入门教程 - 知乎 (zhihu.com)
2、rtqichen/torchdiffeq: Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation. (github.com)