PySpark之电商广告推荐(4)

离线推荐处理

目的:这里主要是利用我们前面训练的ALS模型进行协同过滤召回,但是注意,我们ALS模型召回的是用户最感兴趣的类别,而我们需要的是用户可能感兴趣的广告的集合,因此我们还需要根据召回的类别匹配出对应的广告。 所以,这里我们除了需要我们训练的ALS模型以外,还需要有一个广告和类别的对应关系。

构建广告与商品类别表

# 加载广告基本信息数据,
df = spark.read.csv("data/ad_feature.csv", header=True)

# 注意:由于本数据集中存在NULL字样的数据,无法直接设置schema,只能先将NULL类型的数据处理掉,然后进行类型转换
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType

# 替换掉NULL字符串,替换掉
df = df.replace("NULL", "-1")

# 更改df表结构:更改列类型和列名称
ad_feature_df = df.\
    withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).withColumnRenamed("adgroup_id", "adgroupId").\
    withColumn("cate_id", df.cate_id.cast(IntegerType())).withColumnRenamed("cate_id", "cateId").\
    withColumn("campaign_id", df.campaign_id.cast(IntegerType())).withColumnRenamed("campaign_id", "campaignId").\
    withColumn("customer", df.customer.cast(IntegerType())).withColumnRenamed("customer", "customerId").\
    withColumn("brand", df.brand.cast(IntegerType())).withColumnRenamed("brand", "brandId").\
    withColumn("price", df.price.cast(FloatType()))

# 这里我们只需要adgroupId、和cateId
_ = ad_feature_df.select("adgroupId", "cateId")
# 由于这里数据集其实很少,所以我们再直接转成 Pandas dataframe来处理,把数据载入内存
pdf = _.toPandas()
pdf.head()
adgroupId cateId
0 63133 6406
1 313401 6406
2 248909 392
3 208458 392
4 110847 7211

推荐广告

from pyspark.ml.recommendation import ALSModel

als_model = ALSModel.load("data/ALS_Cate_model.obj/")
print("推荐类型:",  als_model.getItemCol(), als_model.getUserCol())

# 获得推荐广告的商品类型
cateId_df = pd.DataFrame(pdf.cateId.unique(), columns=["cateId"])
cateId_df.insert(0, "userId", np.array([427 for i in range(len(cateId_df))]))
result = als_model.transform(spark.createDataFrame(cateId_df))  # 推荐
result.sort("prediction", ascending=False).dropna().show()
推荐类型:('cateId', 'userId')
+------+------+----------+
|userId|cateId|prediction|
+------+------+----------+
|   427|  6978|0.31386238|
|   427|  7422|0.28090417|
|   427|  4433|0.20575814|
|   427|  3015|0.19531126|
|   427|  5358|0.19412701|
|   427| 11752|0.19253506|
|   427|  4589|0.18606065|
|   427|   820|0.18226622|
|   427|  5375|0.17778711|
|   427|   474|0.17610353|
|   427|  6294|0.17403731|
|   427|  4815| 0.1676827|
|   427|  5751|0.16453266|
|   427|  4798|0.16450714|
|   427| 11486|0.16103375|
|   427|  8799|0.16067055|
|   427|  4432|0.15562129|
|   427|     7|0.15313935|
|   427|  5523|0.15289858|
|   427|  9383|0.15233861|
+------+------+----------+
only showing top 20 rows
# 用户隐含特征
als_model.userFactors.show(10, truncate=False)
+----+--------------------------------------------------------------------------------------------------------------------------------------+
|id  |features                                                                                                                              |
+----+--------------------------------------------------------------------------------------------------------------------------------------+
|427 |[-0.08862011, 0.024677375, 0.12652709, 0.03105228, -0.015887666, -0.076577045, -0.10301663, 0.020497091, 0.014156798, -0.093687914]   |
|437 |[0.08121957, -0.07566206, -0.08685034, -0.16820733, -0.21639597, 0.10768882, -0.13673659, -0.09585375, 0.082970485, 0.14766619]       |
|747 |[0.031304367, 0.09682198, 0.1347598, 0.034188993, 0.13986087, 0.12803291, -0.16679932, -0.104109496, 0.13648869, 0.23043962]          |
|1247|[0.22964337, -0.26408425, 1.6565194, -0.19833353, 0.52432746, 0.08974882, 0.33008733, 1.3145455, -0.34737256, 0.54083633]             |
|2177|[-0.006012235, 0.11970336, 0.21558432, -0.15280657, 0.0086795185, -0.067469634, -0.038556375, 0.008515552, -0.07740968, 0.05015077]   |
|2247|[-0.06761185, -0.40897173, -0.14313649, 0.20925727, 0.028727109, -0.113159336, 0.15228641, -0.058656577, 0.12894748, -0.011355329]    |
|2527|[0.08582949, -0.07457629, -0.06486175, 0.02869886, -0.06862812, 0.029123373, -0.07809314, 0.018493466, 0.06554801, 0.01668032]        |
|2827|[1.377811, -0.69850916, 2.0206037, -0.5936584, -0.69637454, -0.030831939, 1.3373528, 0.076208815, -0.36765805, 0.19733378]            |
|2877|[-0.05168787, 0.056770574, -0.0032576998, -0.04482137, -0.046533227, -0.014454543, -0.1390726, 0.11489638, -0.096733324, -0.038567316]|
|2947|[0.75573343, -0.43991256, 0.15535285, -0.13316123, 0.656812, -0.10289777, -0.29343438, 0.47792152, -0.069517024, -0.66402215]         |
+----+--------------------------------------------------------------------------------------------------------------------------------------+
only showing top 10 rows
# 所有用户推荐

for r in als_model.userFactors.select("id").collect():
    
    userId = r.id
    cateId_df = pd.DataFrame(pdf.cateId.unique(),columns=["cateId"])
    cateId_df.insert(0, "userId", np.array([userId for i in range(len(cateId_df))]))
    
    ret = set()
    # 对用户进行商品推荐,并对预测概率统计
    cateId_list = als_model.transform(spark.createDataFrame(cateId_df)).sort("prediction", ascending=False).dropna()
    
    # 候选20个类别商品进行候选广告推荐
    for i in cateId_list.head(20):
        need = 500 - len(ret)  # 计算是否足够
        ret = ret.union(np.random.choice(pdf.where(pdf.cateId == i.cateId).adgroupId.dropna().astype(np.int64), need))
    # 此处只是测试,并未完整计算出对所有用户的推荐广告id
        if len(ret) >= 500:
            break
        break
    break
    
(userId, *ret)[:10]
(427, 2051, 367108, 101892, 18950, 716295, 63502, 12309, 12310, 2074)

参考

推荐系统

黑马python5.0

推荐系统(一):个性化电商广告推荐系统介绍、数据集介绍、项目效果展示、项目实现分析、点击率预测(CTR–Click-Through-Rate)概念

你可能感兴趣的:(推荐系统,大数据,推荐系统,机器学习,数据分析,python)