比较mindspore和torch网络建立相关的内容
一、网络构建:分别使用mindspore和torch构建两个相同的网络:
mindspore:
建立一个子Cell:ConvBNReLU:
from mindspore import Tensor, ops, Parameter, nn
class ConvBNReLU(nn.Cell):
def __init__(self):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.bn = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
定义一个简单的网络,包含两个buildblock
build_block1:使用nn.SequentialCell容器对子模块进行管理,build_block1除了包含上述ConvBNReLU外,还append了一层pool。nn.SequentialCell有以下特点:
build_block2:使用nn.CellList容器对子模块进行管理,build_block2分别包含了conv,bn,relu层
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
layers = [ConvBNReLU()]
self.build_block1 = nn.SequentialCell(layers)
self.build_block1.append(nn.MaxPool2d(2))
self.build_block2 = nn.CellList([nn.Conv2d(64, 4, 4)])
self.build_block2.append(nn.ReLU())
self.build_block2.insert(-1, nn.BatchNorm2d(4))
def construct(self, x):
output = self.build_block1(x)
for layer in self.build_block2:
output = layer(output)
return output
torch:
定义一个和上述网络相同结构的网络,包含两个buildblock
build_block1:使用nn.Sequential容器对子模块进行管理,build_block1除了包含上述ConvBNReLU外,还加入了一层pool。nn.Sequential有以下特点:
build_block2:使用nn.CellList容器对子模块进行管理,build_block2分别包含了conv,bn,relu层
import torch
import torch.nn as nn
class ConvBNReLU(nn.Module):
def __init__(self):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.bn = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.build_block1 = nn.Sequential(ConvBNReLU(),)
self.build_block1.add_module("pool", nn.MaxPool2d(2))
self.build_block2 = nn.ModuleList([nn.Conv2d(64, 4, 4)])
self.build_block2.append(nn.ReLU())
self.build_block2.insert(1, nn.BatchNorm2d(4))
def forward(self, x):
output = self.build_block1(x)
for layer in self.build_block2:
output = layer(output)
return output
net = MyNet()
input = torch.FloatTensor(1, 3, 64, 32)
output = net(input)
print(net)
输出网络结构(mindspore):
MyNet<
(build_block1): SequentialCell<
(0): ConvBNReLU<
(conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3),stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=Falseweight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block1.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block1.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block1.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block1.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
(relu): ReLU<>
>
(1): MaxPool2d<kernel_size=2, stride=1, pad_mode=VALID>
>
(build_block2): CellList<
(0): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4),stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=Falseweight_init=normal, bias_init=zeros, format=NCHW>
(1): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block2.1.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block2.1.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block2.1.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block2.1.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
(2): ReLU<>
>
>
输出网络结构(torch):
MyNet(
(build_block1): Sequential(
(0): ConvBNReLU(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(build_block2): ModuleList(
(0): Conv2d(64, 4, kernel_size=(4, 4), stride=(1, 1))
(1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)