接下来将按照顺序讲解每一个文件的作用
ab_mixmatch.py
这段代码定义了额外的标志来测试MixMatch方法实现的不同部分。MixMatch算法是一种半监督学习方法,利用标记和未标记的数据来训练模型。
import functools
import os
from absl import app
from absl import flags
from easydict import EasyDict
from libml import layers, utils, models
from libml.data_pair import DATASETS
from libml.layers import MixMode
import tensorflow as tf
这段代码是一个 Python 代码文件的一部分,它使用了一些常用的 Python 库和自定义库来实现深度学习的数据处理和模型训练。
下面是这段代码的主要作用和功能:
1、导入必要的Python库:
functools
:用于高阶函数编程。os
:用于与操作系统进行交互,例如获取环境变量和文件路径等。absl
:一个用于 Python 应用程序的命令行参数解析器。easydict
:提供了一种更加方便的字典方式来访问字典对象中的元素。2.导入自定义库:
libml
:这是一个自定义的 Python 库,包含了一些用于深度学习的数据处理和模型训练的模块。在这段代码中,我们使用了 layers
、utils
和 models
模块。libml.data_pair
:这是一个自定义的 Python 模块,它包含了一些用于深度学习的数据处理的方法。3.定义一个 MixMode
枚举变量,用于表示数据集混合的模式。
4.使用 TensorFlow 2.x 版本的 API 构建深度学习模型。
FLAGS = flags.FLAGS
这一行代码定义了一个全局变量 FLAGS
,它是 absl.flags.FLAGS
对象的一个实例。这个实例用于存储和管理命令行参数,以便在 Python 应用程序中使用这些参数。
在使用 absl.flags
库时,首先需要创建一个 FLAGS
对象实例,然后可以使用 DEFINE_xxx()
方法来定义命令行参数。在程序中引用这些参数时,可以通过 FLAGS.xxx
的方式来访问它们的值。
class AblationMixMatch(models.MultiModel):
def augment(self, x, l, beta, **kwargs):
assert 0, 'Do not call.'
def guess_label(self, y, classifier, T, getter, **kwargs):
del kwargs
logits_y = [classifier(yi, training=True, getter=getter) for yi in y]
logits_y = tf.concat(logits_y, 0)
# Compute predicted probability distribution py.
p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
p_model_y = tf.reduce_mean(p_model_y, axis=0)
# Compute the target distribution.
p_target = tf.pow(p_model_y, 1. / T)
p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
return EasyDict(p_target=p_target, p_model=p_model_y)
def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
l_in = tf.placeholder(tf.int32, [None], 'labels')
wd *= lr
w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
augment = MixMode(mixmode)
classifier = functools.partial(self.classifier, **kwargs)
classifier(x_in, training=True) # Instantiate network.
ema = tf.train.ExponentialMovingAverage(decay=ema)
ema_op = ema.apply(utils.model_vars())
ema_getter = functools.partial(utils.getter_ema, ema)
y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
guess = self.guess_label(tf.split(y, nu), classifier,
getter=ema_getter if use_ema_guess else None, **kwargs)
ly = tf.stop_gradient(guess.p_target)
lx = tf.one_hot(l_in, self.nclass)
xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
x, y = xy[0], xy[1:]
labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
del xy, labels_xy
batches = layers.interleave([x] + y, batch)
skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
logits = [classifier(batches[0], training=True)]
post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
for batchi in batches[1:]:
logits.append(classifier(batchi, training=True))
logits = layers.interleave(logits, batch)
logits_x = logits[0]
logits_y = tf.concat(logits[1:], 0)
loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
loss_xe = tf.reduce_mean(loss_xe)
loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
loss_l2u = tf.reduce_mean(loss_l2u)
tf.summary.scalar('losses/xe', loss_xe)
tf.summary.scalar('losses/l2u', loss_l2u)
post_ops.append(ema_op)
post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
with tf.control_dependencies([train_op]):
train_op = tf.group(*post_ops)
# Tuning op: only retrain batch norm.
skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
classifier(batches[0], training=True)
train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if v not in skip_ops])
return EasyDict(
x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
这段代码实现了一个名为AblationMixMatch的模型类,继承自MultiModel类,用于训练一个混合标签的半监督学习模型。其中的函数和参数的作用如下:
在model()函数中,首先定义了输入placeholder的形状,然后对分类器进行了实例化,同时对模型进行了初始化和平均操作。接下来,对数据进行了数据增强和混合,然后将增强后的数据分别送入分类器中进行训练,并计算交叉熵和L2损失。最后定义了训练和调参操作,以及输出分类器的原始输出和指数平均后的输出。
具体来说,就是
def augment(self, x, l, beta, **kwargs):
assert 0, 'Do not call.'
这个方法是一个占位符,代码中没有实际使用到。它被定义在 AblationMixMatch 类中作为一个抽象方法。如果这个方法被调用,代码会抛出一个异常,提示不应该直接调用它。这种设计方式通常是为了让子类必须实现这个方法,而不是使用父类的默认实现。在本例中,它的目的可能是为了强制子类实现一个数据增强的方法。
def guess_label(self, y, classifier, T, getter, **kwargs):
del kwargs
logits_y = [classifier(yi, training=True, getter=getter) for yi in y]
logits_y = tf.concat(logits_y, 0)
# Compute predicted probability distribution py.
p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
p_model_y = tf.reduce_mean(p_model_y, axis=0)
# Compute the target distribution.
p_target = tf.pow(p_model_y, 1. / T)
p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
return EasyDict(p_target=p_target, p_model=p_model_y)
这段代码是实现了一个模型对给定的标签 y 进行预测,其使用了一个分类器对标签进行推断,同时还有一个温度参数 T,用于控制预测的概率分布平滑程度。
首先,对每个标签 y,使用分类器classifier得到一个输出 logits_y,将所有的logits_y在第0个维度上进原始的logits_y是一个列表,得到一个新的张量。假设原始的logits_y每个元素都是形状为[batch_size, num_classes]的张量,那么拼接后的张量形状为[(len(logits_y) * batch_size), num_classes],其中len(logits_y)表示logits_y列表的长度。(这一步得到的是所有logits_y的值,)
然后,使用 softmax 函数将其转换为概率分布 p_model_y。
接着,将所有标签的概率分布 p_model_y 求平均得到整个数据集的概率分布(这里的平均就是按列先求和再平均)。
最后,使用温度参数 T 对整个数据集的概率分布进行平滑,得到目标分布 p_target。该函数的返回值包含了目标分布 p_target 和整个数据集的概率分布 p_model_y。
是对张量p_target在第1个维度(即num_classes维度)上进行归一化,得到一个新的张量p_target。具体地说,如果原始的p_target是一个形状为[batch_size, num_classes]的张量,那么经过reduce_sum操作后,得到的是一个形状为[batch_size, 1]的张量,其中每个元素是原始张量在该维度上的和。接着,使用除法操作将原始张量中的每个元素除以对应的和,从而得到新张量p_target。
def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
l_in = tf.placeholder(tf.int32, [None], 'labels')
wd *= lr
w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
augment = MixMode(mixmode)
classifier = functools.partial(self.classifier, **kwargs)
classifier(x_in, training=True) # Instantiate network.
ema = tf.train.ExponentialMovingAverage(decay=ema)
ema_op = ema.apply(utils.model_vars())
ema_getter = functools.partial(utils.getter_ema, ema)
y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
guess = self.guess_label(tf.split(y, nu), classifier,
getter=ema_getter if use_ema_guess else None, **kwargs)
ly = tf.stop_gradient(guess.p_target)
lx = tf.one_hot(l_in, self.nclass)
xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
x, y = xy[0], xy[1:]
labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
del xy, labels_xy
batches = layers.interleave([x] + y, batch)
skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
logits = [classifier(batches[0], training=True)]
post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
for batchi in batches[1:]:
logits.append(classifier(batchi, training=True))
logits = layers.interleave(logits, batch)
logits_x = logits[0]
logits_y = tf.concat(logits[1:], 0)
loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
loss_xe = tf.reduce_mean(loss_xe)
loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
loss_l2u = tf.reduce_mean(loss_l2u)
tf.summary.scalar('losses/xe', loss_xe)
tf.summary.scalar('losses/l2u', loss_l2u)
post_ops.append(ema_op)
post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
with tf.control_dependencies([train_op]):
train_op = tf.group(*post_ops)
# Tuning op: only retrain batch norm.
skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
classifier(batches[0], training=True)
train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if v not in skip_ops])
return EasyDict(
x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
该段代码是一个方法model,包含了训练模型的整个过程。该方法接受一些参数,如nu、w_match、warmup_kimg、batch、lr、wd、ema、beta、mixmode、use_ema_guess等,其中x_in表示输入图像,y_in表示与x_in对应的标签图像,l_in表示标签的类别。该方法的目标是在给定标签样本的情况下,使用半监督学习算法来训练分类模型。
该方法的大致流程如下:
定义输入placeholder:x_in、y_in、l_in。
对于给定的参数,进行预处理:计算wd * lr、w_match、augment、classifier等。
对输入图像x_in进行一次前向传播,以便实例化网络。同时,定义一个ExponentialMovingAverage对象ema,并应用于模型变量。
将标签图像y_in展开成一维张量,并根据guess_label方法和classifier对其进行预测。其中,guess_label方法会使用模型对标签图像进行猜测,并返回猜测后的标签,即p_target。使用tf.stop_gradient方法对p_target进行梯度截断,以防止误差反向传播。
对标签图像进行one-hot编码,得到labels_x,将x_in和y用MixMode方法进行数据增强,并将labels_x和p_target合并成labels_y。
将增强后的数据集拆分成batch,并使用分类器对每个batch进行前向传播,得到对应的logits。将logits_x和logits_y分别提取出来。
计算交叉熵损失loss_xe和l2正则化损失loss_l2u,并计算它们的平均值。
进行优化操作。首先,使用Adam优化器对loss_xe和w_match * loss_l2u进行优化。然后,将ema_op和model_vars()中所有名称带有kernel的变量进行指数滑动平均操作,再将它们乘以(1-wd)进行权重衰减。最后,将所有操作合并成train_op。
对于调参,只重新训练batch norm,即将所有除skip_ops外的其他更新操作合并为train_bn。
返回一个EasyDict对象,包含了x_in、y_in、l_in、train_op、tune_op、classify_raw、classify_op等。其中,classify_raw表示在没有应用ema时,分类模型对x_in进行前向传播的结果,classify_op表示在应用ema之后,分类模型对x_in进行前向传播的结果。
展开来讲:
def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
l_in = tf.placeholder(tf.int32, [None], 'labels')
wd *= lr
w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
augment = MixMode(mixmode)
classifier = functools.partial(self.classifier, **kwargs)
这是一个模型定义函数,输入参数包括nu(无标签数据集的数量)、w_match(匹配权重)、warmup_kimg(预热步长)、batch(批大小)、lr(学习率)、wd(权重衰减)、ema(指数滑动平均系数)、beta(数据增强的beta参数)、mixmode(数据增强模式)和其他可选参数。该函数返回一个分类器。
该函数首先根据数据集的高、宽和通道数创建一个输入占位符x_in和一个标签占位符y_in。
然后,将权重衰减乘以学习率,并将匹配权重乘以一个warmup_kimg参数,以在前几个迭代中逐渐增加该权重。
接着,使用给定的数据增强模式创建一个数据增强器augment。
最后,函数返回一个分类器,该分类器使用self.classifier函数作为主要分类器,其中的参数使用了kwargs,该函数是一个偏函数,其中已经部分确定了一些参数。
classifier(x_in, training=True) # Instantiate network.
ema = tf.train.ExponentialMovingAverage(decay=ema)
ema_op = ema.apply(utils.model_vars())
ema_getter = functools.partial(utils.getter_ema, ema)
这段代码用于实例化一个神经网络分类器(classifier),并使用指数移动平均(Exponential Moving Average)对其参数进行平滑处理。
首先,通过调用classifier(x_in, training=True)来实例化网络,其中x_in是输入数据,training=True表示在训练模式下运行网络。
然后,使用指数移动平均(Exponential Moving Average,简称EMA)对网络的参数进行平滑处理。具体来说,通过调用tf.train.ExponentialMovingAverage(decay=ema)来创建一个指数移动平均器,其中decay参数指定了平均的衰减率。
接着,通过调用ema.apply(utils.model_vars())来将指数移动平均器应用于网络的所有参数。这将为每个参数创建一个EMA副本,并更新其值。
最后,使用functools.partial将utils.getter_ema和EMA副本绑定在一起,创建一个ema_getter函数,用于在测试模式下获取网络参数的EMA副本。这将确保在测试模式下,网络参数将始终是平滑的EMA副本,而不是训练模式下的原始参数。
y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
guess = self.guess_label(tf.split(y, nu), classifier,
getter=ema_getter if use_ema_guess else None, **kwargs)
ly = tf.stop_gradient(guess.p_target)
lx = tf.one_hot(l_in, self.nclass)
xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
x, y = xy[0], xy[1:]
labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
del xy, labels_xy
这段代码用于对输入数据进行一些操作,包括将输入标签(y_in)进行转置和重塑,使用分类器对标签进行预测,然后进行数据增强和标签拼接。
首先,使用tf.transpose将输入标签y_in进行转置,以便在后面进行reshape操作。然后,使用tf.reshape将转置后的标签y_in重塑为[-1] + hwc的形状,其中hwc表示标签y_in的高度、宽度和通道数。这将y_in从一个5D张量转换为一个2D张量。num_views
表示每个样本所包含的视角数量。
(详细解释)tf.transpose(y_in, [1,0,2,3,4]),将y_in的维度从(batch_size, num_views, height, width, channels)变为(num_views, batch_size, height, width, channels)。使用tf.reshape函数将转置后的标签y_in重塑为一个新的形状,即[-1] + hwc。将y_in从一个5D张量转换为一个2D张量,其中第一维度表示了所有样本和视角的总数。具体来说,将第二到第五维度平坦化,即(batch_size * num_views, height, width, channels)转换为(batch_size * num_views * height * width * channels)。这种重塑操作可以将标签变为一个长向量,方便后续操作。
接下来,使用分类器对标签进行预测。具体来说,使用self.guess_label函数对重塑后的标签y进行预测,其中guess.p_target是预测的概率,可以用于计算分类器的损失函数。如果use_ema_guess为True,则使用ema_getter获取分类器参数的EMA副本进行预测。
(详细解释)y形状为(batch_size * num_views * height * width * channels),tf.split(y, nu)
将 y
按照 nu
的值在第一维度上进行分割,得到一个包含 nu
个张量的列表。classifier
是用于分类的网络模型,它将每个标签数据映射为一个类别,并同时输出每个类别的置信度。getter
是一个函数,用于获取模型中的参数,这里使用 ema_getter
函数来获取使用指数移动平均法(Exponential Moving Average,EMA)计算的模型参数,以提高模型的鲁棒性。**kwargs
表示其他可选的参数,这些参数会传递给 guess_label
方法。
生成标签的独热编码和停止梯度的标签分布。具体来说,l_in
是一个形状为 (batch_size, )
的张量,表示输入的真实标签。self.nclass
是一个标量,表示标签的类别数量。因此tf.one_hot(l_in, self.nclass)
会将真实标签 l_in
编码为一个形状为 (batch_size, self.nclass)
的独热编码张量,其中每一行表示一个标签的独热编码。
guess
是通过 guess_label
方法生成的一组伪标签。在该方法中,伪标签的生成是通过预测标签的分布来实现的,即 guess.p_target
表示标签数据的估计分布。为了避免在训练时反向传播误差到伪标签,导致网络训练不稳定,这里使用 tf.stop_gradient
函数将 guess.p_target
停止梯度,生成一个形状与其相同的新张量 ly
。
最终,lx
和 ly
分别表示真实标签和伪标签的独热编码,它们会被用于训练网络。
然后,进行数据增强和标签拼接。具体来说,将输入数据x_in和预测的标签guess.p_target(使用tf.split对预测的标签进行分割)传递给augment函数,对它们进行数据增强(augmentation)。augment函数返回增强后的数据和标签。最后,将增强后的数据x和增强后的标签y拆分为单独的张量,并将输入标签l_in进行one-hot编码,得到labels_x和labels_y。最后,删除xy和labels_xy以释放内存。
(详细解释)augment
是一个数据增强的函数,它接受三个参数:data
、labels
和 params
,分别表示原始数据、标签和数据增强的参数。在这里,[x_in] + tf.split(y, nu)
表示将原始数据 x_in
和伪标签数据 y
按照视角数 nu
进行拆分,拼接成一个列表传递给 augment
函数。类似地,[lx] + [ly] * nu
表示将真实标签的独热编码 lx
和伪标签的独热编码 ly
按照视角数 nu
进行拆分,并使用列表推导式生成一个长度为 nu
的列表,最后将这两个列表拼接起来。params
参数中包含了两个值,都是标量 beta
。它们用于控制数据增强时两种操作的强度,具体操作是随机剪裁和随机翻转。augment
函数的返回值是一个元组,包含增强后的数据和标签。在这里,xy
和 labels_xy
分别表示增强后的数据和标签。其中,xy[0]
表示增强后的原始数据,xy[1:]
表示增强后的伪标签数据;labels_xy[0]
表示增强后的真实标签独热编码,labels_xy[1:]
表示增强后的伪标签独热编码。最后,通过将 xy
拆分为 x
和 y
,将 labels_xy
拆分为 labels_x
和 labels_y
,分别表示增强后的原始数据、增强后的伪标签数据、增强后的真实标签独热编码和增强后的伪标签独热编码,用于训练网络。最后通过 del xy, labels_xy
删除不再需要的变量,释放内存。
batches = layers.interleave([x] + y, batch)
skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
logits = [classifier(batches[0], training=True)]
post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
for batchi in batches[1:]:
logits.append(classifier(batchi, training=True))
logits = layers.interleave(logits, batch)
logits_x = logits[0]
logits_y = tf.concat(logits[1:], 0)
这段代码主要是为了计算训练数据和伪标签数据的logits(分类器输出的未经softmax处理的概率),并计算损失函数。
首先通过调用layers.interleave
函数将训练数据和伪标签数据交错分组,形成一个新的batch列表,其中第一个元素是训练数据,其余元素是伪标签数据。
然后通过循环遍历每个batch,调用分类器函数classifier
计算每个batch的logits。在计算logits的过程中,通过设置training=True
来启用训练模式,以便在BN层中记录训练过程中的均值和方差,并在测试过程中使用它们进行归一化。
计算完logits后,通过调用layers.interleave
函数将它们重新交错分组,然后将第一个元素赋给logits_x
变量,其余元素赋给logits_y
变量。
loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
loss_xe = tf.reduce_mean(loss_xe)
loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
loss_l2u = tf.reduce_mean(loss_l2u)
tf.summary.scalar('losses/xe', loss_xe)
tf.summary.scalar('losses/l2u', loss_l2u)
这段代码计算了两个损失函数。第一个是 softmax 交叉熵损失函数,用来计算有标签数据的分类误差,它被赋值给了变量 loss_xe
。第二个是 L2 损失函数,用于衡量无标签数据的预测结果与其平滑后的伪标签之间的差异,它被赋值给了变量 loss_l2u
。这两个损失函数分别使用了 TensorFlow 中的 tf.nn.softmax_cross_entropy_with_logits_v2()
和 tf.square()
函数进行计算,并用 tf.reduce_mean()
函数求取了它们的平均值。在这里,tf.summary.scalar()
函数被用来在 TensorBoard 中记录损失函数的值。
post_ops.append(ema_op)
post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
with tf.control_dependencies([train_op]):
train_op = tf.group(*post_ops)
# Tuning op: only retrain batch norm.
skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
classifier(batches[0], training=True)
train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if v not in skip_ops])
这段代码是在训练模型的过程中,定义了一些后处理操作(post_ops
),然后使用Adam优化器最小化交叉熵损失(loss_xe
)和L2正则化损失(loss_l2u
)的和。其中,L2正则化损失用于匹配有标签样本和无标签样本的特征分布,以实现半监督学习的目的。tf.summary.scalar
用于记录损失的变化情况。with tf.control_dependencies([train_op])
语句确保在进行后续操作之前,train_op
操作先被执行。另外,还定义了一个操作train_bn
,用于只重新训练BN层。
return EasyDict(
x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
这段代码返回了一个包含各种操作和张量的EasyDict对象。它包括输入张量x_in,y_in和l_in,两个分类器的softmax输出,训练操作train_op和调整操作train_bn,以及其他一些操作。这个对象的目的是使训练和测试代码更加简洁和易于理解。
def main(argv):
del argv # Unused.
dataset = DATASETS[FLAGS.dataset]()
log_width = utils.ilog2(dataset.width)
model = AblationMixMatch(
os.path.join(FLAGS.train_dir, dataset.name),
dataset,
lr=FLAGS.lr,
wd=FLAGS.wd,
arch=FLAGS.arch,
batch=FLAGS.batch,
nclass=dataset.nclass,
ema=FLAGS.ema,
beta=FLAGS.beta,
use_ema_guess=FLAGS.use_ema_guess,
T=FLAGS.T,
mixmode=FLAGS.mixmode,
nu=FLAGS.nu,
w_match=FLAGS.w_match,
warmup_kimg=FLAGS.warmup_kimg,
scales=FLAGS.scales or (log_width - 2),
filters=FLAGS.filters,
repeat=FLAGS.repeat)
model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
这段代码是一个调用AblationMixMatch类实例化一个模型,并训练的函数。首先会根据FLAGS中的数据集名称选择对应的数据集,然后通过AblationMixMatch类构造模型。FLAGS中包含了训练需要用到的超参数,例如学习率、权重衰减、卷积神经网络结构等。最后,调用模型的train方法,训练模型并输出训练结果。其中,FLAGS.train_kimg和FLAGS.report_kimg是指训练步数和结果输出步数,都需要左移10位,因为模型使用的是Mini-batch SGD,每一步的batch size是2的整数次幂
if __name__ == '__main__':
utils.setup_tf()
flags.DEFINE_float('wd', 0.02, 'Weight decay.')
flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')
flags.DEFINE_bool('use_ema_guess', False, 'Whether to use EMA parameters when guessing labels.')
flags.DEFINE_float('T', 0.5, 'Softmax sharpening temperature.')
flags.DEFINE_enum('mixmode', 'xxy.yxy', MixMode.MODES, 'Mixup mode')
flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.')
flags.DEFINE_integer('warmup_kimg', 128, 'Warmup in kimg for the matching loss.')
flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
FLAGS.set_default('dataset', 'cifar10.3@250-5000')
FLAGS.set_default('batch', 64)
FLAGS.set_default('lr', 0.002)
FLAGS.set_default('train_kimg', 1 << 16)
app.run(main)
这段代码是一个 Python 脚本的主函数,会在运行时执行。该脚本提供了许多可选的命令行参数,用于指定不同的超参数设置。在此之后,脚本调用了 utils.setup_tf()
函数,该函数是一个工具函数,用于设置 TensorFlow 运行时的 GPU 环境等配置。最后,脚本调用了 app.run(main)
函数来运行 main
函数。main
函数主要是构建了一个 AblationMixMatch
模型,并调用 train
函数来训练模型。