Skip to content

Commit 6a068e0

Browse files
committed
Add TypedOneHotEncoder
1 parent 012f1a1 commit 6a068e0

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package frameless
2+
package ml
3+
package feature
4+
5+
import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid
6+
import frameless.ml.internals.UnaryInputsChecker
7+
import org.apache.spark.ml.Estimator
8+
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, OneHotEncoderModel}
9+
import org.apache.spark.ml.linalg.Vector
10+
11+
/**
12+
* A one-hot encoder that maps a column of category indices to a column of binary vectors, with
13+
* at most a single one-value per row that indicates the input category index.
14+
*
15+
* @see `TypedStringIndexer` for converting categorical values into category indices
16+
*/
17+
class TypedOneHotEncoder[Inputs] private[ml](oneHotEncoder: OneHotEncoderEstimator, inputCol: String)
18+
extends TypedEstimator[Inputs, TypedOneHotEncoder.Outputs, OneHotEncoderModel] {
19+
20+
override val estimator: Estimator[OneHotEncoderModel] = oneHotEncoder
21+
.setInputCols(Array(inputCol))
22+
.setOutputCols(Array(AppendTransformer.tempColumnName))
23+
24+
def setHandleInvalid(value: HandleInvalid): TypedOneHotEncoder[Inputs] =
25+
copy(oneHotEncoder.setHandleInvalid(value.sparkValue))
26+
27+
def setDropLast(value: Boolean): TypedOneHotEncoder[Inputs] =
28+
copy(oneHotEncoder.setDropLast(value))
29+
30+
private def copy(newOneHotEncoder: OneHotEncoderEstimator): TypedOneHotEncoder[Inputs] =
31+
new TypedOneHotEncoder[Inputs](newOneHotEncoder, inputCol)
32+
}
33+
34+
object TypedOneHotEncoder {
35+
36+
case class Outputs(output: Vector)
37+
38+
sealed abstract class HandleInvalid(val sparkValue: String)
39+
object HandleInvalid {
40+
case object Error extends HandleInvalid("error")
41+
case object Keep extends HandleInvalid("keep")
42+
}
43+
44+
def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, Int]): TypedOneHotEncoder[Inputs] = {
45+
new TypedOneHotEncoder[Inputs](new OneHotEncoderEstimator(), inputsChecker.inputCol)
46+
}
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package frameless
2+
package ml
3+
package feature
4+
5+
import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid
6+
import org.apache.spark.ml.linalg._
7+
import org.scalacheck.{Arbitrary, Gen}
8+
import org.scalacheck.Prop._
9+
import shapeless.test.illTyped
10+
11+
class TypedOneHotEncoderTests extends FramelessMlSuite {
12+
13+
test(".fit() returns a correct TypedTransformer") {
14+
implicit val arbInt = Arbitrary(Gen.choose(0, 99))
15+
def prop[A: TypedEncoder : Arbitrary] = forAll { (x2: X2[Int, A], dropLast: Boolean) =>
16+
val encoder = TypedOneHotEncoder[X1[Int]].setDropLast(dropLast)
17+
val inputs = 0.to(x2.a).map(i => X2(i, x2.b))
18+
val ds = TypedDataset.create(inputs)
19+
val model = encoder.fit(ds).run()
20+
val resultDs = model.transform(TypedDataset.create(Seq(x2))).as[X3[Int, A, Vector]]
21+
val result = resultDs.collect.run()
22+
if (dropLast) {
23+
result == Seq (X3(x2.a, x2.b,
24+
Vectors.sparse(x2.a, Array.emptyIntArray, Array.emptyDoubleArray)))
25+
} else {
26+
result == Seq (X3(x2.a, x2.b,
27+
Vectors.sparse(x2.a + 1, Array(x2.a), Array(1.0))))
28+
}
29+
}
30+
31+
check(prop[Double])
32+
check(prop[String])
33+
}
34+
35+
test("param setting is retained") {
36+
implicit val arbHandleInvalid: Arbitrary[HandleInvalid] = Arbitrary {
37+
Gen.oneOf(HandleInvalid.Keep, HandleInvalid.Error)
38+
}
39+
40+
val prop = forAll { handleInvalid: HandleInvalid =>
41+
val encoder = TypedOneHotEncoder[X1[Int]]
42+
.setHandleInvalid(handleInvalid)
43+
val ds = TypedDataset.create(Seq(X1(1)))
44+
val model = encoder.fit(ds).run()
45+
46+
model.transformer.getHandleInvalid == handleInvalid.sparkValue
47+
}
48+
49+
check(prop)
50+
}
51+
52+
test("create() compiles only with correct inputs") {
53+
illTyped("TypedOneHotEncoder.create[Double]()")
54+
illTyped("TypedOneHotEncoder.create[X1[Double]]()")
55+
illTyped("TypedOneHotEncoder.create[X2[String, Long]]()")
56+
}
57+
}

0 commit comments

Comments
 (0)