diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index d0e326b843..beb638d21d 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -2226,6 +2226,7 @@ def forward( # noqa: C901 op_id=self.uuid, per_sample_weights=per_sample_weights, batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + feature_table_map=self.feature_table_map, ) if not is_torchdynamo_compiling(): diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py index 8298d5f68c..bb6851f30c 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py @@ -52,6 +52,8 @@ class TBEDataConfig: Ds: Optional[list[int]] = None # Maximum number of indices max_indices: Optional[int] = None # Maximum number of indices + # Map from feature index to table index [T] + feature_table_map: Optional[list[int]] = None def __post_init__(self) -> None: if isinstance(self.D, list): diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py index 7d145c3136..d8d6d08cf3 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py @@ -142,6 +142,8 @@ class PoolingParams: sigma_L: Optional[int] = None # [Optional] Distribution of embedding sequence lengths (normal, uniform) length_distribution: Optional[str] = "normal" + # [Optional] List of pooling factors per table (average bag size per table) + Ls: Optional[list[float]] = None @classmethod # pyre-ignore [3] diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py index 13680e2530..f00dc7e233 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py @@ -145,6 +145,7 @@ def extract_params( offsets: torch.Tensor, per_sample_weights: Optional[torch.Tensor] = None, batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + feature_table_map: Optional[list[int]] = None, ) -> TBEDataConfig: """ Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig. @@ -156,6 +157,7 @@ def extract_params( offsets (torch.Tensor): The input offsets tensor. per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None. batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None. + feature_table_map (Optional[List[int]], optional): Map from feature index to table index. Defaults to None. Returns: TBEDataConfig: The configuration data for TBE benchmarking. @@ -200,6 +202,11 @@ def extract_params( heavy_hitters, q, s, indices.dtype, offsets.dtype ) + # Compute batch sizes per feature (Bs) + Bs: Optional[list[int]] = None + if batch_size_per_feature_per_rank: + Bs = [sum(batch_size_per_feature_per_rank[f]) for f in range(len(Es))] + # Compute batch parameters batch_params = BatchParams( B=int((offsets.numel() - 1) // T), @@ -226,11 +233,33 @@ def extract_params( if batch_size_per_feature_per_rank else None ), + Bs=Bs, ) # Compute pooling parameters bag_sizes = offsets[1:] - offsets[:-1] - mixed_bag_sizes = len(set(bag_sizes)) > 1 + bag_sizes_list = bag_sizes.tolist() + mixed_bag_sizes = len(set(bag_sizes_list)) > 1 + + # Compute per-table pooling factors (Ls) + Ls: list[float] = [] + pointer_counter = 0 + if batch_size_per_feature_per_rank and Bs: + for batch_size in Bs: + current_L = 0 + for _ in range(batch_size): + current_L += bag_sizes_list[pointer_counter] + pointer_counter += 1 + Ls.append(current_L / batch_size if batch_size > 0 else 0.0) + else: + batch_size = int(len(bag_sizes_list) // len(Es)) + for _ in range(len(Es)): + current_L = 0 + for _ in range(batch_size): + current_L += bag_sizes_list[pointer_counter] + pointer_counter += 1 + Ls.append(current_L / batch_size if batch_size > 0 else 0.0) + pooling_params = PoolingParams( L=( int(torch.ceil(torch.mean(bag_sizes.float()))) @@ -243,18 +272,22 @@ def extract_params( else None ), length_distribution=("normal" if mixed_bag_sizes else None), + Ls=Ls, ) return TBEDataConfig( T=T, E=E, D=D, + Es=Es, + Ds=Ds, mixed_dim=mixed_dim, weighted=(per_sample_weights is not None), batch_params=batch_params, indices_params=indices_params, pooling_params=pooling_params, use_cpu=(not torch.cuda.is_available()), + feature_table_map=feature_table_map, ) def report_stats( @@ -267,6 +300,7 @@ def report_stats( op_id: str = "", per_sample_weights: Optional[torch.Tensor] = None, batch_size_per_feature_per_rank: Optional[list[list[int]]] = None, + feature_table_map: Optional[list[int]] = None, ) -> None: """ Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore. @@ -280,6 +314,7 @@ def report_stats( op_id (str, optional): The operation identifier. Defaults to an empty string. per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None. batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None. + feature_table_map (Optional[List[int]], optional): Map from feature index to table index. Defaults to None. """ if ( (iteration - self.report_iter_start) % self.report_interval == 0 @@ -299,41 +334,11 @@ def report_stats( offsets=offsets, per_sample_weights=per_sample_weights, batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + feature_table_map=feature_table_map, ) - # Ad-hoc fix for adding Es and Ds to JSON output - # TODO: Remove this once we moved Es and Ds to be part of TBEDataConfig - adhoc_config = config.dict() - adhoc_config["Es"] = feature_rows.tolist() - adhoc_config["Ds"] = feature_dims.tolist() - if batch_size_per_feature_per_rank: - adhoc_config["Bs"] = [ - sum(batch_size_per_feature_per_rank[f]) - for f in range(len(adhoc_config["Es"])) - ] - - bag_sizes = (offsets[1:] - offsets[:-1]).tolist() - adhoc_config["Ls"] = [] - pointer_counter = 0 - if batch_size_per_feature_per_rank: - for batchs_size in adhoc_config["Bs"]: - current_L = 0 - for _i in range(batchs_size): - current_L += bag_sizes[pointer_counter] - pointer_counter += 1 - adhoc_config["Ls"].append(current_L / batchs_size) - else: - batch_size = int(len(bag_sizes) // len(adhoc_config["Es"])) - - for _j in range(len(adhoc_config["Es"])): - current_L = 0 - for _i in range(batch_size): - current_L += bag_sizes[pointer_counter] - pointer_counter += 1 - adhoc_config["Ls"].append(current_L / batch_size) - # Write the TBE config to FileStore self.filestore.write( f"{self.path_prefix}/tbe-{op_id}-config-estimation-{iteration}.json", - io.BytesIO(json.dumps(adhoc_config, indent=2).encode()), + io.BytesIO(json.dumps(config.dict(), indent=2).encode()), ) diff --git a/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py b/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py index 6e4cf2e1e4..501b2c098d 100644 --- a/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py +++ b/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py @@ -129,14 +129,19 @@ def test_report_stats( # Generate indices and offsets request = generate_requests(tbeconfig, 1)[0] + # Generate feature_table_map (identity mapping for this test) + feature_table_map = list(range(T)) + # Call the extract_params method extracted_config = reporter.extract_params( feature_rows=embedding_op.rows_per_table, feature_dims=embedding_op.feature_dims, indices=request.indices, offsets=request.offsets, + feature_table_map=feature_table_map, ) + # Verify matching config parameters assert ( extracted_config.T == tbeconfig.T and extracted_config.E == tbeconfig.E @@ -145,6 +150,7 @@ def test_report_stats( and extracted_config.batch_params.B == tbeconfig.batch_params.B and extracted_config.mixed_dim == tbeconfig.mixed_dim and extracted_config.weighted == tbeconfig.weighted + and extracted_config.feature_table_map == feature_table_map and extracted_config.indices_params.index_dtype == tbeconfig.indices_params.index_dtype and extracted_config.indices_params.offset_dtype