我们自己写的每个网络都需要继承自mindspore提供的Cell这个类.
from mindspore import nn
from mindspore import ops
class MyNet(nn.Cell):
def __init__(self, in_chanel, out_chanel):
super().__init__()
self.dense1 = nn.Dense(in_chanel, out_chanel)
self.dense2 = nn.Dense(out_chanel, out_chanel * 3)
self.dense3 = nn.Dense(out_chanel * 3, out_chanel)
def construct(self, input):
output = self.dense1(input)
output = self.dense2(output)
output = self.dense3(output)
return output
这样我们在运行这个网络的时候, 会把这个对象当做函数来直接调用, 这一点跟pytorch是类似的, 就像这样:
input = mindspore.Tensor(np.ones((2, 4)), dtype=mindspore.float32)
net = MyNet(4, 32)
res = net(input)
print(res)
这个其实是python的语法, 会调用到MyNet这个class的call方法上, 也就是其父类Cell的call方法, 我们看一下这个call方法做了什么:
可以看到, 如果是静态图模式的话, 会走到Cell的compile_and_run方法, 继续看:
可以看到是调用了_executor这个对象(注意这不是一个方法, 还是在把python对象当做函数来调用!), 那么这个对象的call方法又是什么呢?
可以看到这个类是调用了自身的run方法
run方法里又调用了_executor对象的call方法, 注意622行这个这个_executor可不是自身, 而是一个pybind, 是python和c++实现互调用的方法, 可以看到这个_executor对应的是这个Executor的instance,
而这个Executor是从c_expression中来的:
我们可以ccsrc(mindspore的c++代码都在这个目录)目录下搜索这个类名, 就能找到相应的代码:
可以看到, get_instance对应的是c++代码中ExecutorPy这个类的GetInstance方法, 而call方法对应的是ExecutorPy这个类的Run方法.
我们在c++代码中可以看到, 这个类的位置:
在mindspore/ccsrc/pipeline/jit/pipeline.h这个文件中, 而这个类的run方法的实现则在pipeline.cc文件中:
总结一下就是:
通过利用python语言提供的call方法, 我们可以把对象当成函数来调用, 在静态图模式下, Cell的call方法会调用compile_and_run, 其中run方法会通过层层封装, 最后用pybind来调用到c++侧, 这与c++的run方法的到底在run什么, 我们下一篇再讲