partitionBy(self, numPartitions, partitionFunc=portable_hash): 函数里主要有两个参数,一个是numPartitions ,这个是分区的数量,大家都知道。
另一个是partitionFunc,这个分区的函数,默认是哈希函数。当然我们也可以来自定义:
data = sc.parallelize(['1', '2', '3', ]).map(lambda x: (x,x)).collect()
wp = data.partitionBy(data.count(),lambda k: int(k))
print wp.map(lambda t: t[0]).glom().collect()
这里的自定义函数是最简单的 lambda k: int(k),即根据自身的int值来分区。我们还可以根据需要定义其他更多的分区函数。
下面给出partitionBy的源码:
def partitionBy(self, numPartitions, partitionFunc=portable_hash):
“””
Return a copy of the RDD partitioned using the specified partitioner.
>>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
>>> sets = pairs.partitionBy(2).glom().collect()
>>> set(sets[0]).intersection(set(sets[1]))
set([])
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
# Transferring O(n) objects to Java is too expensive.
# Instead, we'll form the hash buckets in Python,
# transferring O(numPartitions) objects to Java.
# Each object is a (splitNumber, [objects]) pair.
# In order to avoid too huge objects, the objects are
# grouped into chunks.
outputSerializer = self.ctx._unbatched_serializer
limit = (_parse_memory(self.ctx._conf.get(
"spark.python.worker.memory", "512m")) / 2)
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
c, batch = 0, min(10 * numPartitions, 1000)
for (k, v) in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
c += 1
# check used memory and avg size of chunk of objects
if (c % 1000 == 0 and get_used_memory() > limit
or c > batch):
n, size = len(buckets), 0
for split in buckets.keys():
yield pack_long(split)
d = outputSerializer.dumps(buckets[split])
del buckets[split]
yield d
size += len(d)
avg = (size / n) >> 20
# let 1M < avg < 10M
if avg < 1:
batch *= 1.5
elif avg > 10:
batch = max(batch / 1.5, 1)
c = 0
for (split, items) in buckets.iteritems():
yield pack_long(split)
yield outputSerializer.dumps(items)
keyed = self.mapPartitionsWithIndex(add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique,
# even if partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
return rdd