一、概述
MindSpore的Cell类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,需要继承Cell类,并重写__init__方法和construct方法。
损失函数、优化器和模型层等本质上也属于网络结构,也需要继承Cell类才能实现功能,同样用户也可以根据业务需求自定义这部分内容。
本节内容介绍Cell类的关键成员函数,“构建网络”中将介绍基于Cell实现的MindSpore内置损失函数、优化器和模型层及使用方法,以及通过实例介绍如何利用Cell类构建自定义网络。
_init_:在该函数中定义网络所需要的层或者变量等信息
construct:该函数实现网络的执行流程
二、ops构建网络
这里首先我们使用了ops中的Conv2D算子,然后又使用了bias_add算子,同时定义下模型的权重,之后在construct函数中定义网络的执行流程。
代码样例如下:
class Net(nn.Cell):
def __init__(self,in_channels=10,out_channels=20,kernel_size=3):
super(Net,self).__init__()
self.conv2d=ops.Conv2D(out_channels,kernel_size)
self.bias_add=ops.BiasAdd()
self.weight=Parameter(initializer('normal',[out_channels,in_channels,kernel_size,kernel_size]),name='conv.weight')
def construct(self,x):
output=self.conv2d(x,self.weight)
output=self.bias_add(output,self.bias)
return output
三、nn构建网络
对于nn模块构建网络,非常的方便,它是mindSpore封装的高阶API,简单调用。
代码样例如下:
class Net(nn.Cell):
def __init__(self,in_channels=10,out_channels=20,kernel_size=3):
super(Net,self).__init__()
self.conv2d=nn.Conv2d(in_channels,out_channels,kernel_size,has_bias=True,weight_init=Normal(0.02))
def construct(self,x):
output=self.conv2d(x)
return output
四、nn模块和ops的关系
MindSpore的nn模块是Python实现的模型组件,是对低阶API的封装,主要包括各种模型层、损失函数、优化器等。
同时nn也提供了部分与Primitive算子同名的接口,主要作用是对Primitive算子进行进一步封装,为用户提供更友好的API。
重新分析上文介绍construct方法的用例,此用例是MindSpore的nn.Conv2d源码简化内容,内部会调用ops.Conv2D。nn.Conv2d卷积API增加输入参数校验功能并判断是否bias等,是一个高级封装的模型层。
五、网络的常用方法
1.parameters_dict()
该方法会以字典的形式返回网络的所有参数,键为参数的名称,值为对应的参数
class Net(nn.Cell):
def __init__(self,in_channels=10,out_channels=20,kernel_size=3):
super(Net,self).__init__()
self.conv2d=ops.Conv2D(out_channels,kernel_size)
self.bias_add=ops.BiasAdd()
self.weight=Parameter(initializer('normal',[out_channels,in_channels,kernel_size,kernel_size]),name='conv.weight')
def construct(self,x):
output=self.conv2d(x,self.weight)
output=self.bias_add(output,self.bias)
return output
net=Net()
net.parameters_dict()
>>>OrderedDict([('conv.weight',
Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True))])

2.get_parameters()
该方法返回一个迭代器,返回的是模型的参数,就是返回上个方法的所以值
iter=net.get_parameters()
next(iter)
>>>Parameter (name=conv2d.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True)
3.name_cells()
返回网络中所有单元格的迭代器
net.name_cells()
>>>OrderedDict([('conv2d',
Conv2d
4.cells_and_names()
返回网络中所有单元格的迭代器,包括单元格的名称和它本身
注意第一个返回的是整个网络,键对应着空
names=[]
for m in net.cells_and_names():
print(m)
names.append(m[0]) if m[0] else None
print('-------names-------')
print(names)
('', Net<
(conv2d): Conv2d
>)
('conv2d', Conv2d
-------names-------
['conv2d']
5.cells()
返回对直接单元格的迭代器
net.cells()
>>>odict_values([Conv2d