Skip to content

Median overhaul #1122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 23, 2025
162 changes: 107 additions & 55 deletions core/api/core.api

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ public fun <T : Comparable<T>> DataColumn<T?>.maxOrNull(skipNaN: Boolean = skipN

public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxBy(
skipNaN: Boolean = skipNaNDefault,
noinline selector: (T) -> R,
crossinline selector: (T) -> R,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about crossinline here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crossline is more efficient than noinline and unfortunately we cannot do it without, so yes.

): T & Any = maxByOrNull(skipNaN, selector).suggestIfNull("maxBy")

public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxByOrNull(
skipNaN: Boolean = skipNaNDefault,
noinline selector: (T) -> R,
crossinline selector: (T) -> R,
): T? = Aggregators.max<R>(skipNaN).aggregateByOrNull(this, selector)

public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxOf(
Expand All @@ -59,10 +59,10 @@ public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxOfOrNul
// region DataRow

@Deprecated(ROW_MAX_OR_NULL, level = DeprecationLevel.ERROR)
public fun AnyRow.rowMaxOrNull(): Any? = error(ROW_MAX_OR_NULL)
public fun AnyRow.rowMaxOrNull(): Nothing? = error(ROW_MAX_OR_NULL)

@Deprecated(ROW_MAX, level = DeprecationLevel.ERROR)
public fun AnyRow.rowMax(): Any = error(ROW_MAX)
public fun AnyRow.rowMax(): Nothing = error(ROW_MAX)

public inline fun <reified T : Comparable<T>> AnyRow.rowMaxOfOrNull(skipNaN: Boolean = skipNaNDefault): T? =
Aggregators.max<T>(skipNaN).aggregateOfRow(this) { colsOf<T?>() }
Expand Down
522 changes: 411 additions & 111 deletions core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ public fun <T : Comparable<T>> DataColumn<T?>.minOrNull(skipNaN: Boolean = skipN

public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minBy(
skipNaN: Boolean = skipNaNDefault,
noinline selector: (T) -> R,
crossinline selector: (T) -> R,
): T & Any = minByOrNull(skipNaN, selector).suggestIfNull("minBy")

public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minByOrNull(
skipNaN: Boolean = skipNaNDefault,
noinline selector: (T) -> R,
crossinline selector: (T) -> R,
): T? = Aggregators.min<R>(skipNaN).aggregateByOrNull(this, selector)

public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minOf(
Expand All @@ -59,10 +59,10 @@ public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minOfOrNul
// region DataRow

@Deprecated(ROW_MIN_OR_NULL, level = DeprecationLevel.ERROR)
public fun AnyRow.rowMinOrNull(): Any? = error(ROW_MIN_OR_NULL)
public fun AnyRow.rowMinOrNull(): Nothing? = error(ROW_MIN_OR_NULL)

@Deprecated(ROW_MIN, level = DeprecationLevel.ERROR)
public fun AnyRow.rowMin(): Any = error(ROW_MIN)
public fun AnyRow.rowMin(): Nothing = error(ROW_MIN)

public inline fun <reified T : Comparable<T>> AnyRow.rowMinOfOrNull(skipNaN: Boolean = skipNaNDefault): T? =
Aggregators.min<T>(skipNaN).aggregateOfRow(this) { colsOf<T?>() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
import org.jetbrains.kotlinx.dataframe.math.percentile
import kotlin.reflect.KProperty
import kotlin.reflect.typeOf

// region DataColumn

Expand Down Expand Up @@ -52,7 +53,7 @@ public fun AnyRow.rowPercentile(percentile: Double): Any =
rowPercentileOrNull(percentile).suggestIfNull("rowPercentile")

public inline fun <reified T : Comparable<T>> AnyRow.rowPercentileOfOrNull(percentile: Double): T? =
valuesOf<T>().percentile(percentile)
valuesOf<T>().percentile(percentile, typeOf<T>())

public inline fun <reified T : Comparable<T>> AnyRow.rowPercentileOf(percentile: Double): T =
rowPercentileOfOrNull<T>(percentile).suggestIfNull("rowPercentileOf")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators

import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
import kotlin.reflect.KType

Expand Down Expand Up @@ -28,7 +29,11 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
*/
fun aggregateSingleColumn(column: DataColumn<Value?>): Return
fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)

/**
* Function that can give the return type of [aggregateSequence] as [KType], given the type of the input.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators

import org.jetbrains.kotlinx.dataframe.api.skipNaNDefault
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.HybridAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.ReducingAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.AnyInputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.NumberInputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.math.indexOfMax
import org.jetbrains.kotlinx.dataframe.math.indexOfMedian
import org.jetbrains.kotlinx.dataframe.math.indexOfMin
import org.jetbrains.kotlinx.dataframe.math.maxOrNull
import org.jetbrains.kotlinx.dataframe.math.maxTypeConversion
import org.jetbrains.kotlinx.dataframe.math.mean
import org.jetbrains.kotlinx.dataframe.math.meanTypeConversion
import org.jetbrains.kotlinx.dataframe.math.median
import org.jetbrains.kotlinx.dataframe.math.medianConversion
import org.jetbrains.kotlinx.dataframe.math.medianOrNull
import org.jetbrains.kotlinx.dataframe.math.minOrNull
import org.jetbrains.kotlinx.dataframe.math.minTypeConversion
import org.jetbrains.kotlinx.dataframe.math.percentile
Expand All @@ -29,13 +33,23 @@ internal object Aggregators {
private fun <Value : Return & Any, Return : Any?> twoStepSelectingForAny(
getReturnType: CalculateReturnType,
indexOfResult: IndexOfResult<Value>,
stepOneReducer: Reducer<Value, Return>,
stepOneSelector: Selector<Value, Return>,
) = Aggregator(
aggregationHandler = SelectingAggregationHandler(stepOneReducer, indexOfResult, getReturnType),
aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType),
inputHandler = AnyInputHandler(),
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
)

private fun <Value : Any, Return : Any?> flattenHybridForAny(
getReturnType: CalculateReturnType,
indexOfResult: IndexOfResult<Value>,
reducer: Reducer<Value, Return>,
) = Aggregator(
aggregationHandler = HybridAggregationHandler(reducer, indexOfResult, getReturnType),
inputHandler = AnyInputHandler(),
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
)

private fun <Value : Any, Return : Any?> twoStepReducingForAny(
getReturnType: CalculateReturnType,
stepOneReducer: Reducer<Value, Return>,
Expand Down Expand Up @@ -101,7 +115,7 @@ internal object Aggregators {
private val min by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = minTypeConversion,
stepOneReducer = { type -> minOrNull(type, skipNaN) },
stepOneSelector = { type -> minOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMin(type, skipNaN) },
)
}
Expand All @@ -113,15 +127,15 @@ internal object Aggregators {
private val max by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = maxTypeConversion,
stepOneReducer = { type -> maxOrNull(type, skipNaN) },
stepOneSelector = { type -> maxOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMax(type, skipNaN) },
)
}

// T: Number? -> Double
val std by withTwoOptions { skipNA: Boolean, ddof: Int ->
val std by withTwoOptions { skipNaN: Boolean, ddof: Int ->
flattenReducingForNumbers(stdTypeConversion) { type ->
std(type, skipNA, ddof)
std(type, skipNaN, ddof)
}
}

Expand All @@ -140,9 +154,31 @@ internal object Aggregators {
}
}

// T: Comparable<T>? -> T
val median by flattenReducingForAny<Comparable<Any?>> { type ->
asIterable().median(type)
// T : primitive Number? -> Double?
// T : Comparable<T & Any>? -> T?
fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
where T : Comparable<T & Any>? =
median.invoke(skipNaN).cast2()

// T : Comparable<T & Any>? -> T?
fun <T> medianComparables(): Aggregator<T & Any, T?>
where T : Comparable<T & Any>? =
medianCommon<T>(skipNaNDefault).cast2()

// T : primitive Number? -> Double?
fun <T> medianNumbers(
skipNaN: Boolean,
): Aggregator<T & Any, Double?>
where T : Comparable<T & Any>?, T : Number? =
medianCommon<T>(skipNaN).cast2()

@Suppress("UNCHECKED_CAST")
private val median by withOneOption { skipNaN: Boolean ->
flattenHybridForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = medianConversion,
reducer = { type -> medianOrNull(type, skipNaN) as Comparable<Any>? },
indexOfResult = { type -> indexOfMedian(type, skipNaN) },
)
}

// T: Number -> T
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers

import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.IndexOfResult
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Reducer
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

/**
* Implementation of [AggregatorAggregationHandler] which functions like a selector ánd reducer:
* it takes a sequence of values and returns a single value, which is likely part of the input, but not necessarily.
*
* In practice, this means the handler implements both [indexOfAggregationResultSingleSequence]
* (meaning it can give an index of the result in the input), and [aggregateSequence] with a return type that is
* potentially different from the input.
* The return value of [aggregateSequence] and the value at the index retrieved from [indexOfAggregationResultSingleSequence]
* may differ.
*
* @param reducer This function actually does the selection/reduction.
* Before it is called, nulls are filtered out. The type of the values is passed as [KType] to the selector.
* @param indexOfResult This function must be supplied to give the index of the result in the input values.
* @param getReturnType This function must be supplied to give the return type of [reducer] given some input type and
* whether the input is empty.
* When selecting, the return type is always `typeOf<Value>()` or `typeOf<Value?>()`, when reducing it can be anything.
* @see [ReducingAggregationHandler]
*/
internal class HybridAggregationHandler<in Value : Any, out Return : Any?>(
val reducer: Reducer<Value, Return>,
val indexOfResult: IndexOfResult<Value>,
val getReturnType: CalculateReturnType,
) : AggregatorAggregationHandler<Value, Return> {

/**
* Function that can give the index of the aggregation result in the input [values].
* Calls the supplied [indexOfResult] after preprocessing the input.
*/
@Suppress("UNCHECKED_CAST")
override fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int {
val (values, valueType) = aggregator!!.preprocessAggregation(values, valueType)
return indexOfResult(values, valueType)
}

/**
* Base function of [Aggregator].
*
* Aggregates the given values, taking [valueType] into account,
* filtering nulls (only if [valueType.type.isMarkedNullable][KType.isMarkedNullable]),
* and computes a single resulting value.
*
* When the exact [valueType] is unknown, use [calculateValueType] or [aggregateCalculatingValueType].
*
* Calls the supplied [reducer].
*/
@Suppress("UNCHECKED_CAST")
override fun aggregateSequence(values: Sequence<Value?>, valueType: ValueType): Return {
val (values, valueType) = aggregator!!.preprocessAggregation(values, valueType)
return reducer(
// values =
if (valueType.isMarkedNullable) {
values.filterNotNull()
} else {
values as Sequence<Value>
},
// type =
valueType.withNullability(false),
)
}

/**
* Give the return type of [reducer] given some input type and whether the input is empty.
* Calls the supplied [getReturnType].
*/
override fun calculateReturnType(valueType: KType, emptyInput: Boolean): KType =
getReturnType(valueType.withNullability(false), emptyInput)

override var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>? = null
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers

import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Reducer
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.toValueType
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

Expand Down Expand Up @@ -54,17 +51,6 @@ internal class ReducingAggregationHandler<in Value : Any, out Return : Any?>(
)
}

/**
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
*/
@Suppress("UNCHECKED_CAST")
override fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)

/** This function always returns `-1` because the result of a reducer is not in the input values. */
override fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int = -1

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers

import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.AggregatorAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnType
Expand All @@ -10,7 +8,6 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Selector
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.ValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCalculatingValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.calculateValueType
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.toValueType
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

Expand Down Expand Up @@ -70,16 +67,6 @@ internal class SelectingAggregationHandler<in Value : Return & Any, out Return :
)
}

/**
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
*/
override fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)

/**
* Give the return type of [selector] given some input type and whether the input is empty.
* Calls the supplied [getReturnType].
Expand All @@ -91,7 +78,7 @@ internal class SelectingAggregationHandler<in Value : Return & Any, out Return :
require(it == valueType.withNullability(false) || it == valueType.withNullability(true)) {
"The return type of the selector must be either ${valueType.withNullability(false)} or ${
valueType.withNullability(true)
}"
} but was $it."
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ internal inline fun <T> Aggregatable<T>.remainingColumns(
crossinline predicate: (AnyCol) -> Boolean,
): ColumnsSelector<T, Any?> = remainingColumnsSelector().filter { predicate(it.data) }

/**
* Emulates selecting all columns whose values are comparable to each other.
* These are columns of type `R` where `R : Comparable<R>`.
*
* There is no way to denote this generically in types, however,
* hence the _fake_ type `Comparable<Any>` is used.
* (`Comparable<Nothing>` would be more correct, but then the compiler complains)
*/
@Suppress("UNCHECKED_CAST")
internal fun <T> Aggregatable<T>.intraComparableColumns(): ColumnsSelector<T, Comparable<Any?>> =
remainingColumns { it.valuesAreComparable() } as ColumnsSelector<T, Comparable<Any?>>
internal fun <T> Aggregatable<T>.intraComparableColumns(): ColumnsSelector<T, Comparable<Any>?> =
remainingColumns { it.valuesAreComparable() } as ColumnsSelector<T, Comparable<Any>?>

@Suppress("UNCHECKED_CAST")
internal fun <T> Aggregatable<T>.numberColumns(): ColumnsSelector<T, Number?> =
Expand Down
Loading