最近复现prototypical net发现作者源代码使用了torchnet中的meter,还写了一个engine。所以我用一天时间看了下这个库。其实不用这个也是完全没有任何问题的。它只是方便复用的一个框架
torchnet 是用于 torch 的代码复用和模块化编程的框架
文档:https://tnt.readthedocs.io/en/latest/
github地址:https://github.com/pytorch/tnt
主要包含 4 个部分:
Dataset : 各种不同的方式处理数据
Engine: 各种机器学习算法
Meter: 性能度量评估
Log:
模块详细分为如下部分:
主要用于可视化,数据处理和存取,日志管理
原本是基于lua-torch的一个库,迁移到python中来,变成pytorch中的一部分。
安装
先保证pytorch已经安装,接着
pip install torchnet
或者 pip install git+https://github.com/pytorch/tnt.git@master
当前主要用的是Dataset和meter模块
抽象类classtorchnet.dataset.dataset.Dataset
传进的参数dataset是一个可迭代对象即可,不过BatchDataset必须让每个元素为一个dict
产生batch形式数据
torchnet.dataset.BatchDataset(dataset, batchsize, perm=
Parameters:
- dataset (Dataset) :数据集,这个数据集的每个元素必须是一个dict(应该是用于multi-task,每个样本由多个类,这时候,存储为item = {'data':data , 'class':class}会比较方便)
- batchsize (int) : batchsize
- perm (function, optional):洗牌函数(数据随机打乱)
- merge (function, optional) :控制产生 batch 的行为。 在transform.makebatch源码中使用这个函数. Default is None.它的作用是合并数据,直接从 Dataset 得到的一个 batch 是一个 dict 的列表,makebatch 默认行为是将 dict 中的 key 合并的数据合并,合并后使用 merge 进一步处理,默认行为(merge = None)就是将这个dict中的数据按照第一维度拼成一个Tensor(如果可以拼接的话)返回。应该是和torch.utils.data.dataloader中的collate_fn 如果需要pin_memory的话,需要把数据按锁页方式存储
- policy (str, optional) :处理最后一个batch的策略
include-last 包含最后一个,无论剩几个
skip-last 最后一个小于 batchsize 的时候丢掉
divisible-only 数据不能整除batchsize时报错
- filter (function, optional) :在产生 batch 之前过滤。 filter(sample) 返回 True 则包含这个数据,False 表示过滤掉,默认 True
拼接datasets
torchnet.dataset.ConcatDataset(datasets)
Parameters:
- datasets (iterable) 一个dataset列表
产生List形式数据
torchnet.dataset.ListDataset(elem_list, load=
Parameters:
- elem_list (iterable/str) :用于load数据的参数列表(可以是文件名列表或者数据本身等等,根据load制定)
- load (function, optional) :一个load数据的函数,第i个样本由load(elem_list[i])得到。默认就是恒等映射 i.e, lambda x: x
- path (str, optional) : Defaults to None. 表示数据的目录,如果这个被提供,则elem_list[i]在传给load时,会将这个作为前缀
返回一个采样数据集
torchnet.dataset.ResampleDataset(dataset, sampler=
Parameters:
- dataset (Dataset)
- sampler (function, optional) :采样函数,返回的是下标。第idx个样本由dataset[sampler(dataset, idx)]返回,默认是恒等映射。
- size (int, optional):目标数据大小,默认和原来的一样
也是采样,不过是均匀分布
torchnet.dataset.ShuffleDataset(dataset, size=None, replacement=False)
Parameters:
- dataset (Dataset)
- size (int, optional): 目标数据大小。如果replacement为False且它大于原数据大小,则报错
- replacement (bool, optional): 均匀分布放回抽样
函数resample(seed=None):对数据重新采样,默认不需要传一个随机seed
数据集分割
torchnet.dataset.SplitDataset(dataset, partitions, initial_partition=None)
Parameters:
- dataset (Dataset)
- partitions (dict): 分割dict,key是分割的自定义名称,val是权重(值在0和1之间)或者size大小指定每个部分的样本数
- initial_partition (str, optional) : 初始化选择的分割
函数
- select(partition):partition是上面dict中key的一个,指定使用哪个部分
partition(str)
方便把一个已经存在内存中的数据变成标准的结构(再套上一层)
torchnet.dataset.TensorDataset(data)
Dataset from a tensor or array or list or dict.
data的形式
tensor or numpy array
idx`th sample is `data[idx]
dict of tensors or numpy arrays
idx`th sample is `{k: v[idx] for k, v in data.items()}
list of tensors or numpy arrays
idx`th sample is `[v[idx] for v in data]
得到一个变换数据集
torchnet.dataset.TransformDataset(dataset, transforms)
- dataset (Dataset)
- transforms (function/dict) :一个函数(可以是compose的)或者dict(值是函数),用于样本的变换
torchnet.transform部分
主要用的是torchnet.transform.compose函数,将transform拼接在一起
它接收一个transform列表,每个transform是一个函数,接收上一次的输出作为输入。和TransformDataset搭配使用。
例如 TransformDataset(ListDataset(class_names), compose([transforms1,transform2]))
它将训练过程和测试过程进行包装,抽象成一个类,提供train和test方法和一个hooks.(这部分文档是问题的)
文档中的描述应该是torch.tensor中的hook,原理一致,只不过tensor中的hook是在变量forward或者backward的时候执行(两种hook)
hooks包括on_start, on_sample, on_forward, on_update, on_end_epoch, on_end,可以自己制定函数,在开始,load数据,forward,更新还有epoch结束以及训练结束时执行。一般是用开查看和保存模型训练过程的一些结果。
用于记录一些评估结果和可视化(用visdom)
文档还没完善,先略过
classtorchnet.logger.MeterLogger(server='localhost', env='main', port=8097, title='DNN', nclass=21, plotstylecombined=True)
Parameters:
server – The uri of the Visdom server
env – Visdom environment to log to.
port – Port of the visdom server.
title – The title of the MeterLogger. This will be used as a prefix for all plots.
nclass – 对于分类问题,传入类别数
plotstylecombined – Whether to plot train/test curves in the same window.
方法
peek_meter():Returns a dict of all meters and their values.
print_meter(mode, iepoch, ibatch=1, totalbatch=1, meterlist=None)
reset_meter(iepoch, mode='Train')[
update_loss(loss, meter='loss')
update_meter(output, target, meters={'accuracy'})
class torchnet.logger.visdomlogger.BaseVisdomLogger(fields=None, win=None, env=None, opts={}, port=8097, server='localhost')
分类meter
classtorchnet.meter.APMeter
计算每个类平均准确率AP
方法
add(output, target, weight=None)
- output (Tensor) – NxK tensor表示N个样本,分别属于K个类的概率
- target (Tensor) – binary NxK tensort 表示样本是否属于某个类 (eg: a row [0, 1, 0, 1]意味着样本属于classes 2 and 4)
- weight (optional, Tensor) – Nx1 tensor样本权重(weight>0)
reset()
value() 返回一个1xK FloatTensor 表示每个类的AP
classtorchnet.meter.mAPMeter
函数和参数跟上面一样,返回的是mAp
classtorchnet.meter.ClassErrorMeter(topk=[1], accuracy=False)
add(self, output, target)
output : 应该是一个NxK tensor,N个样本,分别属于K个类的概率
target :真实类别长度为N的一维Tensor
reset()
value(k=-1): 第top k的值,-1表示最后
维护一个混淆矩阵conf,大小为K * K,每行表示真实类别被和其他类的混淆值
classtorchnet.meter.ConfusionMeter(k, normalized=False)
Parameters:
- k (int) – 类别数
- normalized (boolean) – 混淆矩阵归一化(行归一化)
方法
add(predicted, target)
Parameters:
predicted (tensor) – N x K tensor 或者一个 N-tensor (值0到k-1),为predictor的输出
target (tensor) – N x K tensor(one hot) 或者一个 N-tensor(值0到k-1) 为真实类别
value(): 返回混淆矩阵
回归损失meter
计算平均值
classtorchnet.meter.AverageValueMeter
方法
- add(self, value, n=1)
- value是要记录的值,n是记录次数
- reset()
- value(): 返回平均值和标准差
计算AUC
classtorchnet.meter.AUCMeter
- add(output, target):
- reset():
- value():返回的是 (area, tpr, fpr)
计算MovingAverage
classtorchnet.meter.MovingAverageValueMeter(windowsize)
- windowsize:窗口大小
- add(value): 记录value
- reset()
- value() : 返回MA和标准差
classtorchnet.meter.MSEMeter(root=False)
- add(output, target) :output和target大小一致
- reset()
- value() :返回MSE
Miscellaneous Meters(其他)
classtorchnet.meter.TimeMeter(unit)
- unit:单位,暂时没用到
- 方法
- value(): 返回已经使用的时间
一些工具,一部分用在其他的模块中
class torchnet.utils.MultiTaskDataLoader(datasets, batch_size=1, use_all=False, **loading_kwargs)
记录结果的
class torchnet.utils.ResultsWriter(filepath, overwrite=False)
下面两个函数在makebatch 中使用
判断tensor能不能合并
torchnet.utils.table.canmergetensor(tbl)
合并tensor
torchnet.utils.table.mergetensor(tbl)