sparklyr 教程

使用sparklyr可以通过R连接数据库,并且可以使用R的相关工具对spark中的数据进行处理。

R 调用spark

  1. 连接spark
  2. 将数据写入spark
  3. 使用tidyvise对数据进行操作
  4. 建模
  5. 断开连接

加载sparklyr

library(sparklyr)
spark_install(version = "2.3.1")

这里是连接本地的spark

sc <- spark_connect(master = "local")

加载数据处理的包

library(dplyr)

将数据读取进入spark

iris_tbl <- copy_to(sc, iris)
flights_tbl <- copy_to(sc, nycflights13::flights, "flights")
batting_tbl <- copy_to(sc, Lahman::Batting, "batting")

查看spark里面有哪些数据

src_tbls(sc)

你连接好了spark,然后将数据读取spark,然后就直接使用tidyverse 对数据进行处理

> flights_tbl %>% filter(dep_delay == 2)
# Source: spark [?? x 19]
    year month   day dep_time sched_dep_time dep_delay arr_time
 *                          
 1  2013     1     1      517            515      2.00      830
 2  2013     1     1      542            540      2.00      923
 3  2013     1     1      702            700      2.00     1058
 4  2013     1     1      715            713      2.00      911
 5  2013     1     1      752            750      2.00     1025
 6  2013     1     1      917            915      2.00     1206
 7  2013     1     1      932            930      2.00     1219
 8  2013     1     1     1028           1026      2.00     1350
 9  2013     1     1     1042           1040      2.00     1325
10  2013     1     1     1231           1229      2.00     1523
# ... with more rows, and 12 more variables: sched_arr_time ,
#   arr_delay , carrier , flight , tailnum ,
#   origin , dest , air_time , distance , hour ,
#   minute , time_hour 

描绘航班延误的数据

对spark里面的数据进行分析,然后将结果保存在本地

delay <- flights_tbl %>% 
  group_by(tailnum) %>%
  summarise(count = n(), dist = mean(distance), delay = mean(arr_delay)) %>%
  filter(count > 20, dist < 2000, !is.na(delay)) %>%
  collect()

绘制一下图片

library(ggplot2)
ggplot(delay, aes(dist, delay)) +
  geom_point(aes(size = count), alpha = 1/2) +
  geom_smooth() +
  scale_size_area(max_size = 2)
sparklyr 教程_第1张图片
image.png

窗口功能

batting_tbl %>%
  select(playerID, yearID, teamID, G, AB:H) %>%
  arrange(playerID, yearID, teamID) %>%
  group_by(playerID) %>%
  filter(min_rank(desc(H)) <= 2 & H > 0)

使用SQL
可以直接对Spark集群中的表进行操作,spark_connect() 为spark实现了一个DBI的窗口,因此可以使用dbGetQuery
来执行SQL并将结果作为R的数据帧返回:

library(DBI)
iris_preview <- dbGetQuery(sc, "SELECT * FROM iris LIMIT 10")
> iris_preview
   Sepal_Length Sepal_Width Petal_Length Petal_Width Species
1           5.1         3.5          1.4         0.2  setosa
2           4.9         3.0          1.4         0.2  setosa
3           4.7         3.2          1.3         0.2  setosa
4           4.6         3.1          1.5         0.2  setosa
5           5.0         3.6          1.4         0.2  setosa
6           5.4         3.9          1.7         0.4  setosa
7           4.6         3.4          1.4         0.3  setosa
8           5.0         3.4          1.5         0.2  setosa
9           4.4         2.9          1.4         0.2  setosa
10          4.9         3.1          1.5         0.1  setosa
  

机器学习
可以通过sparklyr中的机器学习功能在spark集群中编排机器学习算法。这些功能能够连接到一组基于DataFrames
构建高级API,可以帮助你创建和调整机器学习的工作流程

我们使用ml_linear_regression 来拟合线性回归模型,这里使用内置的mtcars数据集,
使用mpg和wt等来预测汽油的消耗量

mtcars_tbl <- copy_to(sc, mtcars)

对数据进行简单的处理

partitions <- mtcars_tbl %>%
  filter(hp >= 100) %>%
  mutate(cyl8 = cyl == 8) %>%
  sdf_partition(training = 0.5, test = 0.5, seed = 1099)

建立模型

fit <- partitions$training %>%
  ml_linear_regression(response = "mpg", features = c("wt", "cyl"))
> fit
Formula: mpg ~ wt + cyl

Coefficients:
(Intercept)          wt         cyl 
  33.795576   -1.596247   -1.580360 
> summary(fit)
Deviance Residuals:
    Min      1Q  Median      3Q     Max 
-2.0947 -1.2747 -0.1129  1.0876  2.2185 

Coefficients:
(Intercept)          wt         cyl 
  33.795576   -1.596247   -1.580360 

R-Squared: 0.8267
Root Mean Squared Error: 1.437

读写数据
可以读取csv,json,Parquet格式的数据,数据还可以储存在HDFS,S3或者群集节点

第一步,创建文件

temp_csv <- tempfile(fileext = ".csv")
temp_parquet <- tempfile(fileext = ".parquet")
temp_json <- tempfile(fileext = ".json")

第二步,使用sparkcsv读写数据

spark_write_csv(iris_tbl, temp_csv)
iris_csv_tbl <- spark_read_csv(sc, "iris_csv", temp_csv)

读写数据

spark_write_parquet(iris_tbl, temp_parquet)
iris_parquet_tbl <- spark_read_parquet(sc, "iris_parquet", temp_parquet)

读写数据

spark_write_json(iris_tbl, temp_json)
iris_json_tbl <- spark_read_json(sc, "iris_json", temp_json)

src_tbls(sc)

分布式R
可以在集群中执行任意代码spark_apply,所以说在R使用分布式,那么是基于sparkl_apply这个方式

> spark_apply(iris_tbl, function(data) {
+   data[1:4] + rgamma(1,2)
+ })
# Source: spark [?? x 4]
   Sepal_Length Sepal_Width Petal_Length Petal_Width
 *                              
 1         7.54        5.94         3.84        2.64
 2         7.34        5.44         3.84        2.64
 3         7.14        5.64         3.74        2.64
 4         7.04        5.54         3.94        2.64
 5         7.44        6.04         3.84        2.64
 6         7.84        6.34         4.14        2.84
 7         7.04        5.84         3.84        2.74
 8         7.44        5.84         3.94        2.64
 9         6.84        5.34         3.84        2.64
10         7.34        5.54         3.94        2.54
# ... with more rows
# spark_apply(iris_tbl, function(data) {
#   rf_model <- data %>% ml_random_forest(Species~.,type = 'classification')
#   
# })

还可以分组,对每一组进行操作

spark_apply(
+     iris_tbl,
+     function(e) broom::tidy(lm(Petal_Width ~ Petal_Length, e)),
+     names = c("term", "estimate", "std.error", "statistic", "p.value"),
+     group_by = "Species"
+ )
# Source: spark [?? x 6]
  Species    term         estimate std.error statistic         p.value
*                                       
1 versicolor (Intercept)   -0.0843    0.161     -0.525 0.602          
2 versicolor Petal_Length   0.331     0.0375     8.83  0.0000000000127
3 virginica  (Intercept)    1.14      0.379      2.99  0.00434        
4 virginica  Petal_Length   0.160     0.0680     2.36  0.0225         
5 setosa     (Intercept)   -0.0482    0.122     -0.396 0.694          
6 setosa     Petal_Length   0.201     0.0826     2.44  0.0186  

实用程序

将表缓存到内存中去

tbl_cache(sc, "batting")

使用以下内容从内存中卸载

tbl_uncache(sc, "batting")

连接实用程序

spark_web(sc)

显示log信息

> spark_log(sc, n = 10)
18/11/23 23:32:29 INFO BlockManager: Found block rdd_1046_3 locally
18/11/23 23:32:29 INFO Executor: Finished task 2.0 in stage 307.0 (TID 344). 2077 bytes result sent to driver
18/11/23 23:32:29 INFO Executor: Finished task 1.0 in stage 307.0 (TID 343). 2259 bytes result sent to driver
18/11/23 23:32:29 INFO Executor: Finished task 0.0 in stage 307.0 (TID 342). 2263 bytes result sent to driver
18/11/23 23:32:29 INFO TaskSetManager: Finished task 2.0 in stage 307.0 (TID 344) in 6 ms on localhost (executor driver) (1/3)
18/11/23 23:32:29 INFO TaskSetManager: Finished task 0.0 in stage 307.0 (TID 342) in 6 ms on localhost (executor driver) (2/3)
18/11/23 23:32:29 INFO TaskSetManager: Finished task 1.0 in stage 307.0 (TID 343) in 6 ms on localhost (executor driver) (3/3)
18/11/23 23:32:29 INFO TaskSchedulerImpl: Removed TaskSet 307.0, whose tasks have all completed, from pool 
18/11/23 23:32:29 INFO DAGScheduler: ResultStage 307 (collect at utils.scala:200) finished in 0.012 s
18/11/23 23:32:29 INFO DAGScheduler: Job 243 finished: collect at utils.scala:200, took 0.013866 s
> 

断开连接

spark_disconnect(sc)

实用H2O
rsparkling是来自H2O的软件包,其扩展了sparklyr以提供Sparkling Water 的接口,以下实例进安装,配置和运行h2o.glm

options(rsparkling.sparklingwater.version = "2.1.14")

加载必要的包

library(rsparkling)
library(sparklyr)
library(dplyr)
library(h2o)

连接spark集群

sc <- spark_connect(master = "local")

将数据加载进入spark

mtcars_tbl <- copy_to(sc, mtcars, "mtcars")

将数据加载进入H2O

mtcars_h2o <- as_h2o_frame(sc, mtcars_tbl, strict_version_check = F)

在这里出现了BUG

detach("package:rsparkling", unload = TRUE)
if ("package:h2o" %in% search()) { detach("package:h2o", unload = TRUE) }
if (isNamespaceLoaded("h2o")){ unloadNamespace("h2o") }
remove.packages("h2o")
install.packages("h2o", type = "source", repos = "https://h2o-release.s3.amazonaws.com/h2o/rel-wright/7/R")

建立回归模型

mtcars_glm <- h2o.glm(x = c("wt", "cyl"), 
                      y = "mpg",
                      training_frame = mtcars_h2o,
                      lambda_search = TRUE)

查看模型

mtcars_glm

断开连接

spark_disconnect(sc)

你可能感兴趣的:(sparklyr 教程)