spark ml解决数据不均衡的简单方法

来源:
https://stackoverflow.com/questions/33372838/dealing-with-unbalanced-datasets-in-spark-mllib

数据不均衡有很多种解决方法,这里给一个比较简单的。
给dataframe增加一列权重列。
使用spark lr模型的时候设置参数setWeightCol填入权重列。

def balanceDataset(dataset: DataFrame): DataFrame = {

    // Re-balancing (weighting) of records to be used in the logistic loss objective function
    val numNegatives = dataset.filter(dataset("label") === 0).count
    val datasetSize = dataset.count
    val balancingRatio = (datasetSize - numNegatives).toDouble / datasetSize

    val calculateWeights = udf { d: Double =>
      if (d == 0.0) {
        1 * balancingRatio
      }
      else {
        (1 * (1.0 - balancingRatio))
      }
    }

    val weightedDataset = dataset.withColumn("classWeightCol", calculateWeights(dataset("label")))
    weightedDataset
  }

  val df_weighted = balanceDataset(df)
  val lr = new LogisticRegression().setLabelCol(labelCol).setWeightCol("classWeightCol")

你可能感兴趣的:(spark ml解决数据不均衡的简单方法)