Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ private[feature] trait ChiSqSelectorParams extends Params

/** @group getParam */
def getNumTopFeatures: Int = $(numTopFeatures)

final val percentile = new DoubleParam(this, "percentile",
"Percentile of features that selector will select, ordered by statistics value descending.",
ParamValidators.gtEq(0))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still okay when percentile is 0?

setDefault(percentile -> 10)

/** @group getParam */
def getPercentile: Double = $(percentile)

final val alpha = new DoubleParam(this, "alpha",
"The highest p-value for features to be kept.",
ParamValidators.gtEq(0))
setDefault(alpha -> 0.05)

/** @group getParam */
def getAlpha: Double = $(alpha)

final val selectorType = new Param[String](this, "selectorType",
"ChiSqSelector Type: KBest, Percentile, Fpr")
setDefault(selectorType -> "KBest")

/** @group getParam */
def getChiSqSelectorType: String = $(selectorType)
}

/**
Expand All @@ -67,9 +90,27 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
@Since("1.6.0")
def this() = this(Identifiable.randomUID("chiSqSelector"))

@Since("2.1.0")
var chiSqSelector: feature.ChiSqSelector = null

/** @group setParam */
@Since("1.6.0")
def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
@Since("2.1.0")
def setNumTopFeatures(value: Int): this.type = {
set(selectorType, "KBest")
set(numTopFeatures, value)
}

@Since("2.1.0")
def setPercentile(value: Double): this.type = {
set(selectorType, "Percentile")
set(percentile, value)
}

@Since("2.1.0")
def setAlpha(value: Double): this.type = {
set(selectorType, "Fpr")
set(alpha, value)
}

/** @group setParam */
@Since("1.6.0")
Expand All @@ -91,8 +132,38 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
case Row(label: Double, features: Vector) =>
OldLabeledPoint(label, OldVectors.fromML(features))
}
val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input)
copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this))
$(selectorType) match {
case "KBest" =>
chiSqSelector = new feature.ChiSqSelector().setNumTopFeatures($(numTopFeatures))
case "Percentile" =>
chiSqSelector = new feature.ChiSqSelector().setPercentile($(percentile))
case "Fpr" =>
chiSqSelector = new feature.ChiSqSelector().setAlpha($(alpha))
case _ => throw new Exception("Unknown ChiSqSelector Type.")
}
val model = chiSqSelector.fit(input)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("2.1.0")
def selectKBest(value: Int): ChiSqSelectorModel = {
require(chiSqSelector != null, "ChiSqSelector has not been created.")
val model = chiSqSelector.selectKBest(value)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("2.1.0")
def selectPercentile(value: Double): ChiSqSelectorModel = {
require(chiSqSelector != null, "ChiSqSelector has not been created.")
val model = chiSqSelector.selectPercentile(value)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("2.1.0")
def selectFpr(value: Double): ChiSqSelectorModel = {
require(chiSqSelector != null, "ChiSqSelector has not been created.")
val model = chiSqSelector.selectFpr(value)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.test.ChiSqTestResult
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}

@Since("2.1.0")
object ChiSqSelectorType extends Enumeration {
type SelectorType = Value

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add since version

val KBest, Percentile, Fpr = Value
}

/**
* Chi Squared selector model.
*
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
* @param selectedFeatures list of indices to select (filter).
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {

require(isSorted(selectedFeatures), "Array has to be sorted asc")

protected def isSorted(array: Array[Int]): Boolean = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this paragraph?

var i = 1
val len = array.length
Expand All @@ -69,21 +74,22 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
* @param filterIndices indices of features to filter, must be ordered asc
* @param filterIndices indices of features to filter
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
val orderedIndices = filterIndices.sorted
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
val newSize = orderedIndices.length
val newValues = new ArrayBuilder.ofDouble
val newIndices = new ArrayBuilder.ofInt
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
while (i < indices.length && j < filterIndices.length) {
while (i < indices.length && j < orderedIndices.length) {
indicesIdx = indices(i)
filterIndicesIdx = filterIndices(j)
filterIndicesIdx = orderedIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
Expand All @@ -101,7 +107,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
Vectors.dense(filterIndices.map(i => values(i)))
Vectors.dense(orderedIndices.map(i => values(i)))
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
Expand Down Expand Up @@ -171,14 +177,47 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {

/**
* Creates a ChiSquared feature selector.
* @param numTopFeatures number of features that selector will select
* (ordered by statistic value descending)
* Note that if the number of features is less than numTopFeatures,
* then this will select all features.
*/
@Since("1.3.0")
class ChiSqSelector @Since("1.3.0") (
@Since("1.3.0") val numTopFeatures: Int) extends Serializable {
@Since("2.1.0")
class ChiSqSelector @Since("2.1.0") () extends Serializable {
private var numTopFeatures: Int = 50
private var percentile: Double = 10.0
private var alpha: Double = 0.05
private var selectorType = ChiSqSelectorType.KBest
private var chiSqTestResult: Array[ChiSqTestResult] = _

@Since("1.3.0")
def this(numTopFeatures: Int) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add "Since @1.3.0" here

this()
this.numTopFeatures = numTopFeatures
}

@Since("2.1.0")
def setNumTopFeatures(value: Int): this.type = {
numTopFeatures = value
selectorType = ChiSqSelectorType.KBest
this
}

@Since("2.1.0")
def setPercentile(value: Double): this.type = {
percentile = value
selectorType = ChiSqSelectorType.Percentile
this
}

@Since("2.1.0")
def setAlpha(value: Double): this.type = {
alpha = value
selectorType = ChiSqSelectorType.Fpr
this
}

@Since("2.1.0")
def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = {
selectorType = value
this
}

/**
* Returns a ChiSquared feature selector.
Expand All @@ -189,11 +228,35 @@ class ChiSqSelector @Since("1.3.0") (
*/
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val indices = Statistics.chiSqTest(data)
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
.map { case (_, indices) => indices }
.sorted
chiSqTestResult = Statistics.chiSqTest(data)
selectorType match {
case ChiSqSelectorType.KBest => selectKBest(numTopFeatures)
case ChiSqSelectorType.Percentile => selectPercentile(percentile)
case ChiSqSelectorType.Fpr => selectFpr(alpha)
case _ => throw new Exception("Unknown ChiSqSelector Type")
}
}

@Since("2.1.0")
def selectKBest(value: Int): ChiSqSelectorModel = {
val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
.map { case (_, indices) => indices }
new ChiSqSelectorModel(indices)
}

@Since("2.1.0")
def selectPercentile(value: Double): ChiSqSelectorModel = {
val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic }
.take((chiSqTestResult.length * percentile / 100).toInt)
.map { case (_, indices) => indices }
new ChiSqSelectorModel(indices)
}

@Since("2.1.0")
def selectFpr(value: Double): ChiSqSelectorModel = {
val indices = chiSqTestResult.zipWithIndex.filter{ case (res, _) => res.pValue < alpha }
.map { case (_, indices) => indices }
new ChiSqSelectorModel(indices)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,23 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
.map(x => (x._1.label, x._1.features, x._2))
.toDF("label", "data", "preFilteredData")

val model = new ChiSqSelector()
val selector = new ChiSqSelector()
.setNumTopFeatures(1)
.setFeaturesCol("data")
.setLabelCol("label")
.setOutputCol("filtered")

model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}

selector.selectPercentile(34).transform(df)
.select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}

}

test("ChiSqSelector read/write") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(filteredData == preFilteredData)
}

test("ChiSqSelector by FPR transform test (sparse & dense vector)") {
val labeledDiscreteData = sc.parallelize(
Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))),
LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))),
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))),
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2)
val preFilteredData =
Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(2.0, Vectors.dense(Array(9.0))))
val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData)
val filteredData = labeledDiscreteData.map { lp =>
LabeledPoint(lp.label, model.transform(lp.features))
}.collect().toSet
assert(filteredData == preFilteredData)
}

test("model load / save") {
val model = ChiSqSelectorSuite.createModel()
val tempDir = Utils.createTempDir()
Expand Down