Softmax backprop is not yet implemented
Softmax backprop is not yet implemented
Softmax backprop is not yet implemented
也就是之前写的《源码解析》中的动力源泉。
至于灵感,是tensorflow官方给的demo中用数学计算手动实现了交叉熵。所以既然deeplearn.js暂时没有实现softmax的bp,我可以用数学计算代替它,并且实际测试后证明是可行的。
github工程地址:https://github.com/knimet/deeplearn.js-softmax-can-backprop
//基本思路
powX = pow(Math.E , X);
sum_powX = sum(powX);
sfX = divide(powX,sum_powX)
sigX = sigmoid(X)
sig1 = 1/sigX
sig2 = sig1-1
sum_sig = sum(sig2)
sfX = sig2 / sum_sig
这是我第一次写出的伪代码,有坑有坑有坑!
var target = dl.Array1D.new([1,2,3,4,5])
var sf1 = math.softmax(target)
sf1.getValues()
var _tsg = math.sigmoid(_t1)
_tsg.getValues()
var _t1 = math.divide(dl.Array1D.new([1]) , _tsg)
_t1.getValues()
var _t2 = math.subtract(_t1,dl.Array1D.new([1]))
_t2.getValues()
var sf2 = math.divide(_t2,math.sum(_t2))
sf2.getValues()
运行结果如下:
mX = X * (-1)
sigX = sigmoid(mX)
sig1 = 1/sigX
sig2 = sig1-1
sum_sig = sum(sig2)
sfX = sig2 / sum_sig
给X的每个元素xi乘上-1,在计算得到sigmoid后就是我们需要的e^x了。
var target = dl.Array1D.new([1,2,3,4,5])
var sf1 = math.softmax(target)
sf1.getValues()
var _t1 = math.multiply(dl.Scalar.new(-1),target)
_t1.getValues()
var _tsg = math.sigmoid(_t1)
_tsg.getValues()
var _t2 = math.divide(dl.Array1D.new([1]) , _tsg)
_t2.getValues()
var _t3 = math.subtract(_t2,dl.Array1D.new([1]))
_t3.getValues()
var sf2 = math.divide(_t3,math.sum(_t3))
sf2.getValues()
测试结果如下:
var dl = require('deeplearn');
var math = new dl.NDArrayMathCPU();
const session = new dl.Session(g, math);
var target = [dl.Array1D.new([1,2,3,4,5])];
var g = new dl.Graph();
var X = g.placeholder('X',[5]);
var sf1 = g.softmax(X);
var _t1 = g.multiply(X,g.constant(-1));
var _tsg = g.sigmoid(_t1);
var _t2 = g.divide(g.constant(1),_tsg);
var _t3 = g.subtract(_t2,g.constant(1));
var sf2 = g.divide(_t3,g.reduceSum(_t3));
const shuffledInputProviderBuilder = new dl.InCPUMemoryShuffledInputProviderBuilder([target]);
const inputX = shuffledInputProviderBuilder.getInputProviders()[0];
const feedEntries = [{
tensor: X,
data: inputX
}
];
console.log(session.eval(sf1,feedEntries ).getValues());
console.log(session.eval(sf2,feedEntries ).getValues());
运行结果如下: