torch和lua学习常见问题(重点是nn模块编译和torch编译)

1、点号和冒号在函数调用上的区别
http://blog.csdn.net/wangbin_jxust/article/details/12170233,根据这篇博客的描述,可以发现冒号只是省略了传递的第一个参数self的作用,对于点号需要具体的去传递一个对象的实参,下面做一个实验,用torch

x = torch.Tensor(5):zero()

这个时候输出长度为5的全部为0的Tensor

x = torch.Tensor(5).zero()

[string "x = torch.Tensor(5).zero()"]:1: invalid arguments: no arguments provided
expected arguments: *DoubleTensor*
stack traceback:
        [C]: in function 'zero'
        [string "x = torch.Tensor(5).zero()"]:1: in main chunk
        [C]: in function 'xpcall'
        /home/lzhou/torch/install/share/lua/5.1/trepl/init.lua:679: in function 'repl'
        ...zhou/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:204: in main chunk
        [C]: at 0x004064f0

这个时候程序报错了,原因是因为点号调用的时候必须要传入实参对象
将代码改为

x = torch.Tensor(5)
x = x.zero(x)

这个时候就可以了!点号调用必须要传入实参

y = torch.Tensor(5)
z = torch.Tensor(5)
z = z.zero(y)

输出的z和y都是0

2、torch中的c接口问题
从官网的介绍来吧,在MSECriterion.lua文件里有

   input.THNN.MSECriterion_updateOutput(
      input:cdata(),
      target:cdata(),
      self.output_tensor:cdata(),
      self.sizeAverage
   )

input.THNN.MSECriterion_updateOutput的语法很奇怪,来看官网的例子了解一下。
官网以“>http://torch.ch/docs/developer-docs.html# 中的threashold.lua为例进行讲解的。
这里写图片描述
这里的input.nn.threshold_updataOutput函数同样是一个很奇怪的语法,来看看官网的解释。
torch和lua学习常见问题(重点是nn模块编译和torch编译)_第1张图片
这句话实际调用的是threshold.c中的这个函数u,但是为什么回去调用这个函数呢?因为在接下来的几行中,会将这个函数注册在input.nn.的table里面,所以就会调用这个函数
这里写图片描述
这就允许我们为任意的Tensor写任意的函数,不需要很复杂的函数动态的分配了
最后看他们包含在哪些文件里面吧
init.lua
这里写图片描述
init.c
这里写图片描述
这里写图片描述
这里写图片描述

http://blog.hanschen.site/2016/09/07/lua-and-c.html
但是我们自习观察新版的torch,貌似不是官网给出的这个样子的,为什么呢?官网给出的链接是很久之前的torch的链接,在新版的torch中已经不再使用这种Lua C API的形式来进行lua和c之间的交互了,当然这仅仅对于nn这一个模块,最基础的torch部分其实还是由Lua C API 来实现的。,那么在新版的torch里面我们将以MSECriterion为例进行详细的介绍。
首先在/extra/nn/中存在MSECriterion.lua文件
torch和lua学习常见问题(重点是nn模块编译和torch编译)_第2张图片
在整个文件里面再也没有类似之间luaT_pushmetatable这一类的函数了,也就是不存在Lua API C这样的接口了,取而代之的就是很简单的函数,那是什么取代了传统的这种交流方法呢,是LuaJIT的FFI模块,他用来实现和c之间的通信。
其中上面的代码里有

input.THNN.MSECriterion_updateOutput(...)

上面的这句话很模糊,到底是怎么回事,怎么调用的呢?来看一下nn模块的编译过程吧
1)extra/nn目录下面是一系列的lua函数,这是外部接口部分
2)extra/nn/lib/THNN/generic下面是各个函数对应的c的实现
对应的各个目录下面都会有相应的CMakelist.txt文件,用来编译对应目录下的c文件,这样就会生成对应的LibTHNN.so文件,对应的路径在torch/install/lib/lua/5.1.里面

经过了C编译,在底层封装好的函数都放在了动态链接库LibTHNN.so里面,接下来的事情就是lua接口如何去掉用这些C函数。如果用官网的老版教程,肯定是采用Lua C API 的方式来进行,但是这种方法实在是太繁琐了。FFI的出现改变了这一个状态,关于FFI的介绍,详看FFI介绍
FFI能够简化LUA和c之间的通信,那么torch是怎么做的呢?
来到torch/extra/nn文件夹里面,有很多的lua,这些lua是怎么安排的呢?
在这个文件夹下面,首先是init.lua,定义了nn的global这样的table。然后是THNN.lua和THNN_H.lua两个函数,怎么实现的呢?具体看一下THNN.lua

local ffi = require 'ffi'

local THNN = {}


local generic_THNN_h = require 'nn.THNN_h'
-- strip all lines starting with #
-- to remove preprocessor directives originally present
-- in THNN.h
generic_THNN_h = generic_THNN_h:gsub("\n#[^\n]*", "")
generic_THNN_h = generic_THNN_h:gsub("^#[^\n]*\n", "")

通过cpath指定的动态函数库去寻找对应的c模块。

THNN.C = ffi.load(package.searchpath('libTHNN', package.cpath))

首先是加在ffi模块,定义THNN,加在THNN_h,THNN_h里面是一系列C函数在lua中的声明,方便FFI调用。
接下来会有两段话是ffi调用c的话

ffi.cdef(base_declarations)

定义c的两个数据结构

for i=1,#replacements do
   local r = replacements[i]
   local s = preprocessed
   for k,v in pairs(r) do
      s = string.gsub(s, k, v)
   end
   ffi.cdef(s)
end

这样的话所有的c函数定义都用ffi进行了加在,也就是cdef这个函数,他是lua和c之间通信非常关键的一个函数
到目前为止,已经知道lua是怎么和c进行通信的了,但是一开始input.THNN.function()的形式怎么解释呢?看下面的代码

local function_names = extract_function_names(generic_THNN_h)

THNN.kernels = {}
THNN.kernels['torch.FloatTensor'] = THNN.bind(THNN.C, function_names, 'Float', THNN.getState)
THNN.kernels['torch.DoubleTensor'] = THNN.bind(THNN.C, function_names, 'Double', THNN.getState)

torch.getmetatable('torch.FloatTensor').THNN = THNN.kernels['torch.FloatTensor']
torch.getmetatable('torch.DoubleTensor').THNN = THNN.kernels['torch.DoubleTensor']

torch.getmetatable(‘torch.FloatTensor’).THNN = THNN.kernels[‘torch.FloatTensor’]加载FloatTensor的元表,将THNN这一函数加进去,THNN里面存着诸多之前定义的c函数,就这样input这样的tensor就可以轻松点通过THNN来获得相应的函数了。具体细节都在代码里面,这样就可以解释为什么input.THNN.function是合法的了。对于cunn也是类似的实现套路,其他模块也类似。

不过通过上述的代码,发现必须事先装好torch才可以,否则torch.getmetatable是不可行的,torch是在哪里安装的呢?
现在转到/torch/pkg/torch文件夹下面,这个torch文件夹下面主要是实现一些timer,file和Tensor等等类似的基本的操作,来看看是怎么组织的吧。
切入口还是文件夹下面的init.c

#include "general.h"
#include "utils.h"

extern void torch_utils_init(lua_State *L);
extern void torch_random_init(lua_State *L);
extern void torch_File_init(lua_State *L);
extern void torch_DiskFile_init(lua_State *L);
extern void torch_MemoryFile_init(lua_State *L);
extern void torch_PipeFile_init(lua_State *L);
extern void torch_Timer_init(lua_State *L);

extern void torch_ByteStorage_init(lua_State *L);
extern void torch_CharStorage_init(lua_State *L);
extern void torch_ShortStorage_init(lua_State *L)
...
...

一开始首先声明了一系列函数的

LUA_EXTERNC DLL_EXPORT int luaopen_libtorch(lua_State *L);

int luaopen_libtorch(lua_State *L)
{

  lua_newtable(L);
  lua_pushvalue(L, -1);
  lua_setglobal(L, "torch");

  torch_utils_init(L);
  torch_File_init(L);

  torch_ByteStorage_init(L);
  torch_CharStorage_init(L);
  torch_ShortStorage_init(L);
  torch_IntStorage_init(L);
  torch_LongStorage_init(L);
  torch_FloatStorage_init(L);
  torch_DoubleStorage_init(L);
  torch_HalfStorage_init(L);

  torch_ByteTensor_init(L);
  ...
  ...
  luaT_newmetatable(L, "torch.Allocator", NULL, NULL, NULL, NULL);
  return 1;

首先会设置torch这样的全局table,一次往里面的元表加入方法,这里采用标准的LUA C API的方法,利用虚拟栈的形式,而不是FFI的形式
首先看第一个

torch_utils_init(L);

这个函数到当前目录下的utils.c去寻找,里面最关键的一部分代码是

static int torch_updateerrorhandlers(lua_State *L)
{
  THSetErrorHandler(luaTorchErrorHandlerFunction, L);
  THSetArgErrorHandler(luaTorchArgErrorHandlerFunction, L);
  return 0;
}

static const struct luaL_Reg torch_utils__ [] = {
  {"getdefaulttensortype", torch_lua_getdefaulttensortype},
  {"isatty", torch_isatty},
  {"tic", torch_lua_tic},
  {"toc", torch_lua_toc},
  {"setnumthreads", torch_setnumthreads},
  {"getnumthreads", torch_getnumthreads},
  {"getnumcores", torch_getnumcores},
  {"factory", luaT_lua_factory},
  {"getconstructortable", luaT_lua_getconstructortable},
  {"typename", luaT_lua_typename},
  {"isequal", luaT_lua_isequal},
  {"getenv", luaT_lua_getenv},
  {"setenv", luaT_lua_setenv},
  {"newmetatable", luaT_lua_newmetatable},
  {"setmetatable", luaT_lua_setmetatable},
  {"getmetatable", luaT_lua_getmetatable},
  {"metatype", luaT_lua_metatype},
  {"pushudata", luaT_lua_pushudata},
  {"version", luaT_lua_version},
  {"pointer", luaT_lua_pointer},
  {"setheaptracking", torch_setheaptracking},
  {"updateerrorhandlers", torch_updateerrorhandlers},
  {NULL, NULL}
};

void torch_utils_init(lua_State *L)
{
  torch_updateerrorhandlers(L);
  luaT_setfuncs(L, torch_utils__, 0);
}

往torch表里面塞函数,所有上面的函数都可以通过torch.function()进行调用
再来看

extern void torch_Timer_init(lua_State *L);
static const struct luaL_Reg torch_Timer__ [] = {
  {"reset", torch_Timer_reset},
  {"stop", torch_Timer_stop},
  {"resume", torch_Timer_resume},
  {"time", torch_Timer_time},
  {"__tostring__", torch_Timer___tostring__},
  {NULL, NULL}
};

void torch_Timer_init(lua_State *L)
{
  luaT_newmetatable(L, "torch.Timer", NULL, torch_Timer_new, torch_Timer_free, NULL);
  luaT_setfuncs(L, torch_Timer__, 0);
  lua_pop(L, 1);
}

这里首先是往torch.Timer的元表里塞函数,所以这部分调用的时候
torch和lua学习常见问题(重点是nn模块编译和torch编译)_第3张图片
和之前略有不同的,两个辅助功能讲完了,其他的辅助功能差不多,但是Tensor部分的压栈等操作要去generic/Tensor.c里面去寻找

static const struct luaL_Reg torch_Tensor_(_) [] = {
  {"retain", torch_Tensor_(retain)},
  {"free", torch_Tensor_(free)},
  {"contiguous", torch_Tensor_(contiguous)},
  {"size", torch_Tensor_(size)},
  {"elementSize", torch_Tensor_(elementSize)},
  {"__len__", torch_Tensor_(size)},
  {"stride", torch_Tensor_(stride)},
  {"dim", torch_Tensor_(nDimension)},
  {"nDimension", torch_Tensor_(nDimension)},
  {"set", torch_Tensor_(set)},
  {"storage", torch_Tensor_(storage)},
  {"storageOffset", torch_Tensor_(storageOffset)},
  {"clone", torch_Tensor_(clone)},
  {"contiguous", torch_Tensor_(contiguous)},
  {"resizeAs", torch_Tensor_(resizeAs)},
  {"resize", torch_Tensor_(resize)},
  {"narrow", torch_Tensor_(narrow)},
  {"sub", torch_Tensor_(sub)},
  {"select", torch_Tensor_(select)},
#ifndef TH_REAL_IS_HALF
  {"index", torch_Tensor_(indexSelect)},
  {"indexCopy", torch_Tensor_(indexCopy)},
  {"indexAdd", torch_Tensor_(indexAdd)},
  {"indexFill", torch_Tensor_(indexFill)},
  {"maskedSelect", torch_Tensor_(maskedSelect)},
  {"maskedCopy", torch_Tensor_(maskedCopy)},
  {"maskedFill", torch_Tensor_(maskedFill)},
#endif
  {"transpose", torch_Tensor_(transpose)},
  {"t", torch_Tensor_(t)},
  {"unfold", torch_Tensor_(unfold)},
  {"isContiguous", torch_Tensor_(isContiguous)},
  {"isSameSizeAs", torch_Tensor_(isSameSizeAs)},
  {"isSetTo", torch_Tensor_(isSetTo)},
  {"isSize", torch_Tensor_(isSize)},
  {"nElement", torch_Tensor_(nElement)},
  {"copy", torch_Tensor_(copy)},
#ifndef TH_REAL_IS_HALF
  {"apply", torch_Tensor_(apply)},
  {"map", torch_Tensor_(map)},
  {"map2", torch_Tensor_(map2)},
#endif
  {"read", torch_Tensor_(read)},
  {"write", torch_Tensor_(write)},
  {"__index__", torch_Tensor_(__index__)},
  {"__newindex__", torch_Tensor_(__newindex__)},
  {NULL, NULL}
};

void torch_Tensor_(init)(lua_State *L)
{
  luaT_newmetatable(L, torch_Tensor, NULL,
                    torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory));
  luaT_setfuncs(L, torch_Tensor_(_), 0);
  lua_pop(L, 1);
#ifndef TH_REAL_IS_HALF
  THVector_(vectorDispatchInit)();
#endif
}

上面的代码是不是很熟了,对,这些都是Tensor基本操作,torch7就是通过在文件夹下面编译对应的c和lua获得torch的全局掌控的,接下来就是编译的过程了。
在init.lua里面会首先加在libtorch,获得全局torch这样的概念,接下来就是利用torch来进行各种操作,常见的torch.等等

总结:torch7的torch部分采用LUA C API的形式,而torch有自己的实现,也即LuaT,具体内容看luaT
而对于nn,cunn等附加包采用ffi的形式来进行加载的

你可能感兴趣的:(torch和lua学习常见问题(重点是nn模块编译和torch编译))