参考:
1. https://github.com/keras-team/keras/issues/2397
今天在把Mask RCNN改成ROS Server来使用的时候,遇到了这个错误,我是根据这个ROS Subscriber版的Mask RCNN的基础上来改的,原代码运行很正常,但是我在改成ROS Node之后出现了这个错误,看了一晚上才意识到,正如参考[1]里有人提到的:
I had this problem when doing inference in a different thread than where I loaded my model. Here’s how I fixed the problem:
我改完成ROS server之后,load model是在实例化我的调用mask rcnn的类的时候进行的,然而inference是在接收到request的时候才进行,显然不在一个进程里。而那个写成subscriber的版本,他们是在同一个进程里的,subscribe的图片不断的写入一个类成员变量里,这里利用了python多线程中互斥锁确保不会同时读写这个变量,然后就可以让model对当前的图片进行inference了,代码如下:
class MaskRCNNNode(object):
def __init__(self):
self._cv_bridge = CvBridge()
config = InferenceConfig()
config.display()
self._visualization = rospy.get_param('~visualization', True)
# Create model object in inference mode.
self._model = modellib.MaskRCNN(mode="inference", model_dir="",
config=config)
# Load weights trained on MS-COCO
model_path = rospy.get_param('~model_path', COCO_MODEL_PATH)
# Download COCO trained weights from Releases if needed
if model_path == COCO_MODEL_PATH and not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH)
self._model.load_weights(model_path, by_name=True)
self._class_names = rospy.get_param('~class_names', CLASS_NAMES)
self._last_msg = None
self._msg_lock = threading.Lock()
self._class_colors = visualize.random_colors(len(CLASS_NAMES))
self._publish_rate = rospy.get_param('~publish_rate', 100)
def run(self):
self._result_pub = rospy.Publisher('~result', Result, queue_size=1)
vis_pub = rospy.Publisher('~visualization', Image, queue_size=1)
sub = rospy.Subscriber('~input', Image,
self._image_callback, queue_size=1)
rate = rospy.Rate(self._publish_rate)
while not rospy.is_shutdown():
if self._msg_lock.acquire(False):
msg = self._last_msg
self._last_msg = None
self._msg_lock.release()
else:
rate.sleep()
continue
if msg is not None:
np_image = self._cv_bridge.imgmsg_to_cv2(msg, 'bgr8')
# Run detection
results = self._model.detect([np_image], verbose=0)
result = results[0]
result_msg = self._build_result_msg(msg, result)
self._result_pub.publish(result_msg)
# Visualize results
if self._visualization:
vis_image = self._visualize(result, np_image)
cv_result = np.zeros(shape=vis_image.shape, dtype=np.uint8)
cv2.convertScaleAbs(vis_image, cv_result)
image_msg = self._cv_bridge.cv2_to_imgmsg(cv_result, 'bgr8')
vis_pub.publish(image_msg)
rate.sleep()
然而我改完之后二者不在同一个进程了,所以按照参考[1]里的方式:
# Right after loading or constructing your model, save the TensorFlow graph:
graph = tf.get_default_graph()
# In the other thread (or perhaps in an asynchronous event handler), do:
global graph
with graph.as_default():
(... do inference here ...)
对我的代码进行了改动,增加了self.graph这个变量,如下所示:
class MaskRCNNNode(object):
def __init__(self):
self._cv_bridge = CvBridge()
config = InferenceConfig()
config.display()
self._visualization = rospy.get_param('~visualization', True)
# Create model object in inference mode.
self._model = modellib.MaskRCNN(mode="inference", model_dir="",
config=config)
# Load weights trained on MS-COCO
rospack = rospkg.RosPack()
model_path = rospack.get_path('mask_rcnn_ros')+'/models/mask_rcnn_coco.h5'
self._model.load_weights(model_path, by_name=True)
self.graph = tf.get_default_graph()
self._class_names = rospy.get_param('~class_names', CLASS_NAMES)
self._class_colors = visualize.random_colors(len(CLASS_NAMES))
self.vis_pub = rospy.Publisher('~visualization', Image, queue_size=1)
self.server = rospy.Service('instance_segmentation', InstanceSegmentation, self.handle_instance_segmentation)
rospy.loginfo("Waiting for request!")
def handle_instance_segmentation(self, req):
rospy.loginfo("Request received!")
np_image = self._cv_bridge.imgmsg_to_cv2(req.color_image, "bgr8")
# Run detection
with self.graph.as_default():
results = self._model.detect([np_image], verbose=0)
print "got result!"
result = results[0]
resp = InstanceSegmentationResponse()
resp.segmentation_result = self._build_result_msg(req.color_image, result)
# Visualize results
if self._visualization:
vis_image = self._visualize(result, np_image)
cv_result = np.zeros(shape=vis_image.shape, dtype=np.uint8)
cv2.convertScaleAbs(vis_image, cv_result)
image_msg = self._cv_bridge.cv2_to_imgmsg(cv_result, 'bgr8')
self.vis_pub.publish(image_msg)
return resp
然后问题就顺利解决了。