torch 的 forward 和 backward

Criterions有其forward和backward函数
https://github.com/torch/nn/blob/master/doc/criterion.md
Module也有其forward和backward函数
https://github.com/torch/nn/blob/master/doc/module.md

Module的forward函数最简单,就是输入input得到output

Module的backward看下这个线性回归的例子

require 'torch'
require 'nn'
require 'gnuplot'

month = torch.range(1,10)
price = torch.Tensor{28993,29110,29436,30791,33384,36762,39900,39972,40230,40146}

model = nn.Linear(1, 1)
criterion = nn.MSECriterion()

month_train = month:reshape(10,1)
price_train = price:reshape(10,1)

for i=1,1000 do
   price_predict = model:forward(month_train) -- 输入 -> 输出
   err = criterion:forward(price_predict, price_train) -- 输出,正确 -> loss值
   print(i, err)
   model:zeroGradParameters()
   gradient = criterion:backward(price_predict, price_train) -- 输出,正确 -> 梯度
   model:backward(month_train, gradient) -- 输入,梯度
   model:updateParameters(0.01)
end

month_predict = torch.range(1,12)
local price_predict = model:forward(month_predict:reshape(12,1))
print(price_predict)

gnuplot.pngfigure('plot.png')
gnuplot.plot({month, price}, {month_predict, price_predict})
gnuplot.plotflush()

你可能感兴趣的:(Torch)