diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 1482eb3d1f7a6..d6b847a7770b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -54,6 +54,29 @@ private[feature] trait ChiSqSelectorParams extends Params /** @group getParam */ def getNumTopFeatures: Int = $(numTopFeatures) + + final val percentile = new DoubleParam(this, "percentile", + "Percentile of features that selector will select, ordered by statistics value descending.", + ParamValidators.gtEq(0)) + setDefault(percentile -> 10) + + /** @group getParam */ + def getPercentile: Double = $(percentile) + + final val alpha = new DoubleParam(this, "alpha", + "The highest p-value for features to be kept.", + ParamValidators.gtEq(0)) + setDefault(alpha -> 0.05) + + /** @group getParam */ + def getAlpha: Double = $(alpha) + + final val selectorType = new Param[String](this, "selectorType", + "ChiSqSelector Type: KBest, Percentile, Fpr") + setDefault(selectorType -> "KBest") + + /** @group getParam */ + def getChiSqSelectorType: String = $(selectorType) } /** @@ -67,9 +90,27 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) + @Since("2.1.0") + var chiSqSelector: feature.ChiSqSelector = null + /** @group setParam */ - @Since("1.6.0") - def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + @Since("2.1.0") + def setNumTopFeatures(value: Int): this.type = { + set(selectorType, "KBest") + set(numTopFeatures, value) + } + + @Since("2.1.0") + def setPercentile(value: Double): this.type = { + set(selectorType, "Percentile") + set(percentile, value) + } + + @Since("2.1.0") + def setAlpha(value: Double): this.type = { + set(selectorType, "Fpr") + set(alpha, value) + } /** @group setParam */ @Since("1.6.0") @@ -91,8 +132,38 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str case Row(label: Double, features: Vector) => OldLabeledPoint(label, OldVectors.fromML(features)) } - val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input) - copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this)) + $(selectorType) match { + case "KBest" => + chiSqSelector = new feature.ChiSqSelector().setNumTopFeatures($(numTopFeatures)) + case "Percentile" => + chiSqSelector = new feature.ChiSqSelector().setPercentile($(percentile)) + case "Fpr" => + chiSqSelector = new feature.ChiSqSelector().setAlpha($(alpha)) + case _ => throw new Exception("Unknown ChiSqSelector Type.") + } + val model = chiSqSelector.fit(input) + copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) + } + + @Since("2.1.0") + def selectKBest(value: Int): ChiSqSelectorModel = { + require(chiSqSelector != null, "ChiSqSelector has not been created.") + val model = chiSqSelector.selectKBest(value) + copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) + } + + @Since("2.1.0") + def selectPercentile(value: Double): ChiSqSelectorModel = { + require(chiSqSelector != null, "ChiSqSelector has not been created.") + val model = chiSqSelector.selectPercentile(value) + copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) + } + + @Since("2.1.0") + def selectFpr(value: Double): ChiSqSelectorModel = { + require(chiSqSelector != null, "ChiSqSelector has not been created.") + val model = chiSqSelector.selectFpr(value) + copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 56fb2d33c2ca0..1c3b49a04b843 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -27,22 +27,27 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} +@Since("2.1.0") +object ChiSqSelectorType extends Enumeration { + type SelectorType = Value + val KBest, Percentile, Fpr = Value +} + /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). Must be ordered asc + * @param selectedFeatures list of indices to select (filter). */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { - require(isSorted(selectedFeatures), "Array has to be sorted asc") - protected def isSorted(array: Array[Int]): Boolean = { var i = 1 val len = array.length @@ -69,21 +74,22 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter, must be ordered asc + * @param filterIndices indices of features to filter */ private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + val orderedIndices = filterIndices.sorted features match { case SparseVector(size, indices, values) => - val newSize = filterIndices.length + val newSize = orderedIndices.length val newValues = new ArrayBuilder.ofDouble val newIndices = new ArrayBuilder.ofInt var i = 0 var j = 0 var indicesIdx = 0 var filterIndicesIdx = 0 - while (i < indices.length && j < filterIndices.length) { + while (i < indices.length && j < orderedIndices.length) { indicesIdx = indices(i) - filterIndicesIdx = filterIndices(j) + filterIndicesIdx = orderedIndices(j) if (indicesIdx == filterIndicesIdx) { newIndices += j newValues += values(i) @@ -101,7 +107,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( Vectors.sparse(newSize, newIndices.result(), newValues.result()) case DenseVector(values) => val values = features.toArray - Vectors.dense(filterIndices.map(i => values(i))) + Vectors.dense(orderedIndices.map(i => values(i))) case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") @@ -171,14 +177,47 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * @param numTopFeatures number of features that selector will select - * (ordered by statistic value descending) - * Note that if the number of features is less than numTopFeatures, - * then this will select all features. */ -@Since("1.3.0") -class ChiSqSelector @Since("1.3.0") ( - @Since("1.3.0") val numTopFeatures: Int) extends Serializable { +@Since("2.1.0") +class ChiSqSelector @Since("2.1.0") () extends Serializable { + private var numTopFeatures: Int = 50 + private var percentile: Double = 10.0 + private var alpha: Double = 0.05 + private var selectorType = ChiSqSelectorType.KBest + private var chiSqTestResult: Array[ChiSqTestResult] = _ + + @Since("1.3.0") + def this(numTopFeatures: Int) { + this() + this.numTopFeatures = numTopFeatures + } + + @Since("2.1.0") + def setNumTopFeatures(value: Int): this.type = { + numTopFeatures = value + selectorType = ChiSqSelectorType.KBest + this + } + + @Since("2.1.0") + def setPercentile(value: Double): this.type = { + percentile = value + selectorType = ChiSqSelectorType.Percentile + this + } + + @Since("2.1.0") + def setAlpha(value: Double): this.type = { + alpha = value + selectorType = ChiSqSelectorType.Fpr + this + } + + @Since("2.1.0") + def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = { + selectorType = value + this + } /** * Returns a ChiSquared feature selector. @@ -189,11 +228,35 @@ class ChiSqSelector @Since("1.3.0") ( */ @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { - val indices = Statistics.chiSqTest(data) - .zipWithIndex.sortBy { case (res, _) => -res.statistic } - .take(numTopFeatures) - .map { case (_, indices) => indices } - .sorted + chiSqTestResult = Statistics.chiSqTest(data) + selectorType match { + case ChiSqSelectorType.KBest => selectKBest(numTopFeatures) + case ChiSqSelectorType.Percentile => selectPercentile(percentile) + case ChiSqSelectorType.Fpr => selectFpr(alpha) + case _ => throw new Exception("Unknown ChiSqSelector Type") + } + } + + @Since("2.1.0") + def selectKBest(value: Int): ChiSqSelectorModel = { + val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + .map { case (_, indices) => indices } + new ChiSqSelectorModel(indices) + } + + @Since("2.1.0") + def selectPercentile(value: Double): ChiSqSelectorModel = { + val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic } + .take((chiSqTestResult.length * percentile / 100).toInt) + .map { case (_, indices) => indices } + new ChiSqSelectorModel(indices) + } + + @Since("2.1.0") + def selectFpr(value: Double): ChiSqSelectorModel = { + val indices = chiSqTestResult.zipWithIndex.filter{ case (res, _) => res.pValue < alpha } + .map { case (_, indices) => indices } new ChiSqSelectorModel(indices) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 3558290b23ae0..a29ff83ae0cce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -49,16 +49,23 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext .map(x => (x._1.label, x._1.features, x._2)) .toDF("label", "data", "preFilteredData") - val model = new ChiSqSelector() + val selector = new ChiSqSelector() .setNumTopFeatures(1) .setFeaturesCol("data") .setLabelCol("label") .setOutputCol("filtered") - model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { + selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) } + + selector.selectPercentile(34).transform(df) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } test("ChiSqSelector read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 734800a9afad6..e181a544f7159 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -65,6 +65,24 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(filteredData == preFilteredData) } + test("ChiSqSelector by FPR transform test (sparse & dense vector)") { + val labeledDiscreteData = sc.parallelize( + Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))), + LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))), + LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2) + val preFilteredData = + Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(4.0))), + LabeledPoint(1.0, Vectors.dense(Array(4.0))), + LabeledPoint(2.0, Vectors.dense(Array(9.0)))) + val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData) + val filteredData = labeledDiscreteData.map { lp => + LabeledPoint(lp.label, model.transform(lp.features)) + }.collect().toSet + assert(filteredData == preFilteredData) + } + test("model load / save") { val model = ChiSqSelectorSuite.createModel() val tempDir = Utils.createTempDir()