Skip to content

Commit 5ce228d

Browse files
committed
Upgrade ChiSqSelectorModel to spark 3.2.0 compatable design
sort filterIndiecs before using it
1 parent d47d0b4 commit 5ce228d

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

mleap-core/src/main/scala/ml/combust/mleap/core/feature/ChiSqSelectorModel.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,23 @@ import scala.collection.mutable
1010
/**
1111
* Created by hollinwilkins on 12/27/16.
1212
*/
13-
@SparkCode(uri = "https://github.com/apache/spark/blob/v2.0.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala")
13+
@SparkCode(uri = "https://github.com/apache/spark/blob/v3.2.0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala")
1414
case class ChiSqSelectorModel(filterIndices: Seq[Int],
1515
inputSize: Int) extends Model {
16+
private val sortedFilterIndices = filterIndices.sorted
1617
def apply(features: Vector): Vector = {
1718
features match {
1819
case SparseVector(size, indices, values) =>
19-
val newSize = filterIndices.length
20+
val newSize = sortedFilterIndices.length
2021
val newValues = mutable.ArrayBuilder.make[Double]
2122
val newIndices = mutable.ArrayBuilder.make[Int]
2223
var i = 0
2324
var j = 0
2425
var indicesIdx = 0
2526
var filterIndicesIdx = 0
26-
while (i < indices.length && j < filterIndices.length) {
27+
while (i < indices.length && j < sortedFilterIndices.length) {
2728
indicesIdx = indices(i)
28-
filterIndicesIdx = filterIndices(j)
29+
filterIndicesIdx = sortedFilterIndices(j)
2930
if (indicesIdx == filterIndicesIdx) {
3031
newIndices += j
3132
newValues += values(i)
@@ -43,7 +44,7 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],
4344
Vectors.sparse(newSize, newIndices.result(), newValues.result())
4445
case DenseVector(values) =>
4546
val values = features.toArray
46-
Vectors.dense(filterIndices.map(i => values(i)).toArray)
47+
Vectors.dense(sortedFilterIndices.map(i => values(i)).toArray)
4748
case other =>
4849
throw new UnsupportedOperationException(
4950
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
@@ -52,5 +53,5 @@ case class ChiSqSelectorModel(filterIndices: Seq[Int],
5253

5354
override def inputSchema: StructType = StructType("input" -> TensorType.Double(inputSize)).get
5455

55-
override def outputSchema: StructType = StructType("output" -> TensorType.Double(filterIndices.length)).get
56+
override def outputSchema: StructType = StructType("output" -> TensorType.Double(sortedFilterIndices.length)).get
5657
}

mleap-core/src/test/scala/ml/combust/mleap/core/feature/ChiSqSelectorModelSpec.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,22 @@ package ml.combust.mleap.core.feature
22

33
import ml.combust.mleap.core.types.{StructField, TensorType}
44
import org.scalatest.FunSpec
5+
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
56

67
class ChiSqSelectorModelSpec extends FunSpec {
78

89
describe("input/output schema"){
9-
val model = new ChiSqSelectorModel(Seq(1,2,3), 3)
10+
val model = new ChiSqSelectorModel(Seq(2,3, 1), 3)
11+
12+
it("Dense vector work with unsorted indices") {
13+
val vector = Vectors.dense(1.0,2.0,3.0,4.0)
14+
assert(model(vector) == Vectors.dense(2.0, 3.0, 4.0))
15+
}
16+
17+
it("Sparse vector work with unsorted indices") {
18+
val vector = Vectors.sparse(size = 4, indices=Array(0,1,2,3), values = Array(1.0,2.0,3.0,4.0))
19+
assert(model(vector) == Vectors.sparse(size=3, indices=Array(0,1,2), values=Array(2.0,3.0,4.0)))
20+
}
1021

1122
it("Has the right input schema") {
1223
assert(model.inputSchema.fields ==

0 commit comments

Comments
 (0)