个人对移动端神经网络开发一直饶有兴致。去年腾讯开源了NCNN框架之后,一直都在关注。近期成功利用别人训练好的mtcnn和mobilefacenet模型制作了一个ios版本人脸识别swift版本demo。希望maskrcnn移植到ncnn,在手机端实现一些有趣的应用。因为unet模型比较简单,干脆就从这个入手。
基本的网络基于keras版本: https://github.com/TianzhongSong/Person-Segmentation-Keras
不过keras没办法直接转成ncnn模型,研究过通过onnx模型做中间跳板,采用了一些开源的转换工具,也是一堆问题。NCNN支持几个神经网络训练框架:caffe/mxnet/pytorch,在ncnn的github有一篇issue里nihui推荐采用mxnet,因此mxnet也成为了我的首选。
利用Person-Segmentation-Keras项目的数据集,同时基于https://github.com/milesial/Pytorch-UNet/tree/master/unet这个项目捣鼓了几段代码。训练完成,用来测试ncnn转换基本可用。
转换过程发现许多问题,一个是调用ncnn extract会crash,经过调查,发现mxnet2ncnn工具也有bug,blob个数算错,其次是input层one_blob_only标志我的理解应该是false,不知道什么原因转换过来的模型这边是true,导致forward_layer函数里面bottoms变量访问异常。后来一层层extract出来打印输出的channel/width/height调查后又发现,我把unet.py里的name为pool5写成了pool4(文章中的code已经纠正),可能前面的crash跟这个致命错误有关系也说不定。只好重新训练模型,几个小时漫长等待,剩下部分下周再写。部分代码已经更新,请参考: https://github.com/xuduo35/unet_mxnet2ncnn
unetdataiter.py
#!/usr/bin/env python
# coding=utf8
import os
import sys
import random
import cv2
import mxnet as mx
import numpy as np
from mxnet.io import DataIter, DataBatch
sys.path.append('../')
def get_batch(items, root_path, nClasses, height, width):
x = []
y = []
for item in items:
image_path = root_path + item.split(' ')[0]
label_path = root_path + item.split(' ')[-1].strip()
img = cv2.imread(image_path, 1)
label_img = cv2.imread(label_path, 1)
im = np.zeros((width, height, 3), dtype='uint8')
im[:, :, :] = 128
lim = np.zeros((width, height, 3), dtype='uint8')
if img.shape[0] >= img.shape[1]:
scale = img.shape[0] / height
new_width = int(img.shape[1] / scale)
diff = (width - new_width) // 2
img = cv2.resize(img, (new_width, height))
label_img = cv2.resize(label_img, (new_width, height))
im[:, diff:diff + new_width, :] = img
lim[:, diff:diff + new_width, :] = label_img
else:
scale = img.shape[1] / width
new_height = int(img.shape[0] / scale)
diff = (height - new_height) // 2
img = cv2.resize(img, (width, new_height))
label_img = cv2.resize(label_img, (width, new_height))
im[diff:diff + new_height, :, :] = img
lim[diff:diff + new_height, :, :] = label_img
lim = lim[:, :, 0]
seg_labels = np.zeros((height, width, nClasses))
for c in range(nClasses):
seg_labels[:, :, c] = (lim == c).astype(int)
im = np.float32(im) / 127.5 - 1
seg_labels = np.reshape(seg_labels, (width * height, nClasses))
x.append(im.transpose((2,0,1)))
y.append(seg_labels.transpose((1,0)))
return mx.nd.array(x), mx.nd.array(y)
class UnetDataIter(mx.io.DataIter):
def __init__(self, root_path, path_file, batch_size, n_classes, input_width, input_height, train=True):
f = open(path_file, 'r')
self.items = f.readlines()
f.close()
self._provide_data = [['data', (batch_size, 3, input_width, input_height)]]
self._provide_label = [['softmax_label', (batch_size, n_classes, input_width*input_height)]]
self.root_path = root_path
self.batch_size = batch_size
self.num_batches = len(self.items) // batch_size
self.n_classes = n_classes
self.input_height = input_height
self.input_width = input_width
self.train = train
self.reset()
def __iter__(self):
return self
def reset(self):
self.cur_batch = 0
self.shuffled_items = []
index = [n for n in range(len(self.items))]
if self.train:
random.shuffle(index)
for i in range(len(self.items)):
self.shuffled_items.append(self.items[index[i]])
def __next__(self):
return self.next()
@property
def provide_data(self):
return self._provide_data
@property
def provide_label(self):
return self._provide_label
def next(self):
if self.cur_batch == 0:
print("")
print("\r\033[k"+("Training " if self.train else "Validating ")+str(self.cur_batch)+"/"+str(self.num_batches), end=' ')
if self.cur_batch < self.num_batches:
data, label = get_batch(self.shuffled_items[self.cur_batch * self.batch_size:(self.cur_batch + 1) * self.batch_size], self.root_path, self.n_classes, self.input_height, self.input_width)
self.cur_batch += 1
return mx.io.DataBatch([data], [label])
else:
raise StopIteration
if __name__ =='__main__':
root_path = '/datasets/'
train_file = './data/seg_train.txt'
val_file = './data/seg_test.txt'
batch_size = 16
n_classes = 2
img_width = 256
img_height = 256
trainiter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True)
while True:
trainiter.next()
unet.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
def dice_coef(y_true, y_pred):
intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=(1, 2, 3))
return mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.sum(y_true, axis=(1, 2, 3)) + mx.sym.sum(y_pred, axis=(1, 2, 3)) + 1.))
def dice_coef_loss(y_true, y_pred):
intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=1, )
return -mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.broadcast_add(mx.sym.sum(y_true, axis=1), mx.sym.sum(y_pred, axis=1)) + 1.))
def build_unet(batch_size, input_width, input_height, train=True):
data = mx.sym.Variable(name='data')
label = mx.sym.Variable(name='softmax_label')
# encode
# 256x256
conv1 = mx.sym.Convolution(data, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_1')
conv1 = mx.sym.BatchNorm(conv1, name='bn1_1')
conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_1')
conv1 = mx.sym.Convolution(conv1, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_2')
conv1 = mx.sym.BatchNorm(conv1, name='bn1_2')
conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_2')
pool1 = mx.sym.Pooling(conv1, kernel=(2,2), pool_type='max', name='pool1')
# 128x128
conv2 = mx.sym.Convolution(pool1, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_1')
conv2 = mx.sym.BatchNorm(conv2, name='bn2_1')
conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_1')
conv2 = mx.sym.Convolution(conv2, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_2')
conv2 = mx.sym.BatchNorm(conv2, name='bn2_2')
conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_2')
pool2 = mx.sym.Pooling(conv2, kernel=(2,2), pool_type='max', name='pool2')
# 64x64
conv3 = mx.sym.Convolution(pool2, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_1')
conv3 = mx.sym.BatchNorm(conv3, name='bn3_1')
conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_1')
conv3 = mx.sym.Convolution(conv3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_2')
conv3 = mx.sym.BatchNorm(conv3, name='bn3_2')
conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_2')
pool3 = mx.sym.Pooling(conv3, kernel=(2,2), pool_type='max', name='pool3')
# 32x32
conv4 = mx.sym.Convolution(pool3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_1')
conv4 = mx.sym.BatchNorm(conv4, name='bn4_1')
conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_1')
conv4 = mx.sym.Convolution(conv4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_2')
conv4 = mx.sym.BatchNorm(conv4, name='bn4_2')
conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_2')
pool4 = mx.sym.Pooling(conv4, kernel=(2,2), pool_type='max', name='pool4')
# 16x16
conv5 = mx.sym.Convolution(pool4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_1')
conv5 = mx.sym.BatchNorm(conv5, name='bn5_1')
conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_1')
conv5 = mx.sym.Convolution(conv5, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_2')
conv5 = mx.sym.BatchNorm(conv5, name='bn5_2')
conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_2')
pool5 = mx.sym.Pooling(conv5, kernel=(2,2), pool_type='max', name='pool5')
# 8x8
# decode
trans_conv6 = mx.sym.Deconvolution(pool5, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv6')
up6 = mx.sym.concat(*[trans_conv6, conv5], dim=1, name='concat6')
conv6 = mx.sym.Convolution(up6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_1')
conv6 = mx.sym.BatchNorm(conv6, name='bn6_1')
conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_1')
conv6 = mx.sym.Convolution(conv6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_2')
conv6 = mx.sym.BatchNorm(conv6, name='bn6_2')
conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_2')
trans_conv7 = mx.sym.Deconvolution(conv6, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv7')
up7 = mx.sym.concat(*[trans_conv7, conv4], dim=1, name='concat7')
conv7 = mx.sym.Convolution(up7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_1')
conv7 = mx.sym.BatchNorm(conv7, name='bn7_1')
conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_1')
conv7 = mx.sym.Convolution(conv7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_2')
conv7 = mx.sym.BatchNorm(conv7, name='bn7_2')
conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_2')
trans_conv8 = mx.sym.Deconvolution(conv7, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv8')
up8 = mx.sym.concat(*[trans_conv8, conv3], dim=1, name='concat8')
conv8 = mx.sym.Convolution(up8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_1')
conv8 = mx.sym.BatchNorm(conv8, name='bn8_1')
conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_1')
conv8 = mx.sym.Convolution(conv8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_2')
conv8 = mx.sym.BatchNorm(conv8, name='bn8_2')
conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_2')
trans_conv9 = mx.sym.Deconvolution(conv8, num_filter=128, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv9')
up9 = mx.sym.concat(*[trans_conv9, conv2], dim=1, name='concat9')
conv9 = mx.sym.Convolution(up9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_1')
conv9 = mx.sym.BatchNorm(conv9, name='bn9_1')
conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_1')
conv9 = mx.sym.Convolution(conv9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_2')
conv9 = mx.sym.BatchNorm(conv9, name='bn9_2')
conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_2')
trans_conv10 = mx.sym.Deconvolution(conv9, num_filter=64, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv10')
up10 = mx.sym.concat(*[trans_conv10, conv1], dim=1, name='concat10')
conv10 = mx.sym.Convolution(up10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_1')
conv10 = mx.sym.BatchNorm(conv10, name='bn10_1')
conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_1')
conv10 = mx.sym.Convolution(conv10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_2')
conv10 = mx.sym.BatchNorm(conv10, name='bn10_2')
conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_2')
###
conv11 = mx.sym.Convolution(conv10, num_filter=2, kernel=(1,1), name='conv11_1')
conv11 = mx.sym.sigmoid(conv11, name='softmax')
net = mx.sym.Reshape(conv11, (batch_size, 2, input_width*input_height))
if train:
loss = mx.sym.MakeLoss(dice_coef_loss(label, net), normalization='batch')
mask_output = mx.sym.BlockGrad(conv11, 'mask')
out = mx.sym.Group([loss, mask_output])
else:
# mask_output = mx.sym.BlockGrad(conv11, 'mask')
out = mx.sym.Group([conv11])
return out
trainunet.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet
def main():
root_path = '../datasets/'
train_file = './data/seg_train.txt'
val_file = './data/seg_test.txt'
batch_size = 16
n_classes = 2
# img_width = 256
# img_height = 256
img_width = 96
img_height = 96
train_iter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True)
val_iter = UnetDataIter(root_path, val_file, batch_size, n_classes, img_width, img_height, False)
ctx = [mx.gpu(0)]
unet_sym = build_unet(batch_size, img_width, img_height)
unet = mx.mod.Module(unet_sym, context=ctx, data_names=('data',), label_names=('softmax_label',))
unet.bind(data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=[['softmax_label', (batch_size, n_classes, img_width*img_height)]])
unet.init_params(mx.initializer.Xavier(magnitude=6))
unet.init_optimizer(optimizer = 'adam',
optimizer_params=(
('learning_rate', 1E-4),
('beta1', 0.9),
('beta2', 0.99)
))
# unet.fit(train_iter, # train data
# eval_data=val_iter, # validation data
# #optimizer='sgd', # use SGD to train
# #optimizer_params={'learning_rate':0.1}, # use fixed learning rate
# eval_metric='acc', # report accuracy during training
# batch_end_callback = mx.callback.Speedometer(batch_size, 1), # output progress for each 100 data batches
# num_epoch=10) # train for at most 10 dataset passes
epochs = 20
smoothing_constant = .01
curr_losses = []
moving_losses = []
i = 0
best_val_loss = np.inf
for e in range(epochs):
while True:
try:
batch = next(train_iter)
except StopIteration:
train_iter.reset()
break
unet.forward_backward(batch)
loss = unet.get_outputs()[0]
unet.update()
curr_loss = F.mean(loss).asscalar()
curr_losses.append(curr_loss)
moving_loss = (curr_loss if ((i == 0) and (e == 0))
else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)
moving_losses.append(moving_loss)
i += 1
val_losses = []
for batch in val_iter:
unet.forward(batch)
loss = unet.get_outputs()[0]
val_losses.append(F.mean(loss).asscalar())
val_iter.reset()
val_loss = np.mean(val_losses)
print("\nEpoch %i: Moving Training Loss %0.5f, Validation Loss %0.5f" % (e, moving_loss, val_loss))
unet.save_checkpoint('./unet_person_segmentation', e)
if __name__ =='__main__':
main()
以上是训练代码。
预测代码如下predict.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import sys
import cv2
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet
def post_process_mask(label, img_cols, img_rows, n_classes, p=0.5):
pr = label.reshape(n_classes, img_cols, img_rows).transpose([1,2,0]).argmax(axis=2)
return (pr*255).asnumpy()
def load_image(img, width, height):
im = np.zeros((height, width, 3), dtype='uint8')
im[:, :, :] = 128
if img.shape[0] >= img.shape[1]:
scale = img.shape[0] / height
new_width = int(img.shape[1] / scale)
diff = (width - new_width) // 2
img = cv2.resize(img, (new_width, height))
im[:, diff:diff + new_width, :] = img
else:
scale = img.shape[1] / width
new_height = int(img.shape[0] / scale)
diff = (height - new_height) // 2
img = cv2.resize(img, (width, new_height))
im[diff:diff + new_height, :, :] = img
im = np.float32(im) / 127.5 - 1
return [im.transpose((2,0,1))]
def main():
batch_size = 16
n_classes = 2
# img_width = 256
# img_height = 256
img_width = 96
img_height = 96
ctx = [mx.gpu(0)]
# sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 20)
# unet_sym = build_unet(batch_size, img_width, img_height, False)
# unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)
sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 0)
unet = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
unet.set_params(arg_params, aux_params, allow_missing=True)
testimg = cv2.imread(sys.argv[1], 1)
img = load_image(testimg, img_width, img_height)
unet.predict(mx.io.NDArrayIter(data=[img]))
outputs = unet.get_outputs()[0]
cv2.imshow('test', testimg)
cv2.imshow('mask', post_process_mask(outputs[0], img_width, img_height, n_classes))
cv2.waitKey()
if __name__ == '__main__':
if len(sys.argv) < 2:
print("illegal parameters")
sys.exit(0)
main()
剥离softmax保存参数用于ncnn模型转换,train2infer.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import sys
import cv2
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet
def main():
batch_size = 16
n_classes = 2
# img_width = 256
# img_height = 256
img_width = 96
img_height = 96
ctx = [mx.gpu(0)]
sym, arg_params, aux_params = mx.model.load_checkpoint(sys.argv[1], int(sys.argv[2]))
unet_sym = build_unet(batch_size, img_width, img_height, False)
unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)
unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
unet.set_params(arg_params, aux_params, allow_missing=True)
unet.save_checkpoint('./unet_person_segmentation', 0)
if __name__ == '__main__':
if len(sys.argv) < 3:
print("illegal parameters")
sys.exit(0)
main()