hybridize原理
什么是符号式编程
举个沐神的例子
def add_str():
return '''
def add(A, B):
return A + B
'''
def fancy_func_str():
return '''
def fancy_func(A, B, C, D):
E = add(A, B)
F = add(C, D)
G = add(E, F)
return G
'''
def evoke_str():
return add_str() + fancy_func_str() + '''
print(fancy_func(1,2,3,4))
'''
prog = evoke_str()
y = compile(prog, '', 'exec')
exec(y)
上面代码对应3个过程:
- 定义计算流程
- 编译成可执行的程序
- 给定输入调用编译好的程序
hybrid_forward(F, x)
中的F
net(x)
->__call__
->forward
->hybrid_forward
HybridBlock中hybrid_forward源码:
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
try:
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
self._finish_deferred_init(self._active, x, *args)
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)
assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
��15和倒数第一行分别传给了hybrid_forward函数F为ndarray和symbol,这样就达到了用户使用的时候这个F取决于什么运行模式而运行不同的对象
hybridize过程
以下面一段代码为例:
class HybridNet(nn.HybridBlock):
def __init__(self, **kwargs):
super(HybridNet, self).__init__(**kwargs)
with self.name_scope():
self.fc1 = nn.Dense(10)
self.fc2 = nn.Dense(2)
def hybrid_forward(self, F, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
if __name__ == '__main__':
net = HybridNet()
net.initialize()
x = nd.random.normal(shape=(1, 4))
net.hybridize()
print net
y = net(x)
上述代码__init__
会执行block源码的__setattr__
方法注册HybridNet两个子block:
def __setattr__(self, name, value):
"""Registers parameters."""
if hasattr(self, name):
existing = getattr(self, name)
if isinstance(existing, (Parameter, Block)) and not isinstance(value, type(existing)):
raise TypeError('Changing attribute type for {name} from {type1} to {type2}' \
'is not allowed.'.format(name=name,
type1=type(existing),
type2=type(value)))
if isinstance(existing, Block):
for i, c in enumerate(self._children):
if c is existing:
self._children[i] = value
elif isinstance(value, Block):
self.register_child(value)
elif isinstance(value, Block):
self.register_child(value)
super(Block, self).__setattr__(name, value)
register方法注册子block:
def register_child(self, block):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
self._children.append(block)
hybridize方法先把本block和所有的子block实例的_active
置为true:
def hybridize(self, active=True):
self._active = active
print ('HybridBlock', self.name)
super(HybridBlock, self).hybridize(active)
def hybridize(self, active=True):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
non-hybrid children.
Parameters
----------
active : bool, default True
Whether to turn hybrid on or off.
"""
for cld in self._children:
cld.hybridize(active)
执行过hybridize的net(计算图)再执行前向计算net(x)
->__call__
->forward
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
try:
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
self._finish_deferred_init(self._active, x, *args)
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)
assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)
params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
从方法里面看出如果传递进来的是ndarray, 会_call_cached_op
,获取计算图并执行计算
打印图:
客户端代码:
if __name__ == '__main__':
net = HybridNet()
net.initialize()
x = nd.random.normal(shape=(1, 4))
#x = mx.sym.Variable('data')
#net.hybridize()
#print net
y = net(x)
print y
—symbolic
:
hybridize
:
虽然是图,但是还是直接执行结果:
[[ -2.95562204e-05 3.18562193e-03]]
imperactive
:
[[ -2.95562204e-05 3.18562193e-03]]
参考:
https://github.com/mli/gluon-tutorials-zh/blob/master/chapter_gluon-advances/hybridize.md