From 044b973270cd214e757fc6fb0460c6c06d2ef13e Mon Sep 17 00:00:00 2001 From: Ruslan Krenzler Date: Wed, 6 Aug 2025 09:41:34 +0200 Subject: [PATCH] Fix Issue#83, limit number of segments when splitting high-dimensional key space Reduce the number of segments returned by choose_checkpoints. Note, even if each key dimension is split only once, a high number of key dimensions can still result in an excessive number of segments. This significantly lowers the performance of reladiff. For example, 20 key columns may lead to 1.048.576 segments. --- reladiff/table_segment.py | 36 +++++++++++++++++++++++++++++------- tests/test_diff_tables.py | 38 +++++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/reladiff/table_segment.py b/reladiff/table_segment.py index f64247c..73ce2a2 100644 --- a/reladiff/table_segment.py +++ b/reladiff/table_segment.py @@ -45,9 +45,25 @@ def int_product(nums: List[int]) -> int: return p -def split_compound_key_space(mn: Vector, mx: Vector, count: int) -> List[List[DbKey]]: - """Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points.""" - return [split_key_space(mn_k, mx_k, count) for mn_k, mx_k in safezip(mn, mx)] +def split_compound_key_space(mn: Vector, mx: Vector, max_splits_per_column: int, max_segments: int) -> List[List[DbKey]]: + """Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points. + + Do not split the space into more than `max_segments` segments. + """ + grids = [] + segments_left = max_segments + for mn_k, mx_k in safezip(mn, mx): + if segments_left > 1: + # Note, n splits correspond to n-1 segments. + grid = split_key_space(mn_k, mx_k, min(max_splits_per_column, segments_left-1)) + grids.append(grid) + # The total number of segments in a multidimensional grid is + # the product of the number of the segments in each dimension. + segments_left //= (len(grid) - 1) + else: + # Stop splitting, return left and right bounds. + grids.append([mn_k, mx_k]) + return grids def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector]]: @@ -167,7 +183,7 @@ def with_schema(self, refine: bool = True, allow_empty_table: bool = False) -> " return self._with_raw_schema( self.database.query_table_schema(self.table_path), refine=refine, allow_empty_table=allow_empty_table ) - + def _cast_col_value(self, col, value): """Cast the value to the right type, based on the type of the column @@ -209,15 +225,21 @@ def get_values(self) -> list: select = self.make_select().select(*self._relevant_columns_repr) return self.database.query(select, List[Tuple]) + def choose_checkpoints(self, count: int) -> List[List[DbKey]]: "Suggests a bunch of evenly-spaced checkpoints to split by, including start, end." assert self.is_bounded - # Take Nth root of count, to approximate the appropriate box size - count = int(count ** (1 / len(self.key_columns))) or 1 + # Take Nth root of count, to approximate the appropriate box size. + # This is a rough estimation of splits per dimension, because + # * If the min_key and max_key are too close no splits are needed. + # * If there are too many key-columns then even a single split per column result into + # a very huge number of segements, larger than `count`. + # For example 20 key-columns can lead to a 2**20 = 1 048 576 segments. + max_split_per_column = int(count ** (1 / len(self.key_columns))) or 1 - return split_compound_key_space(self.min_key, self.max_key, count) + return split_compound_key_space(self.min_key, self.max_key, max_split_per_column, count+1) def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableSegment"]: "Split the current TableSegment to a bunch of smaller ones, separated by the given checkpoints" diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index d668abd..70f2744 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +import random from typing import Callable import uuid import unittest @@ -8,7 +9,7 @@ from reladiff.hashdiff_tables import HashDiffer from reladiff.joindiff_tables import JoinDiffer -from reladiff.table_segment import TableSegment, split_space, Vector +from reladiff.table_segment import TableSegment, split_space, Vector, split_compound_key_space from reladiff import databases as db from .common import str_to_checksum, test_each_database_in_list, DiffTestCase, table_segment @@ -37,6 +38,41 @@ def test_split_space(self): r = split_space(i, j + i + n, n) assert len(r) == n, f"split_space({i}, {j+n}, {n}) = {(r)}" + def test_split_compound_key_space(self): + # Test that the total number of segments does not exceed max number. + random.seed(12345) + # Test 1 to 30 key dimensions. + for n in range(1, 30+1): + # Do 10 random tests. + for _ in range(10): + keys_a = random.choices([1,2,3,4], k=n) + keys_b = random.choices([1,2,3,4], k=n) + split_per_dim = random.randint(1, 32) + max_segments = random.randint(2, 32) + min_keys = [min(a, b) for a, b in zip(keys_a, keys_b)] + # Assert that the max key is always larger than the key. + max_keys = [max(a, b)+1 for a, b in zip(keys_a, keys_b)] + grid = split_compound_key_space(Vector(min_keys), Vector(max_keys), split_per_dim, max_segments) + segments = 1 + for dim in grid: + segments *= len(dim)-1 # n points correspond to n-1 segments. + self.assertLessEqual(segments, max_segments) + + # Calculate maximum number of splits possible and ensure, that they can be achieved if requested. + for n in range(1, 30+1): + # Do 10 random tests. + for _ in range(10): + splits_per_dim = random.randint(1, 3) + max_segments = (splits_per_dim+1)**n + # Ensure that a split into `splits_per_dim` segments is possible. + min_keys = [1]*n + max_keys = [2+splits_per_dim]*n # n splits results to n+1 segments described with n+2 points. + grid = split_compound_key_space(Vector(min_keys), Vector(max_keys), splits_per_dim, max_segments) + segments = 1 + for dim in grid: + segments *= len(dim)-1 # n points correspond to n-1 segments. + self.assertEqual(segments, max_segments) + @test_each_database class TestDates(DiffTestCase):