Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.NoMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

Expand Down Expand Up @@ -51,6 +52,7 @@ public class Aggregator<in Value : Any, out Return : Any?>(
public val inputHandler: AggregatorInputHandler<Value, Return>,
public val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
public val name: String,
public val statisticsParameters: Map<String, ParameterValue?>,
) : AggregatorInputHandler<Value, Return> by inputHandler,
AggregatorMultipleColumnsHandler<Value, Return> by multipleColumnsHandler,
AggregatorAggregationHandler<Value, Return> by aggregationHandler {
Expand Down Expand Up @@ -84,13 +86,31 @@ public class Aggregator<in Value : Any, out Return : Any?>(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
statisticsParameters: Map<String, ParameterValue?>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
statisticsParameters = statisticsParameters,
)
}

// fictitious, I want the program to compile
internal operator fun <Value : Any, Return : Any?> invoke(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
emptyMap(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import org.jetbrains.kotlinx.dataframe.impl.columns.StatisticResult
import org.jetbrains.kotlinx.dataframe.impl.columns.ValueColumnInternal
import kotlin.reflect.KType

/**
Expand All @@ -26,13 +29,37 @@ public interface AggregatorAggregationHandler<in Value : Any, out Return : Any?>

/**
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
* Calls [aggregateSequence]. It tries to exploit a cache for statistics which is proper of
* [ValueColumnInternal]
*/
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return {
if (column is ValueColumnInternal<*>) {
println("ValueColumnInternal")
// cache check, cache is dynamically created
val aggregator = this.aggregator ?: throw IllegalStateException("Aggregator is required")
val desiredStatisticNotConsideringParameters = column.statistics.getOrPut(aggregator.name) {
mutableMapOf<Map<String, ParameterValue?>, StatisticResult>()
}
// can't compare maps whose Values are Any? -> ParameterValue instead
val desiredStatistic = desiredStatisticNotConsideringParameters[aggregator.statisticsParameters]
// if desiredStatistic is null, statistic was never calculated
if (desiredStatistic != null) {
println("cache hit")
return desiredStatistic.value as Return
}
println("cache miss")
val statistic = aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
desiredStatisticNotConsideringParameters.put(aggregator.statisticsParameters, StatisticResult(statistic))
return aggregateSingleColumn(column)
}
return aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
}

/**
* Function that can give the return type of [aggregateSequence] as [KType], given the type of the input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.NumberInputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import org.jetbrains.kotlinx.dataframe.math.indexOfMax
import org.jetbrains.kotlinx.dataframe.math.indexOfMedian
import org.jetbrains.kotlinx.dataframe.math.indexOfMin
Expand Down Expand Up @@ -35,10 +36,12 @@ public object Aggregators {
getReturnType: CalculateReturnType,
indexOfResult: IndexOfResult<Value>,
stepOneSelector: Selector<Value, Return>,
statisticsParameters: Map<String, ParameterValue?>,
) = Aggregator(
aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType),
inputHandler = AnyInputHandler(),
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
statisticsParameters = statisticsParameters,
)

private fun <Value : Any, Return : Any?> flattenHybridForAny(
Expand Down Expand Up @@ -123,8 +126,9 @@ public object Aggregators {
by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = minTypeConversion,
stepOneSelector = { type -> minOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMin(type, skipNaN) },
stepOneSelector = { type -> minOrNull(type, skipNaN) },
statisticsParameters = mapOf<String, ParameterValue?>(Pair("skipNaN", ParameterValue(skipNaN))),
)
}

Expand All @@ -134,10 +138,13 @@ public object Aggregators {

public val max: AggregatorOptionSwitch1<Boolean, Comparable<Any>, Comparable<Any>?>
by withOneOption { skipNaN: Boolean ->
// the following function is 'getAggregator' of AggregatorOptionSwitch
// this is the fun that works with the parameter!
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = maxTypeConversion,
stepOneSelector = { type -> maxOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMax(type, skipNaN) },
statisticsParameters = mapOf<String, ParameterValue?>(Pair("skipNaN", ParameterValue(skipNaN))),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,52 @@ import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

@JvmInline
public value class StatisticResult(public val value: Any?)

public class ParameterValue(public val parameter: Any?) {

override fun equals(other: Any?): Boolean {
if (parameter is Boolean && other is Boolean) {
return this.parameter == other
}
if (parameter is Int && other is Int) {
return this.parameter == other
}
if (parameter is Double && other is Double) {
return this.parameter == other
}
return super.equals(other)
}

override fun hashCode(): Int {
if (parameter is Boolean) {
return this.parameter.hashCode()
}
if (parameter is Int) {
return this.parameter.hashCode()
}
if (parameter is Double) {
return this.parameter.hashCode()
}
return super.hashCode()
}
}

internal interface ValueColumnInternal<T> : ValueColumn<T> {
// val statistics: MutableMap<String, MutableMap<Map<String, Any?>, WrappedStatistic>>
val statistics: MutableMap<String, MutableMap<Map<String, ParameterValue?>, StatisticResult>>
}

internal open class ValueColumnImpl<T>(
values: List<T>,
name: String,
type: KType,
val defaultValue: T? = null,
distinct: Lazy<Set<T>>? = null,
) : DataColumnImpl<T>(values, name, type, distinct),
ValueColumn<T> {
ValueColumn<T>,
ValueColumnInternal<T> {

override fun distinct() = ValueColumnImpl(toSet().toList(), name, type, defaultValue, distinct)

Expand Down Expand Up @@ -48,10 +86,13 @@ internal open class ValueColumnImpl<T>(
override fun defaultValue() = defaultValue

override fun forceResolve() = ResolvingValueColumn(this)

override val statistics = mutableMapOf<String, MutableMap<Map<String, ParameterValue?>, StatisticResult>>()
}

internal class ResolvingValueColumn<T>(override val source: ValueColumn<T>) :
ValueColumn<T> by source,
ValueColumnInternal<T>,
ForceResolvedColumn<T> {

override fun resolve(context: ColumnResolutionContext) = super<ValueColumn>.resolve(context)
Expand All @@ -70,4 +111,6 @@ internal class ResolvingValueColumn<T>(override val source: ValueColumn<T>) :
override fun equals(other: Any?) = source.checkEquals(other)

override fun hashCode(): Int = source.hashCode()

override val statistics = mutableMapOf<String, MutableMap<Map<String, ParameterValue?>, StatisticResult>>()
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MaxTests {
fun `max with regular values`() {
val col = columnOf(5, 2, 8, 1, 9)
col.max() shouldBe 9
col.max() shouldBe 9
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.NoMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

Expand Down Expand Up @@ -42,6 +43,7 @@ public class Aggregator<in Value : Any, out Return : Any?>(
public val inputHandler: AggregatorInputHandler<Value, Return>,
public val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
public val name: String,
public val statisticsParameters: Map<String, ParameterValue?>,
) : AggregatorInputHandler<Value, Return> by inputHandler,
AggregatorMultipleColumnsHandler<Value, Return> by multipleColumnsHandler,
AggregatorAggregationHandler<Value, Return> by aggregationHandler {
Expand Down Expand Up @@ -75,13 +77,30 @@ public class Aggregator<in Value : Any, out Return : Any?>(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
statisticsParameters: Map<String, ParameterValue?>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
statisticsParameters = statisticsParameters,
)
}

internal operator fun <Value : Any, Return : Any?> invoke(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
emptyMap(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import org.jetbrains.kotlinx.dataframe.impl.columns.StatisticResult
import org.jetbrains.kotlinx.dataframe.impl.columns.ValueColumnInternal
import kotlin.reflect.KType

/**
Expand All @@ -26,13 +29,34 @@ public interface AggregatorAggregationHandler<in Value : Any, out Return : Any?>

/**
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
* Calls [aggregateSequence]. It tries to exploit a cache for statistics which is proper of
* [ValueColumnInternal]
*/
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return {
if (column is ValueColumnInternal<*>) {
// cache check, cache is dynamically created
val aggregator = this.aggregator ?: throw IllegalStateException("Aggregator is required")
val desiredStatisticNotConsideringParameters = column.statistics.getOrPut(aggregator.name) {
mutableMapOf<Map<String, ParameterValue?>, StatisticResult>()
}
// can't compare maps whose Values are Any? -> ParameterValue instead
val desiredStatistic = desiredStatisticNotConsideringParameters[aggregator.statisticsParameters]
// if desiredStatistic is null, statistic was never calculated
if (desiredStatistic != null) {
return desiredStatistic.value as Return
}
val statistic = aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
desiredStatisticNotConsideringParameters[aggregator.statisticsParameters] = StatisticResult(statistic)
return aggregateSingleColumn(column)
}
return aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
}

/**
* Function that can give the return type of [aggregateSequence] as [KType], given the type of the input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.NumberInputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import org.jetbrains.kotlinx.dataframe.math.indexOfMax
import org.jetbrains.kotlinx.dataframe.math.indexOfMedian
import org.jetbrains.kotlinx.dataframe.math.indexOfMin
Expand Down Expand Up @@ -35,10 +36,12 @@ public object Aggregators {
getReturnType: CalculateReturnType,
indexOfResult: IndexOfResult<Value>,
stepOneSelector: Selector<Value, Return>,
statisticsParameters: Map<String, ParameterValue?>,
) = Aggregator(
aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType),
inputHandler = AnyInputHandler(),
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
statisticsParameters = statisticsParameters,
)

private fun <Value : Any, Return : Any?> flattenHybridForAny(
Expand Down Expand Up @@ -117,8 +120,9 @@ public object Aggregators {
by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = minTypeConversion,
stepOneSelector = { type -> minOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMin(type, skipNaN) },
stepOneSelector = { type -> minOrNull(type, skipNaN) },
statisticsParameters = mapOf<String, ParameterValue?>(Pair("skipNaN", ParameterValue(skipNaN))),
)
}

Expand All @@ -132,6 +136,7 @@ public object Aggregators {
getReturnType = maxTypeConversion,
stepOneSelector = { type -> maxOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMax(type, skipNaN) },
statisticsParameters = mapOf<String, ParameterValue?>(Pair("skipNaN", ParameterValue(skipNaN))),
)
}

Expand Down
Loading
Loading