python 版本 torchnet 简单使用文档

最近复现prototypical net发现作者源代码使用了torchnet中的meter,还写了一个engine。所以我用一天时间看了下这个库。其实不用这个也是完全没有任何问题的。它只是方便复用的一个框架

 

torchnet 是用于 torch 的代码复用和模块化编程的框架

文档:https://tnt.readthedocs.io/en/latest/

github地址:https://github.com/pytorch/tnt

主要包含 4 个部分:

Dataset : 各种不同的方式处理数据

Engine: 各种机器学习算法

Meter: 性能度量评估

Log:

模块详细分为如下部分:

  • Datasets:
    • BatchDataset
    • ListDataset
    • ResampleDataset
    • ShuffleDataset
    • TensorDataset [new]
    • TransformDataset
  • Meters:
    • APMeter
    • mAPMeter
    • AverageValueMeter
    • AUCMeter
    • ClassErrorMeter
    • ConfusionMeter
    • MovingAverageValueMeter
    • MSEMeter
    • TimeMeter
  • Engines:
    • Engine
  • Logger
    • Logger
    • VisdomLogger
    • MeterLogger [new, easy to plot multi-meter via Visdom]

 

主要用于可视化,数据处理和存取,日志管理

 

python 版本 torchnet 简单使用文档_第1张图片

原本是基于lua-torch的一个库,迁移到python中来,变成pytorch中的一部分。

 

安装

先保证pytorch已经安装,接着

pip install torchnet

或者 pip install git+https://github.com/pytorch/tnt.git@master

 

当前主要用的是Dataset和meter模块


Dataset部分

抽象类classtorchnet.dataset.dataset.Dataset

传进的参数dataset是一个可迭代对象即可,不过BatchDataset必须让每个元素为一个dict

 

产生batch形式数据

torchnet.dataset.BatchDataset(dataset, batchsize, perm=>, merge=None, policy='include-last', filter=>)

Parameters: 

  • dataset (Dataset) :数据集,这个数据集的每个元素必须是一个dict(应该是用于multi-task,每个样本由多个类,这时候,存储为item = {'data':data , 'class':class}会比较方便)
  • batchsize (int) : batchsize
  • perm (function, optional):洗牌函数(数据随机打乱)
  • merge (function, optional) :控制产生 batch 的行为。 在transform.makebatch源码中使用这个函数. Default is None.它的作用是合并数据,直接从 Dataset 得到的一个 batch 是一个 dict 的列表,makebatch 默认行为是将 dict 中的 key 合并的数据合并,合并后使用 merge 进一步处理,默认行为(merge = None)就是将这个dict中的数据按照第一维度拼成一个Tensor(如果可以拼接的话)返回。应该是和torch.utils.data.dataloader中的collate_fn 如果需要pin_memory的话,需要把数据按锁页方式存储
  • policy (str, optional) :处理最后一个batch的策略

           include-last 包含最后一个,无论剩几个

           skip-last 最后一个小于 batchsize 的时候丢掉

           divisible-only 数据不能整除batchsize时报错

  • filter (function, optional) :在产生 batch 之前过滤。 filter(sample) 返回 True 则包含这个数据,False 表示过滤掉,默认 True

拼接datasets

torchnet.dataset.ConcatDataset(datasets)

Parameters:

  •     datasets (iterable) 一个dataset列表

产生List形式数据

torchnet.dataset.ListDataset(elem_list, load=>, path=None)

Parameters:

  • elem_list (iterable/str) :用于load数据的参数列表(可以是文件名列表或者数据本身等等,根据load制定)
  • load (function, optional) :一个load数据的函数,第i个样本由load(elem_list[i])得到。默认就是恒等映射 i.e, lambda x: x
  • path (str, optional) : Defaults to None. 表示数据的目录,如果这个被提供,则elem_list[i]在传给load时,会将这个作为前缀

返回一个采样数据集

torchnet.dataset.ResampleDataset(dataset, sampler=>, size=None)

Parameters:

  • dataset (Dataset)
  • sampler (function, optional) :采样函数,返回的是下标。第idx个样本由dataset[sampler(dataset, idx)]返回,默认是恒等映射。
  • size (int, optional):目标数据大小,默认和原来的一样

也是采样,不过是均匀分布

torchnet.dataset.ShuffleDataset(dataset, size=None, replacement=False)

Parameters:

  • dataset (Dataset)
  • size (int, optional): 目标数据大小。如果replacement为False且它大于原数据大小,则报错
  • replacement (bool, optional): 均匀分布放回抽样

函数resample(seed=None):对数据重新采样,默认不需要传一个随机seed


数据集分割

torchnet.dataset.SplitDataset(dataset, partitions, initial_partition=None)

Parameters:

  • dataset (Dataset)
  • partitions (dict): 分割dict,key是分割的自定义名称,val是权重(值在0和1之间)或者size大小指定每个部分的样本数
  • initial_partition (str, optional) : 初始化选择的分割

函数

  • select(partition):partition是上面dict中key的一个,指定使用哪个部分

             partition(str)


方便把一个已经存在内存中的数据变成标准的结构(再套上一层)

torchnet.dataset.TensorDataset(data)

           Dataset from a tensor or array or list or dict.

           data的形式

           tensor or numpy array

                   idx`th sample is `data[idx]

           dict of tensors or numpy arrays

                      idx`th sample is `{k: v[idx] for k, v in data.items()}

            list of tensors or numpy arrays

                      idx`th sample is `[v[idx] for v in data]

 


 

得到一个变换数据集

torchnet.dataset.TransformDataset(dataset, transforms)

  • dataset (Dataset)
  • transforms (function/dict) :一个函数(可以是compose的)或者dict(值是函数),用于样本的变换

 

 


torchnet.transform部分

主要用的是torchnet.transform.compose函数,将transform拼接在一起

它接收一个transform列表,每个transform是一个函数,接收上一次的输出作为输入。和TransformDataset搭配使用。

例如 TransformDataset(ListDataset(class_names), compose([transforms1,transform2]))

 


torchnet.engine

它将训练过程和测试过程进行包装,抽象成一个类,提供train和test方法和一个hooks.(这部分文档是问题的)

文档中的描述应该是torch.tensor中的hook,原理一致,只不过tensor中的hook是在变量forward或者backward的时候执行(两种hook)

hooks包括on_start, on_sample, on_forward, on_update,  on_end_epoch,  on_end,可以自己制定函数,在开始,load数据,forward,更新还有epoch结束以及训练结束时执行。一般是用开查看和保存模型训练过程的一些结果。


torch.logger

用于记录一些评估结果和可视化(用visdom)

文档还没完善,先略过

classtorchnet.logger.MeterLogger(server='localhost', env='main', port=8097, title='DNN', nclass=21, plotstylecombined=True)

Parameters:

server – The uri of the Visdom server

env – Visdom environment to log to.

port – Port of the visdom server.

title – The title of the MeterLogger. This will be used as a prefix for all plots.

nclass – 对于分类问题,传入类别数

plotstylecombined – Whether to plot train/test curves in the same window.

方法

peek_meter():Returns a dict of all meters and their values.

print_meter(mode, iepoch, ibatch=1, totalbatch=1, meterlist=None)

reset_meter(iepoch, mode='Train')[

update_loss(loss, meter='loss')

update_meter(output, target, meters={'accuracy'})

 

class torchnet.logger.visdomlogger.BaseVisdomLogger(fields=None, win=None, env=None, opts={}, port=8097, server='localhost')


 

torch.Meter

分类meter

classtorchnet.meter.APMeter

计算每个类平均准确率AP

方法

add(output, target, weight=None)

  • output (Tensor) – NxK tensor表示N个样本,分别属于K个类的概率
  • target (Tensor) – binary NxK tensort 表示样本是否属于某个类 (eg: a row [0, 1, 0, 1]意味着样本属于classes 2 and 4)
  • weight (optional, Tensor) – Nx1 tensor样本权重(weight>0)

reset()

value() 返回一个1xK FloatTensor 表示每个类的AP

 

classtorchnet.meter.mAPMeter

函数和参数跟上面一样,返回的是mAp

 

classtorchnet.meter.ClassErrorMeter(topk=[1], accuracy=False)

add(self, output, target)

output : 应该是一个NxK tensor,N个样本,分别属于K个类的概率

target :真实类别长度为N的一维Tensor

reset()

value(k=-1): 第top k的值,-1表示最后

维护一个混淆矩阵conf,大小为K * K,每行表示真实类别被和其他类的混淆值

classtorchnet.meter.ConfusionMeter(k, normalized=False)

Parameters:

  • k (int) – 类别数
  • normalized (boolean) – 混淆矩阵归一化(行归一化)

方法

add(predicted, target)

Parameters:

        predicted (tensor) – N x K tensor 或者一个 N-tensor (值0到k-1),为predictor的输出

        target (tensor) – N x K tensor(one hot) 或者一个 N-tensor(值0到k-1) 为真实类别

value(): 返回混淆矩阵

回归损失meter

计算平均值

classtorchnet.meter.AverageValueMeter

方法

  • add(self, value, n=1)
  • value是要记录的值,n是记录次数
  • reset()
  • value(): 返回平均值和标准差

计算AUC

classtorchnet.meter.AUCMeter

  • add(output, target):
  • reset():
  • value():返回的是 (area, tpr, fpr)

 

计算MovingAverage

classtorchnet.meter.MovingAverageValueMeter(windowsize)

  • windowsize:窗口大小
  • add(value): 记录value
  • reset()
  • value() : 返回MA和标准差

classtorchnet.meter.MSEMeter(root=False)

  • add(output, target) :output和target大小一致
  • reset()
  • value() :返回MSE

Miscellaneous Meters(其他)

classtorchnet.meter.TimeMeter(unit)

  • unit:单位,暂时没用到
  • 方法
  • value(): 返回已经使用的时间

 


torchnet.utils

一些工具,一部分用在其他的模块中

class torchnet.utils.MultiTaskDataLoader(datasets, batch_size=1, use_all=False, **loading_kwargs)

记录结果的

class torchnet.utils.ResultsWriter(filepath, overwrite=False)

 

下面两个函数在makebatch 中使用

判断tensor能不能合并

torchnet.utils.table.canmergetensor(tbl)

合并tensor

torchnet.utils.table.mergetensor(tbl)

 


 

你可能感兴趣的:(pytorch,torch,pytorch,torchnet)