老司机请无视 方便其他框架转过来的兄弟快速上手
首先是mxnet的安装
安装地址
其次是gluoncv的安装(gluoncv就是mxnet框架的model_zoo)
pip install gluoncv
在mxnet中的Tensor都是ndarray类型,由mxnet.nd这个库进行创建、加减乘除等等一大堆的操作
例如:
form mxnet import nd
from mxnet import gpu
# 生成shape为[1,3,224,224]的0矩阵(默认是在cpu上)
zeros = nd.zeros([1, 3, 224, 224])
# 从cpu转gpu
zeros.as_in_context(gpu())
```<br>
更多操作查询api文档<br>
## 2. mxnet.numpy (1.6版本以上才有的接口)
到本文发布为止,官网api中未提供numpy接口详细说明...1.6刚出来的功能<br>
用法和numpy一样(牛逼的地方是可以用gpu计算),就是为了完善mxnet.ndarray的某些功能,可以用mxnet.numpy操作ndarray数据<br>
例如:
```python
from mxnet import nd
from mxnet import numpy as np
a = nd.ones([3,3])
# 用np扩展维度
a = a[np.newaxis, :]
print(a.shape)
建议用mxnet.nd创建数据 mxnet.numpy操作数据(加减乘除等等)
from mxnet import nd
from gluoncv import model_zoo
net = model_zoo.get_model("yolo3_darknet53_coco", pretrained=False)
net.initialize() # 初始化网络权重等等参数 很重要!
net.hybridize() # 由动态图转静态图 转静态图速度会快点
data_shape = (1, 3, 416, 416)
input_data = nd.random.uniform(-1, 1, data_shape)
out = net(input_data)
print(out)
mxnet.gluon这个接口主要就是mxnet的动态图的接口(静态图接口是mxnet.sym好像 没怎么用它)
常用gluon.nn来创建网络,比如Dense Conv2d 等等
gluon.nn中有这俩东西: nn.Sequential 和 nn.HybridSequential
都是用来构建动态图的,区别就是带Hybrid字眼的能通过hybridize()函数转化为静态图,所以基本使用nn.HybridSequential
同样创建网络时候所继承的Block也有对应的HybridBlock,效果一致
from mxnet.gluon import HybridBlock, nn
# 例子1
net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
net.add(nn.Dense(10, activation='relu'))
net.add(nn.Dense(20))
net.hybridize()
# 例子2
class Model(HybridBlock):
def __init__(self, **kwargs):
# 网络各个layer必须在__init__中初始化 不可在hybrid_forward中初始化
super(Model, self).__init__(**kwargs)
# use name_scope to give child Blocks appropriate names.
with self.name_scope():
self.dense0 = nn.Dense(20)
self.dense1 = nn.Dense(20)
def hybrid_forward(self, F, x): # F指mxnet.ndarray 会调用ndarray中的方法操作数据
x = F.relu(self.dense0(x))
return F.relu(self.dense1(x))
model = Model()
model.initialize(ctx=mx.cpu(0))
model.hybridize()
model(mx.nd.zeros((10, 10), ctx=mx.cpu(0)))
[这块和pytorch一样 由dataset类和dataloader类构成}(http://mxnet.incubator.apache.org/api/python/docs/api/gluon/data/index.html)
from mxnet.gluon.data import Dataset
from mxnet.gluon.data import DataLoader
class DatasetBase(Dataset):
"""
只需要重写这三个函数
"""
def __init__(self, data_root, transform=None, is_train=True):
"""
指定数据集地址等等初始化操作
"""
super(DatasetBase, self).__init__()
self.transform = transform
self.all = [] # 假设其存放所有的数据信息(例如图片路径和label)
def __getitem__(self, idx):
"""
返回整个数据集中第idx个数据
returns the i-th element
"""
data = self.all[idx]
if self.transform is not None:
return self.transform(data)
return data
def __len__(self):
"""
which returns the total number elements.
"""
return len(self.all)
dataset = DatasetBase("path_to_img")
dataloader = DataLoader(
dataset=dataset,
batch_size=256,
shuffle=True,
num_workers=1
)
for data in dataloder:
# 得到数据
pass
想到啥再补充