diff --git a/reladiff/table_segment.py b/reladiff/table_segment.py index f64247c..73ce2a2 100644 --- a/reladiff/table_segment.py +++ b/reladiff/table_segment.py @@ -45,9 +45,25 @@ def int_product(nums: List[int]) -> int: return p -def split_compound_key_space(mn: Vector, mx: Vector, count: int) -> List[List[DbKey]]: - """Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points.""" - return [split_key_space(mn_k, mx_k, count) for mn_k, mx_k in safezip(mn, mx)] +def split_compound_key_space(mn: Vector, mx: Vector, max_splits_per_column: int, max_segments: int) -> List[List[DbKey]]: + """Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points. + + Do not split the space into more than `max_segments` segments. + """ + grids = [] + segments_left = max_segments + for mn_k, mx_k in safezip(mn, mx): + if segments_left > 1: + # Note, n splits correspond to n-1 segments. + grid = split_key_space(mn_k, mx_k, min(max_splits_per_column, segments_left-1)) + grids.append(grid) + # The total number of segments in a multidimensional grid is + # the product of the number of the segments in each dimension. + segments_left //= (len(grid) - 1) + else: + # Stop splitting, return left and right bounds. + grids.append([mn_k, mx_k]) + return grids def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector]]: @@ -167,7 +183,7 @@ def with_schema(self, refine: bool = True, allow_empty_table: bool = False) -> " return self._with_raw_schema( self.database.query_table_schema(self.table_path), refine=refine, allow_empty_table=allow_empty_table ) - + def _cast_col_value(self, col, value): """Cast the value to the right type, based on the type of the column @@ -209,15 +225,21 @@ def get_values(self) -> list: select = self.make_select().select(*self._relevant_columns_repr) return self.database.query(select, List[Tuple]) + def choose_checkpoints(self, count: int) -> List[List[DbKey]]: "Suggests a bunch of evenly-spaced checkpoints to split by, including start, end." assert self.is_bounded - # Take Nth root of count, to approximate the appropriate box size - count = int(count ** (1 / len(self.key_columns))) or 1 + # Take Nth root of count, to approximate the appropriate box size. + # This is a rough estimation of splits per dimension, because + # * If the min_key and max_key are too close no splits are needed. + # * If there are too many key-columns then even a single split per column result into + # a very huge number of segements, larger than `count`. + # For example 20 key-columns can lead to a 2**20 = 1 048 576 segments. + max_split_per_column = int(count ** (1 / len(self.key_columns))) or 1 - return split_compound_key_space(self.min_key, self.max_key, count) + return split_compound_key_space(self.min_key, self.max_key, max_split_per_column, count+1) def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableSegment"]: "Split the current TableSegment to a bunch of smaller ones, separated by the given checkpoints" diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index d668abd..70f2744 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +import random from typing import Callable import uuid import unittest @@ -8,7 +9,7 @@ from reladiff.hashdiff_tables import HashDiffer from reladiff.joindiff_tables import JoinDiffer -from reladiff.table_segment import TableSegment, split_space, Vector +from reladiff.table_segment import TableSegment, split_space, Vector, split_compound_key_space from reladiff import databases as db from .common import str_to_checksum, test_each_database_in_list, DiffTestCase, table_segment @@ -37,6 +38,41 @@ def test_split_space(self): r = split_space(i, j + i + n, n) assert len(r) == n, f"split_space({i}, {j+n}, {n}) = {(r)}" + def test_split_compound_key_space(self): + # Test that the total number of segments does not exceed max number. + random.seed(12345) + # Test 1 to 30 key dimensions. + for n in range(1, 30+1): + # Do 10 random tests. + for _ in range(10): + keys_a = random.choices([1,2,3,4], k=n) + keys_b = random.choices([1,2,3,4], k=n) + split_per_dim = random.randint(1, 32) + max_segments = random.randint(2, 32) + min_keys = [min(a, b) for a, b in zip(keys_a, keys_b)] + # Assert that the max key is always larger than the key. + max_keys = [max(a, b)+1 for a, b in zip(keys_a, keys_b)] + grid = split_compound_key_space(Vector(min_keys), Vector(max_keys), split_per_dim, max_segments) + segments = 1 + for dim in grid: + segments *= len(dim)-1 # n points correspond to n-1 segments. + self.assertLessEqual(segments, max_segments) + + # Calculate maximum number of splits possible and ensure, that they can be achieved if requested. + for n in range(1, 30+1): + # Do 10 random tests. + for _ in range(10): + splits_per_dim = random.randint(1, 3) + max_segments = (splits_per_dim+1)**n + # Ensure that a split into `splits_per_dim` segments is possible. + min_keys = [1]*n + max_keys = [2+splits_per_dim]*n # n splits results to n+1 segments described with n+2 points. + grid = split_compound_key_space(Vector(min_keys), Vector(max_keys), splits_per_dim, max_segments) + segments = 1 + for dim in grid: + segments *= len(dim)-1 # n points correspond to n-1 segments. + self.assertEqual(segments, max_segments) + @test_each_database class TestDates(DiffTestCase):