Skip to content

Commit cae42f5

Browse files
committed
Field TypedEncoder
1 parent 765cad9 commit cae42f5

File tree

5 files changed

+57
-6
lines changed

5 files changed

+57
-6
lines changed

dataset/src/main/scala/frameless/RecordEncoder.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,18 @@ final class RecordFieldEncoder[T](
187187
private[frameless] val jvmRepr: DataType,
188188
private[frameless] val fromCatalyst: Expression => Expression,
189189
private[frameless] val toCatalyst: Expression => Expression
190-
) extends Serializable
190+
) extends Serializable { self =>
191+
private[frameless] def toTypedEncoder = new TypedEncoder[T]()(encoder.classTag) {
192+
def nullable: Boolean = encoder.nullable
193+
194+
def jvmRepr: DataType = self.jvmRepr
195+
def catalystRepr: DataType = encoder.catalystRepr
196+
197+
def fromCatalyst(path: Expression): Expression = self.fromCatalyst(path)
198+
199+
def toCatalyst(path: Expression): Expression = self.toCatalyst(path)
200+
}
201+
}
191202

192203
object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
193204

dataset/src/main/scala/frameless/TypedEncoder.scala

+9
Original file line numberDiff line numberDiff line change
@@ -756,5 +756,14 @@ object TypedEncoder {
756756
}
757757
}
758758

759+
/**
760+
* In case a type `T` encoding is supported in derivation (as a struct field),
761+
* then this allows to resolve the corresponding `TypedEncoder[T]`,
762+
* so the field can be handled invidiually.
763+
*/
764+
def usingFieldEncoder[T](
765+
implicit fieldEncoder: shapeless.Lazy[RecordFieldEncoder[T]]):
766+
TypedEncoder[T] = fieldEncoder.value.toTypedEncoder
767+
759768
object injections extends InjectionEnum
760769
}

dataset/src/test/scala/frameless/ColumnTests.scala

+27
Original file line numberDiff line numberDiff line change
@@ -615,4 +615,31 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
615615
// we should be able to block the following as well...
616616
"ds.col(_.a.toInt)" shouldNot typeCheck
617617
}
618+
619+
test("col through record encoder (for Value class)") {
620+
import RecordEncoderTests.{ Name, Person }
621+
622+
val bar = new Name("bar")
623+
val foo = new Name("foo")
624+
625+
val ds: TypedDataset[Person] =
626+
TypedDataset.create(Seq(Person(bar, 23), Person(foo, 11)))
627+
628+
a[org.apache.spark.sql.AnalysisException] should be thrownBy {
629+
// TypedEncoder[Name] is resolved using case class derivation,
630+
// which is not compatible to the way such Value class
631+
// is encoded as a field in another class,
632+
// which leads to encoding/analysis error.
633+
634+
ds.select(ds.col[Name](Symbol("name"))).
635+
collect.run().toSeq shouldEqual Seq[Name](bar, foo)
636+
}
637+
638+
{
639+
implicit def enc: TypedEncoder[Name] = TypedEncoder.usingFieldEncoder[Name]
640+
641+
ds.select(ds.col[Name](Symbol("name"))).
642+
collect.run().toSeq shouldEqual Seq[Name](bar, foo)
643+
}
644+
}
618645
}

dataset/src/test/scala/frameless/CreateTests.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class CreateTests extends TypedDatasetSuite with Matchers {
130130
val v = Vector(X2(1,2))
131131
val df = TypedDataset.create(v).dataset.toDF()
132132

133-
a [IllegalStateException] should be thrownBy {
133+
a[IllegalStateException] should be thrownBy {
134134
TypedDataset.createUnsafe[X1[Int]](df).show().run()
135135
}
136136
}

dataset/src/test/scala/frameless/TypedDatasetSuite.scala

+8-4
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,17 @@ class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll
7979
val da = numeric.toDouble(a)
8080
val db = numeric.toDouble(b)
8181
val epsilon = 1E-6
82+
8283
// Spark has a weird behaviour concerning expressions that should return Inf
8384
// Most of the time they return NaN instead, for instance stddev of Seq(-7.827553978923477E227, -5.009124275715786E153)
84-
if((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved
85-
else if (
85+
if ((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) {
86+
proved
87+
} else if (
8688
(da - db).abs < epsilon ||
87-
(da - db).abs < da.abs / 100)
89+
(da - db).abs < da.abs / 100) {
8890
proved
89-
else falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon."
91+
} else {
92+
falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon."
93+
}
9094
}
9195
}

0 commit comments

Comments
 (0)