-
Notifications
You must be signed in to change notification settings - Fork 1
Fpr chi square2 #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2adebe8
04053ca
7623563
3d6aecb
026ac85
b522c5a
3431a7a
89e2dd5
ab96c06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,22 +27,27 @@ import org.apache.spark.annotation.Since | |
| import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.stat.Statistics | ||
| import org.apache.spark.mllib.stat.test.ChiSqTestResult | ||
| import org.apache.spark.mllib.util.{Loader, Saveable} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.sql.{Row, SparkSession} | ||
|
|
||
| @Since("2.1.0") | ||
| object ChiSqSelectorType extends Enumeration { | ||
| type SelectorType = Value | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add since version |
||
| val KBest, Percentile, Fpr = Value | ||
| } | ||
|
|
||
| /** | ||
| * Chi Squared selector model. | ||
| * | ||
| * @param selectedFeatures list of indices to select (filter). Must be ordered asc | ||
| * @param selectedFeatures list of indices to select (filter). | ||
| */ | ||
| @Since("1.3.0") | ||
| class ChiSqSelectorModel @Since("1.3.0") ( | ||
| @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { | ||
|
|
||
| require(isSorted(selectedFeatures), "Array has to be sorted asc") | ||
|
|
||
| protected def isSorted(array: Array[Int]): Boolean = { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we remove this paragraph? |
||
| var i = 1 | ||
| val len = array.length | ||
|
|
@@ -69,21 +74,22 @@ class ChiSqSelectorModel @Since("1.3.0") ( | |
| * Preserves the order of filtered features the same as their indices are stored. | ||
| * Might be moved to Vector as .slice | ||
| * @param features vector | ||
| * @param filterIndices indices of features to filter, must be ordered asc | ||
| * @param filterIndices indices of features to filter | ||
| */ | ||
| private def compress(features: Vector, filterIndices: Array[Int]): Vector = { | ||
| val orderedIndices = filterIndices.sorted | ||
| features match { | ||
| case SparseVector(size, indices, values) => | ||
| val newSize = filterIndices.length | ||
| val newSize = orderedIndices.length | ||
| val newValues = new ArrayBuilder.ofDouble | ||
| val newIndices = new ArrayBuilder.ofInt | ||
| var i = 0 | ||
| var j = 0 | ||
| var indicesIdx = 0 | ||
| var filterIndicesIdx = 0 | ||
| while (i < indices.length && j < filterIndices.length) { | ||
| while (i < indices.length && j < orderedIndices.length) { | ||
| indicesIdx = indices(i) | ||
| filterIndicesIdx = filterIndices(j) | ||
| filterIndicesIdx = orderedIndices(j) | ||
| if (indicesIdx == filterIndicesIdx) { | ||
| newIndices += j | ||
| newValues += values(i) | ||
|
|
@@ -101,7 +107,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( | |
| Vectors.sparse(newSize, newIndices.result(), newValues.result()) | ||
| case DenseVector(values) => | ||
| val values = features.toArray | ||
| Vectors.dense(filterIndices.map(i => values(i))) | ||
| Vectors.dense(orderedIndices.map(i => values(i))) | ||
| case other => | ||
| throw new UnsupportedOperationException( | ||
| s"Only sparse and dense vectors are supported but got ${other.getClass}.") | ||
|
|
@@ -171,14 +177,47 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { | |
|
|
||
| /** | ||
| * Creates a ChiSquared feature selector. | ||
| * @param numTopFeatures number of features that selector will select | ||
| * (ordered by statistic value descending) | ||
| * Note that if the number of features is less than numTopFeatures, | ||
| * then this will select all features. | ||
| */ | ||
| @Since("1.3.0") | ||
| class ChiSqSelector @Since("1.3.0") ( | ||
| @Since("1.3.0") val numTopFeatures: Int) extends Serializable { | ||
| @Since("2.1.0") | ||
| class ChiSqSelector @Since("2.1.0") () extends Serializable { | ||
| private var numTopFeatures: Int = 50 | ||
| private var percentile: Double = 10.0 | ||
| private var alpha: Double = 0.05 | ||
| private var selectorType = ChiSqSelectorType.KBest | ||
| private var chiSqTestResult: Array[ChiSqTestResult] = _ | ||
|
|
||
| @Since("1.3.0") | ||
| def this(numTopFeatures: Int) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add "Since @1.3.0" here |
||
| this() | ||
| this.numTopFeatures = numTopFeatures | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setNumTopFeatures(value: Int): this.type = { | ||
| numTopFeatures = value | ||
| selectorType = ChiSqSelectorType.KBest | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setPercentile(value: Double): this.type = { | ||
| percentile = value | ||
| selectorType = ChiSqSelectorType.Percentile | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setAlpha(value: Double): this.type = { | ||
| alpha = value | ||
| selectorType = ChiSqSelectorType.Fpr | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = { | ||
| selectorType = value | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Returns a ChiSquared feature selector. | ||
|
|
@@ -189,11 +228,35 @@ class ChiSqSelector @Since("1.3.0") ( | |
| */ | ||
| @Since("1.3.0") | ||
| def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { | ||
| val indices = Statistics.chiSqTest(data) | ||
| .zipWithIndex.sortBy { case (res, _) => -res.statistic } | ||
| .take(numTopFeatures) | ||
| .map { case (_, indices) => indices } | ||
| .sorted | ||
| chiSqTestResult = Statistics.chiSqTest(data) | ||
| selectorType match { | ||
| case ChiSqSelectorType.KBest => selectKBest(numTopFeatures) | ||
| case ChiSqSelectorType.Percentile => selectPercentile(percentile) | ||
| case ChiSqSelectorType.Fpr => selectFpr(alpha) | ||
| case _ => throw new Exception("Unknown ChiSqSelector Type") | ||
| } | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectKBest(value: Int): ChiSqSelectorModel = { | ||
| val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic } | ||
| .take(numTopFeatures) | ||
| .map { case (_, indices) => indices } | ||
| new ChiSqSelectorModel(indices) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectPercentile(value: Double): ChiSqSelectorModel = { | ||
| val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic } | ||
| .take((chiSqTestResult.length * percentile / 100).toInt) | ||
| .map { case (_, indices) => indices } | ||
| new ChiSqSelectorModel(indices) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectFpr(value: Double): ChiSqSelectorModel = { | ||
| val indices = chiSqTestResult.zipWithIndex.filter{ case (res, _) => res.pValue < alpha } | ||
| .map { case (_, indices) => indices } | ||
| new ChiSqSelectorModel(indices) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it still okay when percentile is 0?