TensorFlowOnSpark、PySpark
Wide&Deep模型报错
ValueError: Items of feature_columns must be a _FeatureColumn. Given (type
原因分析
单机模型运行同样代码可以正常运行,但使用TFCluster.run 执行就会出现以上类型错误
主要是spark将方法广播到executor执行时序列化过程出现问题,此时无法将TensorFlow项目中_FeatureColumn类型正确序列化,导致feature_column.py中报检查类型出错
for column in feature_columns:
print("==="+str(type(column)))
if not isinstance(column, _FeatureColumn):
raise ValueError('Items of feature_columns must be a _FeatureColumn. '
'Given (type {}): {}.'.format(type(column), column))
解决办法,将如下代码加到driver程序的python文件中
import collections
collections.namedtuple.__hijack = 1
当__hijack有值后,pyspark项目中serializer.py的如下方法内容将不会被执行
def _hijack_namedtuple(): """ Hack namedtuple() to make it picklable """ # hijack only one time if hasattr(collections.namedtuple, "__hijack"): return global _old_namedtuple # or it will put in closure global _old_namedtuple_kwdefaults # or it will put in closure too def _copy_func(f): return types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__) def _kwdefaults(f): # __kwdefaults__ contains the default values of keyword-only arguments which are # introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple # are as below: # # - Does not exist in Python 2. # - Returns None in <= Python 3.5.x. # - Returns a dictionary containing the default values to the keys from Python 3.6.x # (See https://bugs.python.org/issue25628). kargs = getattr(f, "__kwdefaults__", None) if kargs is None: return {} else: return kargs _old_namedtuple = _copy_func(collections.namedtuple) _old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple) def namedtuple(*args, **kwargs): for k, v in _old_namedtuple_kwdefaults.items(): kwargs[k] = kwargs.get(k, v) cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 # hack the cls already generated by namedtuple. # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple and hasattr(o, "_fields") and "__reduce__" not in o.__dict__): _hack_namedtuple(o) # hack inplace