|
| 1 | +package frameless |
| 2 | +package ml |
| 3 | +package feature |
| 4 | + |
| 5 | +import frameless.ml.feature.TypedOneHotEncoder.HandleInvalid |
| 6 | +import org.apache.spark.ml.linalg._ |
| 7 | +import org.scalacheck.{Arbitrary, Gen} |
| 8 | +import org.scalacheck.Prop._ |
| 9 | +import shapeless.test.illTyped |
| 10 | + |
| 11 | +class TypedOneHotEncoderTests extends FramelessMlSuite { |
| 12 | + |
| 13 | + test(".fit() returns a correct TypedTransformer") { |
| 14 | + implicit val arbInt = Arbitrary(Gen.choose(0, 99)) |
| 15 | + def prop[A: TypedEncoder : Arbitrary] = forAll { (x2: X2[Int, A], dropLast: Boolean) => |
| 16 | + val encoder = TypedOneHotEncoder[X1[Int]].setDropLast(dropLast) |
| 17 | + val inputs = 0.to(x2.a).map(i => X2(i, x2.b)) |
| 18 | + val ds = TypedDataset.create(inputs) |
| 19 | + val model = encoder.fit(ds).run() |
| 20 | + val resultDs = model.transform(TypedDataset.create(Seq(x2))).as[X3[Int, A, Vector]] |
| 21 | + val result = resultDs.collect.run() |
| 22 | + if (dropLast) { |
| 23 | + result == Seq (X3(x2.a, x2.b, |
| 24 | + Vectors.sparse(x2.a, Array.emptyIntArray, Array.emptyDoubleArray))) |
| 25 | + } else { |
| 26 | + result == Seq (X3(x2.a, x2.b, |
| 27 | + Vectors.sparse(x2.a + 1, Array(x2.a), Array(1.0)))) |
| 28 | + } |
| 29 | + } |
| 30 | + |
| 31 | + check(prop[Double]) |
| 32 | + check(prop[String]) |
| 33 | + } |
| 34 | + |
| 35 | + test("param setting is retained") { |
| 36 | + implicit val arbHandleInvalid: Arbitrary[HandleInvalid] = Arbitrary { |
| 37 | + Gen.oneOf(HandleInvalid.Keep, HandleInvalid.Error) |
| 38 | + } |
| 39 | + |
| 40 | + val prop = forAll { handleInvalid: HandleInvalid => |
| 41 | + val encoder = TypedOneHotEncoder[X1[Int]] |
| 42 | + .setHandleInvalid(handleInvalid) |
| 43 | + val ds = TypedDataset.create(Seq(X1(1))) |
| 44 | + val model = encoder.fit(ds).run() |
| 45 | + |
| 46 | + model.transformer.getHandleInvalid == handleInvalid.sparkValue |
| 47 | + } |
| 48 | + |
| 49 | + check(prop) |
| 50 | + } |
| 51 | + |
| 52 | + test("create() compiles only with correct inputs") { |
| 53 | + illTyped("TypedOneHotEncoder.create[Double]()") |
| 54 | + illTyped("TypedOneHotEncoder.create[X1[Double]]()") |
| 55 | + illTyped("TypedOneHotEncoder.create[X2[String, Long]]()") |
| 56 | + } |
| 57 | +} |
0 commit comments