Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions reladiff/table_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,25 @@ def int_product(nums: List[int]) -> int:
return p


def split_compound_key_space(mn: Vector, mx: Vector, count: int) -> List[List[DbKey]]:
"""Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points."""
return [split_key_space(mn_k, mx_k, count) for mn_k, mx_k in safezip(mn, mx)]
def split_compound_key_space(mn: Vector, mx: Vector, max_splits_per_column: int, max_segments: int) -> List[List[DbKey]]:
"""Returns a list of split-points for each key dimension, essentially returning an N-dimensional grid of split points.

Do not split the space into more than `max_segments` segments.
"""
grids = []
segments_left = max_segments
for mn_k, mx_k in safezip(mn, mx):
if segments_left > 1:
# Note, n splits correspond to n-1 segments.
grid = split_key_space(mn_k, mx_k, min(max_splits_per_column, segments_left-1))
grids.append(grid)
# The total number of segments in a multidimensional grid is
# the product of the number of the segments in each dimension.
segments_left //= (len(grid) - 1)
else:
# Stop splitting, return left and right bounds.
grids.append([mn_k, mx_k])
return grids


def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector]]:
Expand Down Expand Up @@ -167,7 +183,7 @@ def with_schema(self, refine: bool = True, allow_empty_table: bool = False) -> "
return self._with_raw_schema(
self.database.query_table_schema(self.table_path), refine=refine, allow_empty_table=allow_empty_table
)

def _cast_col_value(self, col, value):
"""Cast the value to the right type, based on the type of the column

Expand Down Expand Up @@ -209,15 +225,21 @@ def get_values(self) -> list:
select = self.make_select().select(*self._relevant_columns_repr)
return self.database.query(select, List[Tuple])


def choose_checkpoints(self, count: int) -> List[List[DbKey]]:
"Suggests a bunch of evenly-spaced checkpoints to split by, including start, end."

assert self.is_bounded

# Take Nth root of count, to approximate the appropriate box size
count = int(count ** (1 / len(self.key_columns))) or 1
# Take Nth root of count, to approximate the appropriate box size.
# This is a rough estimation of splits per dimension, because
# * If the min_key and max_key are too close no splits are needed.
# * If there are too many key-columns then even a single split per column result into
# a very huge number of segements, larger than `count`.
# For example 20 key-columns can lead to a 2**20 = 1 048 576 segments.
max_split_per_column = int(count ** (1 / len(self.key_columns))) or 1

return split_compound_key_space(self.min_key, self.max_key, count)
return split_compound_key_space(self.min_key, self.max_key, max_split_per_column, count+1)

def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableSegment"]:
"Split the current TableSegment to a bunch of smaller ones, separated by the given checkpoints"
Expand Down
38 changes: 37 additions & 1 deletion tests/test_diff_tables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta
import random
from typing import Callable
import uuid
import unittest
Expand All @@ -8,7 +9,7 @@

from reladiff.hashdiff_tables import HashDiffer
from reladiff.joindiff_tables import JoinDiffer
from reladiff.table_segment import TableSegment, split_space, Vector
from reladiff.table_segment import TableSegment, split_space, Vector, split_compound_key_space
from reladiff import databases as db

from .common import str_to_checksum, test_each_database_in_list, DiffTestCase, table_segment
Expand Down Expand Up @@ -37,6 +38,41 @@ def test_split_space(self):
r = split_space(i, j + i + n, n)
assert len(r) == n, f"split_space({i}, {j+n}, {n}) = {(r)}"

def test_split_compound_key_space(self):
# Test that the total number of segments does not exceed max number.
random.seed(12345)
# Test 1 to 30 key dimensions.
for n in range(1, 30+1):
# Do 10 random tests.
for _ in range(10):
keys_a = random.choices([1,2,3,4], k=n)
keys_b = random.choices([1,2,3,4], k=n)
split_per_dim = random.randint(1, 32)
max_segments = random.randint(2, 32)
min_keys = [min(a, b) for a, b in zip(keys_a, keys_b)]
# Assert that the max key is always larger than the key.
max_keys = [max(a, b)+1 for a, b in zip(keys_a, keys_b)]
grid = split_compound_key_space(Vector(min_keys), Vector(max_keys), split_per_dim, max_segments)
segments = 1
for dim in grid:
segments *= len(dim)-1 # n points correspond to n-1 segments.
self.assertLessEqual(segments, max_segments)

# Calculate maximum number of splits possible and ensure, that they can be achieved if requested.
for n in range(1, 30+1):
# Do 10 random tests.
for _ in range(10):
splits_per_dim = random.randint(1, 3)
max_segments = (splits_per_dim+1)**n
# Ensure that a split into `splits_per_dim` segments is possible.
min_keys = [1]*n
max_keys = [2+splits_per_dim]*n # n splits results to n+1 segments described with n+2 points.
grid = split_compound_key_space(Vector(min_keys), Vector(max_keys), splits_per_dim, max_segments)
segments = 1
for dim in grid:
segments *= len(dim)-1 # n points correspond to n-1 segments.
self.assertEqual(segments, max_segments)


@test_each_database
class TestDates(DiffTestCase):
Expand Down