Skip to content

Latest commit

 

History

History
57 lines (46 loc) · 2.3 KB

File metadata and controls

57 lines (46 loc) · 2.3 KB

RFormula

  RFormula通过一个R model formula选择一个特定的列。 目前我们支持R算子的一个受限的子集,包括~,.,:,+,-。这些基本的算子是:

  • ~ 分开targetterms
  • + 连接term,+ 0表示删除截距(intercept)
  • - 删除term,- 1表示删除截距
  • : 交集
  • . 除了target之外的所有列

  假设abdouble列,我们用下面简单的例子来证明RFormula的有效性。

  • y ~ a + b 表示模型 y ~ w0 + w1 * a + w2 * b,其中w0是截距,w1w2是系数
  • y ~ a + b + a:b - 1表示模型y ~ w1 * a + w2 * b + w3 * a * b,其中w1,w2,w3是系数

  RFormula产生一个特征向量列和一个doublestring类型的标签列。比如在线性回归中使用R中的公式时, 字符串输入列是one-hot编码,数值列强制转换为double类型。如果标签列是字符串类型,它将使用StringIndexer转换为double 类型。如果DataFrame中不存在标签列,输出的标签列将通过公式中指定的返回变量来创建。

例子

  假设我们有一个DataFrame,它的列名是id, country, hourclicked

id | country | hour | clicked
---|---------|------|---------
 7 | "US"    | 18   | 1.0
 8 | "CA"    | 12   | 0.0
 9 | "NZ"    | 15   | 0.0

  如果我们用clicked ~ country + hour(基于countryhour来预测clicked)来作用于RFormula,将会得到下面的结果。

id | country | hour | clicked | features         | label
---|---------|------|---------|------------------|-------
 7 | "US"    | 18   | 1.0     | [0.0, 0.0, 18.0] | 1.0
 8 | "CA"    | 12   | 0.0     | [0.0, 1.0, 12.0] | 0.0
 9 | "NZ"    | 15   | 0.0     | [1.0, 0.0, 15.0] | 0.0

  下面是代码调用的例子。

import org.apache.spark.ml.feature.RFormula

val dataset = spark.createDataFrame(Seq(
  (7, "US", 18, 1.0),
  (8, "CA", 12, 0.0),
  (9, "NZ", 15, 0.0)
)).toDF("id", "country", "hour", "clicked")
val formula = new RFormula()
  .setFormula("clicked ~ country + hour")
  .setFeaturesCol("features")
  .setLabelCol("label")
val output = formula.fit(dataset).transform(dataset)
output.select("features", "label").show()