记一次spark mllib stackoverflow踩坑

以前做als相关的东西的时候,都是用的公司的内部工具居多,今天第一次用了下spark的mlib,拿了个几M的小数据集试了个水。。

结果一跑,我擦。。。居然stackoverflow了。。


源码如下:

from pyspark.mllib.recommendation import ALS
from numpy import array
from pyspark import SparkContext

if __name__ == '__main__':
	# sc = SparkSession\
	# 	.builder\
	# 	.appName("PythonWordCount")\
	# 	.getOrCreate()
	sc = SparkContext(appName="PythonWordCount")

	data = sc.textFile("CollaborativeFiltering.txt", 20)
	ratings = data.map(lambda line: [float(x) for x in line.split(' ')]).persist()
	rank = 10
	n = 30

	model = ALS.train(ratings, rank, n)
	testdata = ratings.map(lambda r: (int(r[0]), int(r[1])))
	predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))

	ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions).persist()
	MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count()

	print "Mean Squared Error = " + str(MSE)
	ratesAndPreds.unpersist()

错误信息如下:

2017-11-24 17:15:23 [INFO] ShuffleMapStage 66 (flatMap at ALS.scala:1272) failed in Unknown s due to Job aborted due to stage failure: Task serialization failed: java.lang.StackOverflowError
java.lang.StackOverflowError
    at java.io.ObjectOutputStream$BlockDataOutputStream.write(ObjectOutputStream.java:1841)
    at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1534)
    at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
    at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
    at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
    at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348)
    at scala.collection.immutable.$colon$colon.writeObject(List.scala:379)
    at sun.reflect.GeneratedMethodAccessor15.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeWriteObject(ObjectStreamClass.java:1028)
    at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1496)
    at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
    at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
    at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
    at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
    at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
    at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
    at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)


一开始以为是join的问题,查了好久,无果,又怀疑是jvm设置的问题。。。很显然也没有进展。。

泪崩 + 泪崩 + 泪崩

再后来怀疑到了linage 是不是过长导致,遂google和请教大神

记一次spark mllib stackoverflow踩坑_第1张图片

发现果然如此,spark在迭代计算的过程中,会导致linage剧烈变长,所需的栈空间也急剧上升,最终爆栈了。。

这类问题解决方法如下:

在代码中加入 sc.setCheckpointDir(path),显示指明checkpoint路径,问题便可得到解决。当然这也带来了一个问题,如果数据量变大,磁盘的IO变成为了瓶颈,这方面暂时没能解决,各位聚聚有更好的解决方案,欢迎联系我~

修改后代码如下:

from pyspark.mllib.recommendation import ALS
from numpy import array
from pyspark import SparkContext

if __name__ == '__main__':
	# sc = SparkSession\
	# 	.builder\
	# 	.appName("PythonWordCount")\
	# 	.getOrCreate()
	sc = SparkContext(appName="PythonWordCount")
	sc.setCheckpointDir('checkpoint')
	data = sc.textFile("CollaborativeFiltering.txt", 20)
	ratings = data.map(lambda line: [float(x) for x in line.split(' ')]).persist()
	rank = 10
	n = 30
	#ALS.setCheckpointInterval(2).setMaxIter(100).setRank(10).setAlpha(0.1)
	model = ALS.train(ratings, rank, n)
	testdata = ratings.map(lambda r: (int(r[0]), int(r[1])))
	predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))

	ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions).persist()
	MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count()

	print "Mean Squared Error = " + str(MSE)
	ratesAndPreds.unpersist()


你可能感兴趣的:(机器学习,spark)