本文记录spark开发过程中遇到的小知识点,使用pyspark开发,由于使用大多数场景为DataFrame,介绍也多为DataFrame。本文比较长,在学习过程中摘了一些博客和资料,如果有描述的不对的地方请指出。
Spark是分布式内存计算,能够依据各类操作创建一个计算DAG图,数据通过DAG处理后生成结果。
对spark的数据操作分为两类,一类是转换(transformation)操作,比如Filter、map、flatMap、reduce等,但是这些操作是懒转换,只在action的时候才真正的对数据做处理;另一类是action为操作,比如collect、show、count、first等,它们能够触发数据的计算得到结果。
数据集分为RDD、DataFrame、DataSet,其中DataFrame可以看做带格式的RDD,(因为格式确定,所以处理计算效率高于RDD),对DataFrame的操作可以视为对一张数据表操作,由于数据集的不可变特性,不能够修改原有DataFrame,而只能创建新的DataFrame,如增加列生成新的DF,删减列生成新的DF等。由于对DataFrame类似于表,spark提供了SQL的方式进行计算和操作,很多计算可以直接通过SQL的方式解决掉。
注意,虽然PySpark中也叫DataFrame,但是其与Pandas的DataFrame的很多操作不一样,虽然有很多类似之处。
from pyspark.sql import Row
from pyspark.sql.functions import col, isnan, isnull
from pyspark.sql import SparkSession # SparkConf、SparkContext 和 SQLContext 都已经被封装在 SparkSession
from pyspark.sql.types import *
# 创建spark 如果是pypsark的话,直接用内置的spark变量
spark = SparkSession.builder.appName('test pyspark').getOrCreate()
# 通过读取数据集来创建DataFrame
# 1. 参考本文档读取操作
# 2. 通过RDD
client_rdd = spark.sparkContext.parallelize([
('20180701', '1111', 0.1),
('20180801', "1111", 0.2),
('20180901', "1111", 0.3),
])
client_schema = StructType([
StructField("date", StringType(), True),
StructField("client_id", StringType(), True),
StructField("cash", DoubleType(), True)
])
client_df = spark.createDataFrame(client_rdd, client_schema)
# pandas DataFrame 与 pyspark的DataFrame相互转化
pandas_df = spark_df.toPandas() # spark df转pandas df
spark_df = spark.createDataFrame(pandas_df.values.tolist(), list(pandas_df.columns)) # pandas df转spark df
spark_df = spark.createDataFrame(pandas_df) # pandas df转spark df
# 查看DataFrame
df.show(5, False)
df.first()
df.head(5)
df.take(5)
df.collect() # 将df整体以list的形式返回,不要在大数据集的情况用这种方法
df.collectAsMap() # ?
df.count() # 统计df的行数
df.schema # df的结构
df.printSchema() # 树的形式打印df的结构
df.columns # 查看列名
# 缓存df,对于多次调用一些小的数据集,如果不缓存,则在计算的时候会多次加载,缓存能提高效率
df.cache()
df.persist()
def essure_asset_type(stock_code):
'''随机赋予证券代码的类别属性'''
return random.choice(["equity", "fixincome", "cash"])
asset_type = udf(essure_asset_type, StringType()) # 创建用户自定义函数
df = df.withColumn('asset_type', asset_type(df.symbol)) # 增加新的列,处理的函数必须是处理列的
df = df.withColumn("differ", df.col1 - df.col2) # 增加新的列
df = df.withColumnRenamed('age', 'age2') # 更改列名
df = df.drop(df.col1) # 删除列
df = df.drop("col1")
df.rdd # df 转 rdd
df.rdd.getNumPartitions() # 查看分区数
## DataFrame的类SQL操作 DataFrame通过select、where、groupBy、sum等实现了类SQL的操作
client_df.groupBy(client_df.date, client_df.client_id).sum('cash')
# 两个DF的join操作,不得不说join的存在是在太方便了
join_df = client_df.join(bench_df, client_df.date == bench_df.date, 'left')
# DataFrame的列选择操作
df = df.select("orders", "traders") # 通过选择列生成新的df
df = df.select(df.orders, df.traders)
# DataFrame的过滤操作
df = df.filter("trader=1111")
df = df.where("trader=1111")
df = df.filter(df.trader == '1111')
df = df.where(df.trader == '1111')
df = df.filter(col('trader').like('%1111%'))
df = df.filter(isnull("trader"))
# DataFrame的SQL操作,支持常用的sql操作,如select、where、group by、count、like等
df.registerTempTable("df") # 注册了一个名叫 df 的表
df = spark.sql("select orders from df where trader='1111'")
df = spark.sql("select count(orders) from df where trader='1111'")
# 读取json
df = spark.read.json(json_file_path)
# 读取csv
df = spark.read.csv(csv_file_path, header=True, inferSchema=True)
# 写入csv
df.write.csv(path=csv_file_path, header=True, sep=",", mode='overwrite')
# 读取MySQL
df = spark.read.format('jdbc').options(url='jdbc:oracle:thin:@ip:port:database',
dbtable='table_name or select sql as table', user='user_name',
password='password').load()
# 读取Oracle, 需要在提交时指定JDBC
df = spark.read.format('jdbc').options(url='jdbc:oracle:thin:@ip:port:database',
dbtable='table_name or select sql as table', user='user_name',
password='password').load()
# 读取Hive,注:该方法我没试过
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().master("172.31.100.170:7077").appName("my_first_app_name").getOrCreate()
df=spark.sql("select * from hive_tb_name")
# 读取HDFS文件,如parquet
df = spark.read.parquet(parquet_file_path)
# 写入parquet文件
df.write.parquet(path=parquet_file_path, mode='overwrite')
# 读取Impala
from impala.dbapi import connect
conn = connect(host='ip', port=port, user='user_name', password='password')
cur = conn.cursor(user='user_name')
cur.execute("select sql;")
rdd = spark.sparkContext.parallelize(cur.fetchall())
# 写入MySQL,需要在提交时指定JDBC
df.write.mode("append").format("jdbc").options(url='jdbc:mysql://ip:port/database',
user='user_name', password='password',
dbtable='table', batchsize="1000").save()
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('test pyspark').getOrCreate()
lines_df = spark.read.text("/user/spark/test/test.txt") # z这个是一个DataFrame
lines = lines_df.rdd.map(lambda x: x[0]) # df.rdd 可以将数据由DataFrame类型转化为RDD类型
counts = lines.flatMap(lambda x: x.split(' ')).map(lambda x: (x, 1)).reduceByKey(lambda a, b: a + b)
print(counts.collect())
counts.saveAsTextFile("/user/spark/test/count")
def compute_contribs(urls, rank):
'''
计算每个节点新的pr分值
'''
num_urls = len(urls)
for url in urls:
yield (url, rank * 1.0 / num_urls) # (节点,节点权重*占比)
# 初始化,"1 2"表示节点1有一条路径(连接)到2
lines = sc.parallelize(["1 2", "1 3", "1 4", "2 4", "2 1", "3 1", "4 3", "4 2"])
# links表示每个节点连接的节点列表[('2', ['4', '1']), ('1', ['2', '3', '4']), ('4', ['3', '2']), ('3', ['1'])]
links = lines.map(lambda line: line.split()).groupByKey().mapValues(lambda x: list(x)).cache()
# 每个节点初始化pr值为1,[('2', 1), ('1', 1), ('4', 1), ('3', 1)]
ranks = links.keys().map(lambda x: (x, 1))
for i in range(10):
# 将各个节点初始的分数分发给相邻的每个节点
contribs = links.join(ranks).flatMap(lambda r: compute_contribs(r[1][0], r[1][1]))
# 将每个节点的分数值汇总
ranks = contribs.reduceByKey(lambda x, y: x + y).mapValues(lambda x: x)
for link, rank in ranks.collect():
print("%s has rank %s." % (link, rank))
# 提交pyspark程序,可以指定运行配置,注:无论是pyspark shell模式,还是提交程序到yarn上运行,如果需要用到相关JDBC等jar包,需要指定
spark2-submit spark_learn.py
spark2-submit --master local[2] spark_learn.py # 本地模式,2个节点并行
spark2-submit --master yarn spark_learn.py # 提交到yarn上去执行
spark2-submit \
--master yarn \
--deploy-mode cluster \
--jars thirdparty/jars/ojdbc6.jar,thirdparty/jars/mysql-connector-java-5.1.42-bin.jar,thirdparty/jars/hive-metastore-1.1.0-cdh5.13.0.jar,thirdparty/jars/hive-service-1.1.0-cdh5.13.0.jar,thirdparty/jars/ImpalaJDBC41.jar \
--driver-class-path ojdbc6.jar:mysql-connector-java-5.1.42-bin.jar:hive-metastore-1.1.0-cdh5.13.0.jar:hive-service-1.1.0-cdh5.13.0.jar:ImpalaJDBC41.jar \
--executor-memory 14G \
--driver-memory 6G \
--conf spark.app.name='test' \
--conf spark.default.parallelism=50 \
--conf spark.memory.fraction=0.85 \
--conf spark.memory.storageFraction=0.5 \
--conf spark.yarn.executor.memoryOverhead=2048 \
--conf spark.yarn.driver.memoryOverhead=1024 \
--conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
--conf spark.yarn.maxAppAttempts=1 \
spark_learn.py \
pyspark2 \
--jars thirdparty/jars/ojdbc6.jar,thirdparty/jars/mysql-connector-java-5.1.42-bin.jar \
--driver-class-path thirdparty/jars/ojdbc6.jar:thirdparty/jars/mysql-connector-java-5.1.42-bin.jar