diff --git a/CHANGES.md b/CHANGES.md index afaf6a896cc4..e2dcf6e0f2ca 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,7 +69,6 @@ ## New Features / Improvements * (Python) Added exception chaining to preserve error context in CloudSQLEnrichmentHandler, processes utilities, and core transforms ([#37422](https://github.com/apache/beam/issues/37422)). -* (Python) Added `take(n)` convenience for PCollection: `beam.take(n)` and `pcoll.take(n)` to get the first N elements deterministically without Top.Of + FlatMap ([#X](https://github.com/apache/beam/issues/37429)). * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 6621d96127d4..ca9a662d399e 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -176,25 +176,6 @@ def from_(pcoll: PValue, is_bounded: Optional[bool] = None) -> 'PCollection': is_bounded = pcoll.is_bounded return PCollection(pcoll.pipeline, is_bounded=is_bounded) - def take(self, n: int) -> 'PCollection[T]': - """Takes the first N elements from this PCollection. - - This is a convenience method that returns a new PCollection containing - at most N elements from this PCollection. The elements are taken - deterministically (not randomly sampled). - - Args: - n: Number of elements to take. Must be a positive integer. - - Returns: - A new PCollection containing at most N elements. - - Example:: - first_10 = pcoll.take(10) - """ - from apache_beam.transforms import util - return self | util.take(n) - def to_runner_api( self, context: 'PipelineContext') -> beam_runner_api_pb2.PCollection: return beam_runner_api_pb2.PCollection( diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index dd14bd8f57bd..fbaab6b4ebbb 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -54,7 +54,6 @@ from apache_beam.pvalue import PCollection from apache_beam.transforms import window from apache_beam.transforms.combiners import CountCombineFn -from apache_beam.transforms.combiners import Top from apache_beam.transforms.core import CombinePerKey from apache_beam.transforms.core import Create from apache_beam.transforms.core import DoFn @@ -106,13 +105,11 @@ 'Reshuffle', 'Secret', 'ToString', - 'Take', 'Tee', 'Values', 'WithKeys', 'GroupIntoBatches', - 'WaitOn', - 'take', + 'WaitOn' ] K = TypeVar('K') @@ -1970,75 +1967,6 @@ def expand(self, input): )) -@typehints.with_input_types(T) -@typehints.with_output_types(T) -class Take(PTransform): - """Takes the first N elements from a PCollection. - - This transform returns a PCollection containing at most N elements from the - input PCollection. The elements are taken deterministically (not randomly - sampled). - - Args: - n: Number of elements to take. Must be a positive integer. - - Returns: - A PCollection containing at most N elements. - - Example:: - # Take first 10 elements - first_10 = pcoll | beam.take(10) - - # Or as a method - first_10 = pcoll.take(10) - """ - def __init__(self, n): - """Initializes Take transform. - - Args: - n: Number of elements to take. Must be positive. - """ - if n <= 0: - raise ValueError('n must be positive, got %d' % n) - self._n = n - - def expand(self, pcoll): - """Expands the Take transform. - - Args: - pcoll: Input PCollection. - - Returns: - A PCollection containing at most N elements. - """ - # Use Top.Of with a constant key to get first N elements deterministically. - # Top.Of returns a list, so we flatten it to get individual elements. - return ( - pcoll - | Top.Of(self._n, key=lambda x: 0).without_defaults() - | FlatMap(lambda elements: elements)) - - def default_label(self): - return 'Take(%d)' % self._n - - -def take(n): - """Convenience function for Take transform. - - Takes the first N elements from a PCollection. - - Args: - n: Number of elements to take. Must be positive. - - Returns: - A Take transform instance. - - Example:: - first_10 = pcoll | beam.take(10) - """ - return Take(n) - - class Reify(object): """PTransforms for converting between explicit and implicit form of various Beam values.""" diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 448ba8a7ad9d..7389568691cd 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1934,45 +1934,6 @@ def test_tostring_kvs_empty_delimeter(self): assert_that(result, equal_to(["one1", "two2"])) -class TakeTest(unittest.TestCase): - def test_take_function_syntax(self): - with TestPipeline() as p: - result = p | beam.Create([1, 2, 3, 4, 5]) | util.take(3) - assert_that(result, equal_to([1, 2, 3])) - - def test_take_method_syntax(self): - with TestPipeline() as p: - pcoll = p | beam.Create([10, 20, 30, 40, 50]) - result = pcoll.take(2) - assert_that(result, equal_to([10, 20])) - - def test_take_more_than_available(self): - with TestPipeline() as p: - result = p | beam.Create([1, 2, 3]) | util.take(10) - assert_that(result, equal_to([1, 2, 3])) - - def test_take_single_element(self): - with TestPipeline() as p: - result = p | beam.Create([100, 200, 300]) | util.take(1) - assert_that(result, equal_to([100])) - - def test_take_all_elements(self): - with TestPipeline() as p: - data = [1, 2, 3, 4, 5] - result = p | beam.Create(data) | util.take(len(data)) - assert_that(result, equal_to(data)) - - def test_take_invalid_n_zero(self): - with self.assertRaises(ValueError) as ctx: - util.Take(0) - self.assertIn('n must be positive', str(ctx.exception)) - - def test_take_invalid_n_negative(self): - with self.assertRaises(ValueError) as ctx: - util.Take(-1) - self.assertIn('n must be positive', str(ctx.exception)) - - class LogElementsTest(unittest.TestCase): @pytest.fixture(scope="function") def _capture_stdout_log(request, capsys):