训练了一个DecisionTreeModel
,然后在RDD 上准备进行验证:
dtModel = DecisionTree.trainClassifier(data, 2, {}, impurity="entropy", maxDepth=maxTreeDepth)
predictions = dtModel.predict(data.map(lambda lp: lp.features))
def GetDtLabel(x):
return 1 if dtModel.predict(x.features) > 0.5 else 0
dtTotalCorrect = data.map(lambda point : 1 if GetDtLabel(point) == point.label else 0).sum()
提示错误:
Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.
看scala的代码没问题,以为是dtModel需要广播一下,但是错误依旧:
dtModelBroadcast = sc.broadcast(dtModel)
最后根据下面stackoverflow提到的才发现是pyspark的问题:
http://stackoverflow.com/questions/31684842/how-to-use-java-scala-function-from-an-action-or-a-transformation
http://stackoverflow.com/questions/36838024/combining-spark-streaming-mllib
pyspark里面 DescitionTreeModel的predict方法源代码提到
“In Python, predict cannot currently be used within an RDD transformation or action.
Call predict directly on the RDD instead.”
def predict(self, x): """ Predict the label of one or more examples. Note: In Python, predict cannot currently be used within an RDD transformation or action. Call predict directly on the RDD instead. :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) else: return self.call("predict", _convert_to_vector(x))
class JavaModelWrapper(object):
"""
Wrapper for the model in JVM
"""
def __init__(self, java_model):
self._sc = SparkContext.getOrCreate()
self._java_model = java_model
def __del__(self):
self._sc._gateway.detach(self._java_model)
def call(self, name, *a):
"""Call method of java_model"""
return callJavaFunc(self._sc, getattr(self._java_model, name), *a)