spark中递归join一系列dataframe

代码如下:

def join_dfs(df_list,key=['id']):
	if len(df_list) == 1:
		retirm df_list[0]
	def join_df(df1,df2,key=['id']):
		return df1.join(df2, key)
	return reduce(join_df, df_list)

def join_df_recursive(df_list, key=['id']):
	len_df = len(df_list)
	if len_df == 0:
		return df_list
	if len_df == 1:
		return df_list[0]
	else:
		chunk_size = min(int(len_df)/2), 8) if len_df >4 else 2
		chunk_number = int(len(df_list)/chunk_size)
		chunks = np.array_split(df_list, chunk_number)
		para = min(8, chunk_number)
		pool = ThreadPool(int(para))
		df_list = pool.map(lambda df_list : join_dfs(df_list, key), chunks)
		pool.close()
		pool.join()
		return join_df_recursive(df_list, key)

你可能感兴趣的:(Spark,Spark学习随笔)