redis.clients
jedis
3.2.0
禁用Linux的防火墙:Linux(CentOS7)里执行命令
systemctl stop/disable firewalld.service
redis.conf中注释掉
bind 127.0.0.1
然后将安全模式关闭
protected-mode no
或者不关闭安全模式,设置密码在配置文件中的requirepass中修改
然后执行命令前需要先输入密码
AUTH 密码 之后再执行命令
新建工具类 user_profile_manager_0224\src\main\java\com\atguigu\userprofile\utils\RedisUtil.java
public class RedisUtil {
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
}
}
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.set("k1000","v1000");
jedis.set("k2000","v2000");
jedis.set("k3000","v3000");
Set<String> keys = jedis.keys("*");
System.out.println(keys.size());
for(String key:keys){
System.out.println(key);
}
System.out.println(jedis.exists("k3000"));
System.out.println(jedis.ttl("k2000"));
System.out.println(jedis.get("k1000"));
}
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.mset("str1", "v1", "str2", "v2", "str3", "v3");
System.out.println(jedis.mget("str1", "str2", "str3"));
}
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.lpush("mylist","v1","v2","v3");
List<String> list = jedis.lrange("mylist", 0, -1);
for(String element : list){
System.out.println(element);
}
}
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.sadd("sets1","set01","set02","set03","set04");
jedis.sadd("sets2","set02","set03","set04","set05");
Set<String> smembers = jedis.smembers("sets1");
for (String set : smembers) {
System.out.println(set);
}
System.out.println("===================");
jedis.srem("sets1","set02");
System.out.println(jedis.scard("sets1"));
Set<String> sinter = jedis.sinter("sets1", "sets2");
for (String s : sinter) {
System.out.println(s);
}
System.out.println("===================");
Set<String> sunion = jedis.sunion("sets1", "sets2");
for (String s : sunion) {
System.out.println(s);
}
System.out.println("===================");
Set<String> sdiff = jedis.sdiff("sets1", "sets2");
for (String s : sdiff) {
System.out.println(s);
}
}
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.hset("hash1","userName","zhangsan");
System.out.println(jedis.hget("hash1", "userName"));
HashMap<String, String> map = new HashMap<>();
map.put("userName","lisi");
map.put("age","20");
map.put("gender","nv");
jedis.hmset("hash2",map);
List<String> res = jedis.hmget("hash2", "userName", "age", "gender");
for (String re : res) {
System.out.println(re);
}
}
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.zadd("zset01", 100d, "z3");
jedis.zadd("zset01", 90d, "l4");
jedis.zadd("zset01", 80d, "w5");
jedis.zadd("zset01", 70d, "z6");
Set<Tuple> zrange = jedis.zrangeWithScores("zset01", 0, -1);
for (Tuple tuple : zrange) {
System.out.println(tuple);
}
}
为了节省每次连接redis服务带来的消耗,把连接好的实例反复利用。
通过参数管理连接的行为
代码如下
public class RedisUtil {
public static void main(String[] args) {
//Jedis jedis = new Jedis("hadoop101", 6379);
Jedis jedis = RedisUtil.getJedisFromPool();
System.out.println(jedis.ping()); // 输出PONG
jedis.close();
}
private static JedisPool jedisPool=null;
public static Jedis getJedisFromPool(){
if(jedisPool==null){
JedisPoolConfig jedisPoolConfig =new JedisPoolConfig();
jedisPoolConfig.setMaxTotal(200); //最大可用连接数
jedisPoolConfig.setMaxIdle(30); //最大闲置连接数
jedisPoolConfig.setMinIdle(10); //最小闲置连接数
jedisPoolConfig.setBlockWhenExhausted(true); //连接耗尽是否等待
jedisPoolConfig.setMaxWaitMillis(2000); //等待时间
jedisPoolConfig.setTestOnBorrow(true); //取连接的时候进行一下测试 ping pong
jedisPool=new JedisPool(jedisPoolConfig,"hadoop101", 6379 );
return jedisPool.getResource();
}else{
return jedisPool.getResource();
}
}
}
链接池参数说明
添加方法
/**
* 数组无法存入到List中,mybatis进行封装,想要封装到List中,需要变为一行一行的值
* 数组是一行值,不同uid间以逗号分隔
* 将数组变为很多行,在ClickHouse中可以使用arrayJoin函数将数组炸开
* @param userGroupId
* @return
*/
@Select("select arrayJoin( bitmapToArray(us) ) as us from user_group where user_group_id=#{userGroupId}")
@DS("clickhouse")
public List<String> userGroupUidList(@Param("userGroupId") String userGroupId);
添加代码
// 3 人群包(包含所有uid)以应对高QPS访问
// redis(bitmap/set)
/**
* - 查询出人群包 uids的集合
*
* - 写入redis
* - type:set(不需要有序,排除zset;需要单值排除hash;list中很多不是幂等操作,最终选择set)
* - key:user_group: 101(user_group:user_group_id)
* - value:uid ...
* - field score:无
* - 写api:sadd
* - 读api:smembers
* - 失效:不是临时值,不设失效
*/
List<String> uidList = super.baseMapper.userGroupUidList(userGroup.getId().toString());
Jedis jedis = RedisUtil.getJedisFromPool();
String key = "user_group:" + userGroup.getId();
String[] uidArr = uidList.toArray(new String[]{});
jedis.sadd(key,uidArr);
jedis.close();
在网页创建分群,然后在redis中查看是否存在数据
keys *
smembers user_group:id(会有具体的数字)
挖掘类标签需要用算法挖掘用户的相关特征,比如:性别预测、年龄预测、 用户流失预测、风险欺诈预测。
相比统计、规则类这些通过专业人员制定明确规则的标签,挖掘类的标签完全是另一套处理思路。
获得挖掘标签过程:
整个挖掘的过程的核心就是建立、完善模型的过程。
一个模型完善的过程是个没有尽头的迭代。
主要是对数据的初步的清洗加工,这个过程一般可以在数仓中完成,然后在数仓中稍微的添加一些操作。
主要是特征的选择和提取。比如想预测用户的流失,那就要选择哪些指标字段会和用户的流失有比较强的相关性。要从数仓中,把这些指标提取出来并进一步加工。
除了获得特征,还需要“参考答案”,比如抽选出来的这些用户特征,那这些用户到底是不是流失的,要标记出来,用于机器学习。
特征的选取往往不能一蹴而就,需要反复的迭代尝试。
目前机器学习的算法种类繁多,比如分类算法领域中:决策树、随机森林、逻辑回归、GBDT、XGBoost。
回归算法领域中:线性回归、多项式回归、岭回归、Lasso回归、弹性回归。
在画像领域中,主要使用分类算法。但具体使用哪种分类算法,也是需要不断尝试验证的,没有一定的标准。
通过代码实现“数据 + 算法 = 模型”,可以使用scala调用sparkMLlib工具包实现机器学习训练,将模型存储在hdfs。
一般会把数据进行分组,训练组和验证组,然后对模型组进行准率的评估。
根据准确率,对模型进行优化:
优化一般主要是三个方面:
把模型投放到实际的标签生产中去观察,比如预测流失的用户,一段时间是否真的会流失。
或者进行A/B测试,对预测的一部分用户采取某种措施,另一部分用户不作处理。观察两组人的变化效果。
通过实际生产中的预测效果,不断的反复调整模型、算法。
机器学习【决策树算法1】
机器学习【决策树算法2】
使用决策树需要解决的问题:
训练 + 预测的完成过程如下图:
create table student
( uid bigint ,
hair string,
height bigint ,
skirt string,
age string ,
gender string
)
insert overwrite table student
values
( 1,'长发' ,155,'是', '80后','女' ),
( 2,'短发' ,156,'否', '90后','女' ),
( 3,'长发' ,157,'是', '00后','女' ),
( 4,'短发' ,158,'否', '80后','女' ),
( 5,'长发' ,159,'是', '90后','女' ),
( 6,'短发' ,160,'否', '00后','女' ),
( 7,'长发' ,161,'否', '80后','女' ),
( 8,'短发' ,162,'否', '90后','女' ),
( 9,'长发' ,163,'是', '00后','女' ),
( 10,'短发' ,164,'否', '80后','女' ),
( 11,'长发' ,165,'是', '90后','女' ),
( 12,'短发' ,166,'否', '00后','女' ),
( 13,'长发' ,167,'是', '80后','女' ),
( 14,'短发' ,168,'否', '90后','女' ),
( 15,'板寸' ,169,'是', '00后','女' ),
( 16,'短发' ,160,'否', '80后','女' ),
( 17,'长发' ,171,'是', '90后','女' ),
( 18,'短发' ,162,'否', '00后','女' ),
( 19,'长发' ,173,'是', '80后','女' ),
( 20,'短发' ,174,'否', '90后','女' ),
( 21,'长发' ,175,'是', '00后','女' ),
( 22,'短发' ,155,'否', '80后','女' ),
( 23,'长发' ,156,'否', '90后','女' ),
( 24,'短发' ,157,'否', '00后','女' ),
( 25,'长发' ,158,'否', '80后','女' ),
( 26,'短发' ,159,'否', '90后','女' ),
( 27,'长发' ,160,'是', '00后','女' ),
( 28,'短发' ,161,'否', '00后','女' ),
( 29,'长发' ,162,'是', '80后','女' ),
( 30,'短发' ,163,'否', '00后','女' ),
( 31,'长发' ,164,'是', '80后','女' ),
( 32,'短发' ,165,'否', '00后','女' ),
( 33,'长发' ,166,'是', '00后','女' ),
( 34,'短发' ,167,'否', '80后','女' ),
( 35,'长发' ,169,'是', '90后','女' ),
( 36,'短发' ,170,'否', '00后','女' ),
( 37,'长发' ,171,'是', '80后','女' ),
( 38,'短发' ,172,'是', '90后','女' ),
( 39,'长发' ,173,'否', '00后','女' ),
( 40,'长发' ,174,'否', '80后','女' ),
( 41,'短发' ,175,'是', '90后','女' ),
( 42,'短发' ,165,'否', '00后','女' ),
( 43,'短发' ,166,'是', '80后','女' ),
( 44,'长发' ,167,'否', '90后','女' ),
( 45,'短发' ,168,'是', '00后','女' ),
( 46,'短发' ,169,'否', '80后','女' ),
( 47,'长发' ,170,'是', '90后','女' ),
( 48,'短发' ,171,'否', '00后','女' ),
( 49,'长发' ,172,'是', '80后','女' ),
( 50,'短发' ,173,'否', '90后','女' ),
( 51,'短发' ,165,'否', '80后','男' ),
( 52,'板寸' ,166,'否', '90后','男' ),
( 51,'短发' ,167,'否', '00后','男' ),
( 52,'板寸' ,168,'否', '80后','男' ),
( 53,'短发' ,169,'否', '90后','男' ),
( 54,'短发' ,170,'否', '00后','男' ),
( 55,'短发' ,171,'否', '80后','男' ),
( 56,'板寸' ,172,'否', '90后','男' ),
( 57,'短发' ,173,'否', '00后','男' ),
( 58,'短发' ,174,'否', '80后','男' ),
( 59,'短发' ,175,'否', '90后','男' ),
( 60,'短发' ,176,'否', '00后','男' ),
( 61,'短发' ,177,'否', '80后','男' ),
( 62,'短发' ,178,'否', '90后','男' ),
( 63,'短发' ,179,'否', '00后','男' ),
( 64,'板寸' ,180,'否', '80后','男' ),
( 65,'短发' ,181,'否', '90后','男' ),
( 66,'短发' ,182,'否', '80后','男' ),
( 67,'短发' ,183,'否', '80后','男' ),
( 68,'短发' ,184,'否', '90后','男' ),
( 69,'短发' ,185,'否', '80后','男' ),
( 70,'短发' ,166,'否', '80后','男' ),
( 71,'短发' ,167,'否', '90后','男' ),
( 72,'板寸' ,168,'否', '00后','男' ),
( 73,'短发' ,169,'否', '80后','男' ),
( 74,'短发' ,170,'否', '90后','男' ),
( 75,'短发' ,171,'否', '00后','男' ),
( 76,'板寸' ,172,'否', '80后','男' ),
( 77,'短发' ,173,'否', '90后','男' ),
( 78,'短发' ,174,'否', '00后','男' ),
( 79,'短发' ,175,'否', '80后','男' ),
( 80,'板寸' ,176,'否', '90后','男' ),
( 81,'短发' ,177,'否', '00后','男' ),
( 82,'短发' ,178,'否', '80后','男' ),
( 83,'短发' ,179,'否', '90后','男' ),
( 84,'短发' ,180,'否', '80后','男' ),
( 85,'短发' ,181,'否', '80后','男' ),
( 86,'板寸' ,182,'否', '90后','男' ),
( 87,'短发' ,183,'否', '00后','男' ),
( 88,'短发' ,184,'否', '80后','男' ),
( 89,'短发' ,185,'否', '90后','男' ),
( 90,'板寸' ,184,'否', '00后','男' ),
( 91,'短发' ,171,'否', '80后','男' ),
( 92,'短发' ,172,'否', '90后','男' ),
( 93,'短发' ,173,'否', '00后','男' ),
( 94,'短发' ,174,'否', '80后','男' ),
( 95,'短发' ,175,'否', '90后','男' ),
( 96,'板寸' ,176,'否', '00后','男' ),
( 97,'短发' ,177,'否', '80后','男' ),
( 98,'板寸' ,178,'否', '90后','男' ),
( 99,'板寸' ,179,'否', '00后','男' ),
( 100,'长发' ,180,'否', '80后','男' ) ,
( 101,'长发' ,155,'是', '80后','女' ),
( 102,'短发' ,156,'否', '90后','女' ),
( 103,'长发' ,157,'是', '00后','女' ),
( 104,'短发' ,158,'否', '80后','女' ),
( 105,'长发' ,159,'是', '90后','女' ),
( 106,'短发' ,160,'否', '00后','女' ),
( 107,'长发' ,161,'否', '80后','女' ),
( 108,'短发' ,162,'否', '90后','女' ),
( 109,'长发' ,163,'是', '00后','女' ),
( 110,'短发' ,164,'否', '80后','女' )
将数据存放到hive中。
在user-profile-task1016下创建task-ml,如下图:
在pom.xml引入依赖
<dependencies>
<dependency>
<groupId>com.hzy.userprofilegroupId>
<artifactId>task-commonartifactId>
<version>1.0-SNAPSHOTversion>
dependency>
<dependency>
<groupId>org.apache.sparkgroupId>
<artifactId>spark-mllib_2.12artifactId>
<version>3.0.0version>
<scope>providedscope>
dependency>
dependencies>
<build>
<plugins>
<plugin>
<groupId>net.alchim31.mavengroupId>
<artifactId>scala-maven-pluginartifactId>
<version>3.4.6version>
<executions>
<execution>
<goals>
<goal>compilegoal>
<goal>testCompilegoal>
goals>
execution>
executions>
plugin>
<plugin>
<groupId>org.apache.maven.pluginsgroupId>
<artifactId>maven-assembly-pluginartifactId>
<version>3.0.0version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependenciesdescriptorRef>
descriptorRefs>
configuration>
<executions>
<execution>
<id>make-assemblyid>
<phase>packagephase>
<goals>
<goal>singlegoal>
goals>
execution>
executions>
plugin>
plugins>
build>
流水线PipeLine实际上就是执行一些预处理工作,其中
标签索引:参考答案,数据集中的最后一列,将标签值转换为矢量值,也就是将男,女
转换为0,1
。
按照出现概率的大小次序排序,概率越大,矢量越小。
特征聚合:在原始数据中选择特征列,并集合成一列。
特征索引:将特征集合中的原值转换为矢量值,转换规则同标签索引。
需要识别哪些是连续值特征,哪些是离散值特征,具体判断标准:底层会设置一个阈值,高于阈值判断为连续值,否则为离散值,即小于等于。
class MyPipeline {
// 5 用于接收此阶段最终的结果
var pipeline:Pipeline = null
//最大分类树(用于识别连续值特征和分类特征),用于3创建特征索引列
private var maxCategories=5
// 最大分支数
private var maxBins=5
// 最大树深度
private var maxDepth=5
//最小分支包含数据条数
private var minInstancesPerNode=1
//最小分支信息增益
private var minInfoGain=0.0
}
// 用于1 标签索引
var labelColName: String = null
// 从外部注入
def setLabelColName(labelColName: String) : MyPipeline = {
this.labelColName = labelColName
this
}
// 1 创建标签索引
def createLabelIndexer():StringIndexer = {
// 输入的原始数据 结构为DF
val indexer = new StringIndexer()
// 设置输入列和输出列
// 输入列为数据的最后一列,通过外部传递进来
// 输出列与外部数据没有关系,直接固定下来即可
// 最终会在DF中增加一列,名称可以自己设置
indexer.setInputCol(labelColName).setOutputCol("label_index")
indexer
}
// 用于2 特征集合
var featureColNames:Array[String] = null
// 从外部注入
def setFeatureColNames(featureColNames: Array[String]) : MyPipeline = {
this.featureColNames = featureColNames
this
}
// 2 创建特征集合列
def createFeatureAssemble():VectorAssembler = {
val assembler = new VectorAssembler()
// 可以将多个列设置为特征,也可以称为维度,输出列只有一个
assembler.setInputCols(featureColNames).setOutputCol("feature_assemble")
assembler
}
// 3 创建特征索引列
def createFeatureIndexer():VectorIndexer = {
val indexer = new VectorIndexer()
// 特征集合的输出就是特征索引的输入
// 此外还需要设置阈值,用于判断是线性值还是离散值
indexer.setInputCol("feature_assemble").setOutputCol("feature_index").setMaxCategories(maxCategories)
indexer
}
// 4 创建分类器
def createClassifier():DecisionTreeClassifier ={
val classifier = new DecisionTreeClassifier()
// 设置标签列(1),设置特征列(3),设置预测列(自己起名)
classifier.setLabelCol("label_index").setFeaturesCol("feature_index").setPredictionCol("prediction_col")
classifier
}
def init():MyPipeline = {
// StringIndexer、VectorAssembler、VectorIndexer、DecisionTreeClassifier
// 以上四者的父类都是PipelineStage,可以理解为是流水线上的一个环节
// 以上前三者都是这个环节中的工人,最后一个是这三个人的师傅
// 执行此方法,师徒四人就要上岗干活了!
pipeline = new Pipeline().setStages( Array(
createLabelIndexer,
createFeatureAssemble,
createFeatureIndexer,
createClassifier
))
this
}
// 6 训练,得到模型
def train(dataFrame:DataFrame):Unit ={
pipelineModel = pipeline.fit(dataFrame)
}
// 7 预测
def predict(dataFrame: DataFrame):DataFrame ={
val predictedDataFrame1: DataFrame = pipelineModel.transform(dataFrame)
predictedDataFrame1
}
新建类StudentGenderTrain,添加配置文件,如下图
源码如下:
package com.hzy.userprofile.ml.train
import com.hzy.userprofile.ml.pipeline.MyPipeline
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
object StudentGenderTrain {
def main(args: Array[String]): Unit = {
val sparkConf: SparkConf = new SparkConf().setAppName("student_gender_train.app").setMaster("local[*]")
val sparkSession: SparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
println("查询数据")
// 1 查询数据
val sql =
s"""
| select
| uid,
| case hair when '长发' then 101 when '短发' then 102 when '板寸' then 103 end as hair,
| height,
| case skirt when '是' then 111 when '否' then 222 end as skirt,
| case age when '00后' then 100 when '90后' then 90 when '80后' then 80 end as age,
| gender
| from
| default.student
|""".stripMargin
println(sql)
val dataFrame: DataFrame = sparkSession.sql(sql)
println("切分数据")
// 2 切分数据:训练集和测试集(82 或 73)
val Array(trainDF,testDF) = dataFrame.randomSplit(Array(0.8,0.2))
println("创建myPipeine")
// 3 创建myPipeine
val myPipeline: MyPipeline = new MyPipeline()
.setLabelColName("gender")
.setFeatureColNames(Array("hair","height","skirt","age"))
.init()
println("进行训练")
// 4 进行训练
myPipeline.train(trainDF)
println("进行预测")
// 5 进行预测
val predictedDataFrame: DataFrame = myPipeline.predict(testDF)
println("打印预测结果")
// 6 打印预测结果
predictedDataFrame.show(100,false)
}
}
运行之前需要配置hadoop用户名,集体结果分析如下:
package com.hzy.userprofile.ml.pipeline
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.sql.DataFrame
class MyPipeline {
// 5 用于接收此阶段最终的结果
var pipeline:Pipeline = null
def init():MyPipeline = {
// StringIndexer、VectorAssembler、VectorIndexer、DecisionTreeClassifier
// 以上四者的父类都是PipelineStage,可以理解为是流水线上的一个环节
// 以上前三者都是这个环节中的工人,最后一个是这三个人的师傅
// 执行此方法,师徒四人就要上岗干活了!
pipeline = new Pipeline().setStages( Array(
createLabelIndexer(),
createFeatureAssemble(),
createFeatureIndexer(),
createClassifier()
))
this
}
// 模型:通过训练得来
var pipelineModel:PipelineModel = null
//最大分类树(用于识别连续值特征和分类特征),用于3创建特征索引列
private var maxCategories=5
// 最大分支数
private var maxBins=5
// 最大树深度
private var maxDepth=5
//最小分支包含数据条数
private var minInstancesPerNode=1
//最小分支信息增益
private var minInfoGain=0.0
// 用于1 标签索引
var labelColName: String = null
// 用于2 特征集合
var featureColNames:Array[String] = null
// 从外部注入
def setLabelColName(labelColName: String) : MyPipeline = {
this.labelColName = labelColName
this
}
// 从外部注入
def setFeatureColNames(featureColNames: Array[String]) : MyPipeline = {
this.featureColNames = featureColNames
this
}
// 1 创建标签索引
def createLabelIndexer():StringIndexer = {
// 输入的原始数据 结构为DF
val indexer = new StringIndexer()
// 设置输入列和输出列
// 输入列为数据的最后一列,通过外部传递进来
// 输出列与外部数据没有关系,直接固定下来即可
// 最终会在DF中增加一列,名称可以自己设置
indexer.setInputCol(labelColName).setOutputCol("label_index")
indexer
}
// 2 创建特征集合列
def createFeatureAssemble():VectorAssembler = {
val assembler = new VectorAssembler()
// 可以将多个列设置为特征,也可以称为维度,输出列只有一个
assembler.setInputCols(featureColNames).setOutputCol("feature_assemble")
assembler
}
// 3 创建特征索引列
def createFeatureIndexer():VectorIndexer = {
val indexer = new VectorIndexer()
// 特征集合的输出就是特征索引的输入
// 此外还需要设置阈值,用于判断是线性值还是离散值
indexer.setInputCol("feature_assemble").setOutputCol("feature_index").setMaxCategories(maxCategories)
indexer
}
// 4 创建分类器
def createClassifier():DecisionTreeClassifier ={
val classifier = new DecisionTreeClassifier()
// 设置标签列(1),设置特征列(3),设置预测列(自己起名)
classifier.setLabelCol("label_index").setFeaturesCol("feature_index").setPredictionCol("prediction_col")
classifier
}
// 6 训练,得到模型
def train(dataFrame:DataFrame):Unit ={
pipelineModel = pipeline.fit(dataFrame)
}
// 7 预测
def predict(dataFrame: DataFrame):DataFrame ={
val predictedDataFrame1: DataFrame = pipelineModel.transform(dataFrame)
predictedDataFrame1
}
}
package com.hzy.userprofile.ml.train
import com.hzy.userprofile.ml.pipeline.MyPipeline
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
object StudentGenderTrain {
def main(args: Array[String]): Unit = {
val sparkConf: SparkConf = new SparkConf().setAppName("student_gender_train.app").setMaster("local[*]")
val sparkSession: SparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
println("查询数据")
// 1 查询数据
val sql =
s"""
| select
| uid,
| case hair when '长发' then 101 when '短发' then 102 when '板寸' then 103 end as hair,
| height,
| case skirt when '是' then 111 when '否' then 222 end as skirt,
| case age when '00后' then 100 when '90后' then 90 when '80后' then 80 end as age,
| gender
| from
| default.student
|""".stripMargin
println(sql)
val dataFrame: DataFrame = sparkSession.sql(sql)
println("切分数据")
// 2 切分数据:训练集和测试集(82 或 73)
val Array(trainDF,testDF) = dataFrame.randomSplit(Array(0.8,0.2))
println("创建myPipeine")
// 3 创建myPipeine
val myPipeline: MyPipeline = new MyPipeline()
.setLabelColName("gender")
.setFeatureColNames(Array("hair","height","skirt","age"))
.init()
println("进行训练")
// 4 进行训练
myPipeline.train(trainDF)
println("进行预测")
// 5 进行预测
val predictedDataFrame: DataFrame = myPipeline.predict(testDF)
println("打印预测结果")
// 6 打印预测结果
predictedDataFrame.show(100,false)
}
}