尽管是用for loop语句很容易处理Dataset,但有时希望以on-the-fly manner或者在线程中读取数据,这时候Dataset Iterator就是个好的选择
注意,iterators是用于特殊情况的,一般情况下还是使用Dataset比较好
Iteartor 的两个主要方法:
* run() 返回一个Lua 迭代器,也可以使用()操作符,因为iterator源码中定义了__call事件
* exec(funcname,...) 在指定的dataset上执行funcname方法,funcname是dataset自己的方法,比如size
tnt.DatasetIterator(self,dataset[,perm][,filter][,transform])
The default dataset iterator
perm(idx), 实现shuffle功能,即对idx进行变换,更复杂的变换可以使用ShuffleDataset
filter(sample), 闭包函数,筛选样本是否用于迭代,返回bool值
transform(sample),闭包函数,实现对样本的变换,更复杂的变换可以结合TransformDataset和transform.compose等实现
tnt.ParallelDatasetIterator(self[,init],closure,nthread[,perm][,filter][,transform][,ordered])
这个才是迭代器的重点,用于以多线程方式迭代数据。
The purpose of this class is to have a zero pre-processing cose. when reading datasets on the fly from disk(not loading thenm fully in memory), or performing complex pre-processing this canbe of interest.
nthreads 指定了线程的个数
init(threadid) 闭包函数,指定了线程threadid的初始化工作,如果啥都不做可以省略
closure(threadid) 每个线程的job,返回的必须时tnt.Dataset的一个实例
perm(idx) 用于shuffle
filter(sample) 闭包函数,指定哪些样本不用于迭代
transform(sample) 对样本进行变换,在filter之前执行
order 线程之间数据的处理是否有序,主要是为了程序的可重现性,当order=true时,多次执行程序,顺序是相同的
尤其需要注意的是,closure中的所有upvalues都必须是可序列化的,最好是避免使用upvalues,并保证closure中使用的package都在init中require
在网络训练的过程中,都是计算前向误差,误差反传,更新权重这些过程,只是模型,数据和评价函数不同而已,所以Engine给训练过程提供了一个模板,该模板建立了model,DatasetIterator,Criterion和Meter之间的联系
engine=tnt.Engine()包含两个主要方法
* engine:train() 在数据集上训练数据
* engine:test() 评估模型,可选
Engine不仅实现了训练和评估的一般模板,还提供了许多接口,用于控制训练过程
tnt.SGDEngine
SGDEngine 模块在train过程中使用Stochastic Gradient Descent方法训练,模块包含数据采样,前向传递,反向传递,参数更新等,还有一些钩子函数
hooks = {
['onStart'] = function() end, --用于训练开始前的设置和初始化
['onStartEpoch'] = function() end, -- 每一个epoch前的操作
['onSample'] = function() end, -- 每次采样一个样本之后的操作
['onForward'] = function() end, -- 在model:forward()之后的操作
['onForwardCriterion'] = function() end, -- 前向计算损失函数之后的操作
['onBackwardCriterion'] = function() end, -- 反向计算损失误差之后的操作
['onBackward'] = function() end, -- 反向传递误差之后的操作
['onUpdate'] = function() end, -- 权重参数更新之后的操作
['onEndEpoch'] = function() end, -- 每一个epoch结束时的操作
['onEnd'] = function() end, -- 整个训练过程结束后的收拾现场
}
可以发现Engine给的hook函数还是很全面的,几乎训练过程的每一个节点都允许用户制定操作,使用hook函数
一般而言,训练过程最少应该知道训练模型,损失函数,数据和学习率,这里学习方法已经知道了SGD,Engine用到的数据是tnt.DatasetIterator类型的。 评估过程只需要数据和模型就可以了
外部可以通过state变量与Engine训练过程交互
state = {
['network'] = network, --设置了model
['criterion'] = criterion, -- 设置损失函数
['iterator'] = iterator, -- 数据迭代器
['lr'] = lr, -- 学习率
['lrcriterion'] = lrcriterion, --
['maxepoch'] = maxepoch, --最大epoch数
['sample'] = {}, -- 当前采集的样本,可以在onSample中通过该阈值查看采样样本
['epoch'] = 0 , -- 当前的epoch
['t'] = 0, -- 已经训练样本的个数
['training'] = true -- 训练过程
}
评估时需要指定:
state = {
['netwrok'] = network
['iterator'] = iterator
['criterion'] = criterion
}
tnt.OptimEngine
这个方法和SGDEngine的最大的区别在于封装了optim中的多种优化方法。在训练开始的时候,engine会通过getParameters获取model的参数
train需要附加两个量:
optimMethod 优化方法,比如optim.sgd
config 优化方法对应的参数
Example:
和Engine配合使用,用于measure the model.
几乎所有的meters都会有3个方法:
* add() 给待统计的meter添加一个观测值,其输入参数一般形式为(output,value),output为model的输出,target为真实值
* value() 获得待统计的meter的当前值
* reset() 重新计数
Meter的使用示例:
tnt.APMeter(self)
评估每一类的平均正确率
APMeter的操作对象是一个的Tensor,表示N个样本对应在K类中的值,另外可选的一个的 Tensor表示每个样本的权重
tnt.AverageValueMeter(self)
用于统计任意添加的变量的方差和均值,可以用来测量平均损失等
add()的输入必须时number类型,另外在add的时候可以有一个可选的参数n,表示对应值的权重
tnt.AUCMeter(self)
对于二分类问题计算Area Under Curve (AUC).
AUCMeter操作的变量是1D的tensor
tnt.ConfusionMeter(self,k[,nirmalized])
多类之间的混淆矩阵,注意不是多类多标签问题,多标签是指一个类的实例可能分配多个标签,这类问题参见tnt.MultiLabelConfusionMeter
初始化的时候,需要指定类别数k,normalized指定是否将confuse matrix 归一化,归一化之后输出的是百分比,否则是数值
add(output,target) 输入都是的tensor,这里为什么每次都是N个样本一起输入呢?这是因为往往训练模型都是Batch模式处理的,target可以是N的tensor,每个值表示对应类别标号,也可以时NK的tensor表示类别的one-hot vector
value()返回KK的混淆矩阵行表示groundtruth,列表示predicted targets
tnt.mAPMeter(self)
统计所有类别之间的平均正确率,和APMeter参数完全一致,不同的时value()返回的是多个类别总的正确率
tnt.MovingAverageValueMeter(self,windowsize)
该meter和AverageValueMeter非常类似,输入的也是number,不同在于他统计的不是所有的number的均值和方差,而是往前windowsize时间窗内的numbers的均值和方差,windowsize在初始化时需要指定
tnt.MultiLabelConfusionMeter(self,k[,normalized])
多类多标签混淆矩阵,这个没接触过,不知道理解对不对,先放这吧,需要的时候再看
The tnt.MultiLabelConfusionMeter constructs a confusion matrix for multi- label, multi-class classification problems. In constructing the confusion matrix, the number of positive predictions is assumed to be equal to the number of positive labels in the ground-truth. Correct predictions (that is, labels in the prediction set that are also in the ground-truth set) are added to the diagonal of the confusion matrix. Incorrect predictions (that is, labels in the prediction set that are not in the ground-truth set) are equally divided over all non-predicted labels in the ground-truth set.
At initialization time, the k parameter that indicates the number of classes in the classification problem under consideration must be specified. Additionally, an optional parameter normalized (default = false) may be specified that determines whether or not the confusion matrix is normalized (that is, it contains percentages) or not (that is, it contains counts).
The add(output, target) method takes as input an NxK tensor output that contains the output scores obtained from the model for N examples and K classes, and a corresponding NxK-tensor target that provides the targets for the N examples using one-hot vectors (that is, vectors that contain only zeros and a single one at the location of the target value to be encoded).
tnt.ClassErrorMeter(self[,topk][,accuracy])
参数: topk = table
accuracy = boolean
该meter用于统计分类误差,topk是一个table指定分别统计前k类预测误差,如ImageNet Competition中的Top5类误差,accuracy表示返回的是正确了还是错误率,accuracy=true,返回的就是1-error
add(output,target),output是一个的tensor,target可以使一个N的tensor也可以是一个的tensor,参考之前的AUCMeter
value()返回的时topk误差,value(k)返回的是第topk类误差
tnt.TimeMeter(self[,unit])
这个Meter用于统计events之间的时间,也可以用来统计batch数据的平均处理数据。她很特别!
unit在初始的时候给定,是一个布尔值,默认false,当设置为true时,返回值将会被incUnit()值平均,计算平均时间消耗。
tnt.TimeMeter提供的方法有:
reset() 重置timer,unit counter
stop() stop the timer
resume() 唤醒timer
incUnit() uint+1
value() 返回从reset()到现在的时间消耗
tnt.PrecisionAtKMeter(self[,topk][,dim][,online])
待补充
tnt.RecallMeter(self[,threshold][,preclass])
统计threshold下的召回率,threshold是一个table类型,每个元素是一个阈值,默认值为0.5. perclass是一个布尔值,表示是单独统计每一类的召回率还是统计整个召回率,默认值是false
add(output,target) output是N*K的概率矩阵,行和为1;target是NK的二值矩阵,不一定行和为1,如{0,1,0,1}
value()返回的是table值,对应的是threshold table中指定阈值下的召回率,如果perclass = true,那么table的每个元素就是一个table
tnt.PrecisionMeter(self[,threshold][,perclass])
参考RecallMeter,这里计算的是正确率
tnt.NDCGMeter(self[,K])
计算normalized discounted cumulative gain,没使用过。。。。
Log是一个由sting key索引的table,这些keys必须在构造函数中指定,有一个特殊的键 __status__可以在log:status()函数中设置用于记录一些基本的messages
Log中提供的一些closures以及对应attached events
* onSet(log,key,value) 对应着给键赋值 log:set{}
* onGet(log,key) 对应着读取key对应的值 log:get()
* onFlush(log) 对应着清空log log:flush()
* onClose(log) 对应log:close() 关闭log
示例:
原文地址:https://www.cnblogs.com/YiXiaoZhou/p/6774806.html