diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index a7cb2a8b14..a8cf1922ef 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.NoMultipleColumnsHandler import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler +import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue import kotlin.reflect.KType import kotlin.reflect.full.withNullability @@ -42,6 +43,7 @@ public class Aggregator( public val inputHandler: AggregatorInputHandler, public val multipleColumnsHandler: AggregatorMultipleColumnsHandler, public val name: String, + public val statisticsParameters: Map, ) : AggregatorInputHandler by inputHandler, AggregatorMultipleColumnsHandler by multipleColumnsHandler, AggregatorAggregationHandler by aggregationHandler { @@ -75,6 +77,7 @@ public class Aggregator( aggregationHandler: AggregatorAggregationHandler, inputHandler: AggregatorInputHandler, multipleColumnsHandler: AggregatorMultipleColumnsHandler, + statisticsParameters: Map, ): AggregatorProvider = AggregatorProvider { name -> Aggregator( @@ -82,6 +85,22 @@ public class Aggregator( inputHandler = inputHandler, multipleColumnsHandler = multipleColumnsHandler, name = name, + statisticsParameters = statisticsParameters, + ) + } + + internal operator fun invoke( + aggregationHandler: AggregatorAggregationHandler, + inputHandler: AggregatorInputHandler, + multipleColumnsHandler: AggregatorMultipleColumnsHandler, + ): AggregatorProvider = + AggregatorProvider { name -> + Aggregator( + aggregationHandler = aggregationHandler, + inputHandler = inputHandler, + multipleColumnsHandler = multipleColumnsHandler, + name = name, + emptyMap(), ) } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt index 7b1b0357eb..a7f65152c3 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt @@ -3,6 +3,9 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.api.asSequence import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler +import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue +import org.jetbrains.kotlinx.dataframe.impl.columns.StatisticResult +import org.jetbrains.kotlinx.dataframe.impl.columns.ValueColumnInternal import kotlin.reflect.KType /** @@ -26,13 +29,34 @@ public interface AggregatorAggregationHandler /** * Aggregates the data in the given column and computes a single resulting value. - * Calls [aggregateSequence]. + * Calls [aggregateSequence]. It tries to exploit a cache for statistics which is proper of + * [ValueColumnInternal] */ - public fun aggregateSingleColumn(column: DataColumn): Return = - aggregateSequence( + public fun aggregateSingleColumn(column: DataColumn): Return { + if (column is ValueColumnInternal<*>) { + // cache check, cache is dynamically created + val aggregator = this.aggregator ?: throw IllegalStateException("Aggregator is required") + val desiredStatisticNotConsideringParameters = column.statistics.getOrPut(aggregator.name) { + mutableMapOf, StatisticResult>() + } + // can't compare maps whose Values are Any? -> ParameterValue instead + val desiredStatistic = desiredStatisticNotConsideringParameters[aggregator.statisticsParameters] + // if desiredStatistic is null, statistic was never calculated + if (desiredStatistic != null) { + return desiredStatistic.value as Return + } + val statistic = aggregateSequence( + values = column.asSequence(), + valueType = column.type().toValueType(), + ) + desiredStatisticNotConsideringParameters[aggregator.statisticsParameters] = StatisticResult(statistic) + return aggregateSingleColumn(column) + } + return aggregateSequence( values = column.asSequence(), valueType = column.type().toValueType(), ) + } /** * Function that can give the return type of [aggregateSequence] as [KType], given the type of the input. diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 9648fed3ad..25660dd3d1 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.NumberInputHandler import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler +import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue import org.jetbrains.kotlinx.dataframe.math.indexOfMax import org.jetbrains.kotlinx.dataframe.math.indexOfMedian import org.jetbrains.kotlinx.dataframe.math.indexOfMin @@ -35,10 +36,12 @@ public object Aggregators { getReturnType: CalculateReturnType, indexOfResult: IndexOfResult, stepOneSelector: Selector, + statisticsParameters: Map, ) = Aggregator( aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType), inputHandler = AnyInputHandler(), multipleColumnsHandler = TwoStepMultipleColumnsHandler(), + statisticsParameters = statisticsParameters, ) private fun flattenHybridForAny( @@ -117,8 +120,9 @@ public object Aggregators { by withOneOption { skipNaN: Boolean -> twoStepSelectingForAny, Comparable?>( getReturnType = minTypeConversion, - stepOneSelector = { type -> minOrNull(type, skipNaN) }, indexOfResult = { type -> indexOfMin(type, skipNaN) }, + stepOneSelector = { type -> minOrNull(type, skipNaN) }, + statisticsParameters = mapOf(Pair("skipNaN", ParameterValue(skipNaN))), ) } @@ -132,6 +136,7 @@ public object Aggregators { getReturnType = maxTypeConversion, stepOneSelector = { type -> maxOrNull(type, skipNaN) }, indexOfResult = { type -> indexOfMax(type, skipNaN) }, + statisticsParameters = mapOf(Pair("skipNaN", ParameterValue(skipNaN))), ) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt index f758360d1f..071801a5d8 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/ValueColumnImpl.kt @@ -8,6 +8,32 @@ import org.jetbrains.kotlinx.dataframe.columns.ValueColumn import kotlin.reflect.KType import kotlin.reflect.full.withNullability +@JvmInline +internal value class StatisticResult(val value: Any?) + +public class ParameterValue(public val parameter: Any?) { + + override fun equals(other: Any?): Boolean { + val otherAsParameterValue = other as ParameterValue? + val that = otherAsParameterValue?.parameter + if (parameter is Boolean && that is Boolean) { + return this.parameter == that + } + return super.equals(other) + } + + override fun hashCode(): Int { + if (parameter is Boolean?) { + return this.parameter.hashCode() + } + return super.hashCode() + } +} + +internal interface ValueColumnInternal : ValueColumn { + val statistics: MutableMap, StatisticResult>> +} + internal open class ValueColumnImpl( values: List, name: String, @@ -15,7 +41,8 @@ internal open class ValueColumnImpl( val defaultValue: T? = null, distinct: Lazy>? = null, ) : DataColumnImpl(values, name, type, distinct), - ValueColumn { + ValueColumn, + ValueColumnInternal { override fun distinct() = ValueColumnImpl(toSet().toList(), name, type, defaultValue, distinct) @@ -48,10 +75,13 @@ internal open class ValueColumnImpl( override fun defaultValue() = defaultValue override fun forceResolve() = ResolvingValueColumn(this) + + override val statistics = mutableMapOf, StatisticResult>>() } internal class ResolvingValueColumn(override val source: ValueColumn) : ValueColumn by source, + ValueColumnInternal, ForceResolvedColumn { override fun resolve(context: ColumnResolutionContext) = super.resolve(context) @@ -70,4 +100,6 @@ internal class ResolvingValueColumn(override val source: ValueColumn) : override fun equals(other: Any?) = source.checkEquals(other) override fun hashCode(): Int = source.hashCode() + + override val statistics = mutableMapOf, StatisticResult>>() }