@@ -10,22 +10,23 @@ import scala.collection.mutable
1010/**
1111 * Created by hollinwilkins on 12/27/16.
1212 */
13- @ SparkCode (uri = " https://github.com/apache/spark/blob/v2.0 .0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala" )
13+ @ SparkCode (uri = " https://github.com/apache/spark/blob/v3.2 .0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala" )
1414case class ChiSqSelectorModel (filterIndices : Seq [Int ],
1515 inputSize : Int ) extends Model {
16+ private val sortedFilterIndices = filterIndices.sorted
1617 def apply (features : Vector ): Vector = {
1718 features match {
1819 case SparseVector (size, indices, values) =>
19- val newSize = filterIndices .length
20+ val newSize = sortedFilterIndices .length
2021 val newValues = mutable.ArrayBuilder .make[Double ]
2122 val newIndices = mutable.ArrayBuilder .make[Int ]
2223 var i = 0
2324 var j = 0
2425 var indicesIdx = 0
2526 var filterIndicesIdx = 0
26- while (i < indices.length && j < filterIndices .length) {
27+ while (i < indices.length && j < sortedFilterIndices .length) {
2728 indicesIdx = indices(i)
28- filterIndicesIdx = filterIndices (j)
29+ filterIndicesIdx = sortedFilterIndices (j)
2930 if (indicesIdx == filterIndicesIdx) {
3031 newIndices += j
3132 newValues += values(i)
@@ -43,7 +44,7 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],
4344 Vectors .sparse(newSize, newIndices.result(), newValues.result())
4445 case DenseVector (values) =>
4546 val values = features.toArray
46- Vectors .dense(filterIndices .map(i => values(i)).toArray)
47+ Vectors .dense(sortedFilterIndices .map(i => values(i)).toArray)
4748 case other =>
4849 throw new UnsupportedOperationException (
4950 s " Only sparse and dense vectors are supported but got ${other.getClass}. " )
@@ -52,5 +53,5 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],
5253
5354 override def inputSchema : StructType = StructType (" input" -> TensorType .Double (inputSize)).get
5455
55- override def outputSchema : StructType = StructType (" output" -> TensorType .Double (filterIndices .length)).get
56+ override def outputSchema : StructType = StructType (" output" -> TensorType .Double (sortedFilterIndices .length)).get
5657}
0 commit comments