import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
/**
* @author YuZhansheng
* @ desc
* @ create 2019-03-11 14:58
*/
object TopNStatJob {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("TopNStatJob")
.master("local[2]")
.config("spark.sql.sources.partitionColumnTypeInference.enabled",false)
.getOrCreate()
val accessDF = spark.read.format("parquet").load("file:/root/DataSet/clean")
//accessDF.printSchema()
//accessDF.show(false)
//最受欢迎的TOPn课程
videoAccessTopNStat(spark,accessDF)
}
//最受欢迎的TOPn课程
def videoAccessTopNStat(spark:SparkSession,accessDF:DataFrame):Unit = {
import spark.implicits._
val videoAccessTopNDF = accessDF.filter($"day" === "20170511" && $"cmsType" === "video")
.groupBy("day","cmsId").agg(count("cmsId").as("times")).orderBy($"times".desc)
videoAccessTopNDF.show()
}
}
控制台打印输出:
+--------+-----+------+
| day|cmsId| times|
+--------+-----+------+
|20170511|14540|111027|
|20170511| 4000| 55734|
|20170511|14704| 55701|
|20170511|14390| 55683|
|20170511|14623| 55621|
|20170511| 4600| 55501|
|20170511| 4500| 55366|
|20170511|14322| 55102|
+--------+-----+------+
这是使用DataFrame统计出的结果,我们还可以使用SQL进行统计,如下:
//最受欢迎的TOPn课程
def videoAccessTopNStat(spark:SparkSession,accessDF:DataFrame):Unit = {
//使用DataFrame方式进行统计
//import spark.implicits._
//val videoAccessTopNDF = accessDF.filter($"day" === "20170511" && $"cmsType" === "video")
// .groupBy("day","cmsId").agg(count("cmsId").as("times")).orderBy($"times".desc)
//videoAccessTopNDF.show()
//使用SQL方式进行统计
accessDF.createOrReplaceTempView("access_logs")
val videoAccessTopNDF = spark.sql("select day,cmsId,count(1) as times from access_logs" +
"where day='20170511' and cmsType='video' " +
"group by day, cmsId order by times desc")
videoAccessTopNDF.show(false)
}
打印的结果和上一种方法是一样的。
新建一个MySQL操作工具类:
import java.sql.{Connection, DriverManager, PreparedStatement}
/**
* @author YuZhansheng
* @desc mysql操作工具类
* @create 2019-03-11 15:55
*/
object MySQLUtils {
/**
* 获取数据库连接
*/
def getConnection() = {
DriverManager.getConnection("jdbc:mysql://localhost:3306/imooc_project?user=root&password=18739548870yu")
}
/**
* 释放数据库连接等资源
* @param connection
* @param pstmt
*/
def release(connection: Connection, pstmt: PreparedStatement): Unit = {
try {
if (pstmt != null) {
pstmt.close()
}
} catch {
case e: Exception => e.printStackTrace()
} finally {
if (connection != null) {
connection.close()
}
}
}
def main(args: Array[String]) {
println(getConnection())
}
}
测试打印输出:com.mysql.jdbc.JDBC4Connection@799d4f69说明MySQL工具类创建正确。
在MySQL里面创建一张表,用来存放统计的结果:
create table day_video_access_topn_stat(
day varchar(8) not null,
cms_id bigint(10) not null,
times bigint(10) not null,
primary key (day,cms_id)
);
创建一个实体类DayVideoAccessStat:
/**
* 每天课程访问次数实体类
*/
case class DayVideoAccessStat(day: String, cmsId: Long, times: Long)
再创建一个DAO方法StatDAO:
import java.sql
import java.sql.PreparedStatement
import scala.collection.mutable.ListBuffer
/**
* @author YuZhansheng
* @desc 各个维度统计的DAO操作
* @create 2019-03-11 16:11
*/
object StatDAO {
//批量保存DayVideoAccessStat到数据库
def insertDayVideoAccessTopN(list:ListBuffer[DayVideoAccessStat]) : Unit = {
var connection:sql.Connection = null
var pstmt:PreparedStatement = null
try {
connection = MySQLUtils.getConnection()
connection.setAutoCommit(false) //设置为手动提交
val sql = "insert into day_video_access_topn_stat(day,cms_id,times) values (?,?,?)"
pstmt = connection.prepareStatement(sql)
for (ele <- list){
pstmt.setString(1,ele.day)
pstmt.setLong(2,ele.cmsId)
pstmt.setLong(3,ele.times)
pstmt.addBatch() //放到批里面执行
}
pstmt.executeBatch() //批量执行
connection.commit() //手动提交
}catch {
case e : Exception => e.printStackTrace()
}finally {
MySQLUtils.release(connection,pstmt)
}
}
}
然后再在Spark应用程序TopNStatJob中,执行将统计结果插入到MySQL数据库的操作:
//videoAccessTopNDF.show(false)
//将统计结果写入到MySQL数据库
try{
videoAccessTopNDF.foreachPartition(partitionOfRecords => {
val list = new ListBuffer[DayVideoAccessStat]
partitionOfRecords.foreach(info => {
val day = info.getAs[String]("day")
val cmsId = info.getAs[Long]("cmsId")
val times = info.getAs[Long]("times")
list.append(DayVideoAccessStat(day,cmsId,times))
})
StatDAO.insertDayVideoAccessTopN(list)
})
}catch {
case e : Exception => e.printStackTrace()
}
查看数据库是否已有数据: