diff --git a/lib/compiler/include/compiler/algorithm_config.dtg.toml b/lib/compiler/include/compiler/algorithm_config.dtg.toml index df08841384..b67ca2b02f 100644 --- a/lib/compiler/include/compiler/algorithm_config.dtg.toml +++ b/lib/compiler/include/compiler/algorithm_config.dtg.toml @@ -10,6 +10,7 @@ features = [ includes = [ "compiler/data_parallelism/data_parallelism_config.dtg.h", "compiler/unity_algorithm/unity_search_config.dtg.h", + "compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.h", ] [[values]] @@ -17,3 +18,6 @@ type = "::FlexFlow::DataParallelismConfig" [[values]] type = "::FlexFlow::UnitySearchConfig" + +[[values]] +type = "::FlexFlow::MCMCOverMappedPCGConfig" diff --git a/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h new file mode 100644 index 0000000000..b08ca57851 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H + +#include "compiler/search_result.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" + +namespace FlexFlow { +/** + * @brief Applies \p substitution to \p mapped_pcg at the location specified by + * \p match, returning the resulting SearchResult (mapped pcg) + * + * @param mapped_pcg + * @param substitution + * @param match The location at which to apply substitution. This location in + * sub_pcg should match substitution's PCGPattern. Likely created by running + * FlexFlow::find_pattern_matches(PCGPattern const &, + * SubParallelComputationGraph const &). + * @return SearchResult A mapped pcg similar to mapped_pcg, but with + * the subgraph of the pcg specified by match replaced with the result of the + * output expression of substitution and the machine mapping updated to account + * for the new output + */ +SearchResult apply_substitution_and_update_machine_mapping( + SearchResult const &mapped_pcg, + Substitution const &sub, + PCGPatternMatch const &match); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h new file mode 100644 index 0000000000..d05e3fab7c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H + +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/search_result.dtg.h" + +namespace FlexFlow { +std::optional + get_random_mapping(ParallelComputationGraph const &pcg, + MachineComputeSpecification const &resources, + DeviceType const &device_type); + +std::optional + get_random_mutation(SearchResult const &mapped_pcg, + MachineComputeSpecification const &resource, + DeviceType const &device_type); +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h new file mode 100644 index 0000000000..a3baa251e3 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H +#define _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H + +#include "compiler/mcmc/generic_mcmc_config.dtg.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/optional.h" +#include "utils/random_utils.h" + +namespace FlexFlow { + +// SamplingFn : State -> std::optional +// CostFn : State -> float + +template +State run_mcmc(State const &starting_state, + SamplingFn const &sampler, + CostFn const &cost, + GenericMCMCConfig const &search_config) { + State best_state = starting_state; + State current_state = best_state; + for (nonnegative_int i : nonnegative_range(search_config.num_iterations)) { + std::optional maybe_new_state = + transform(sampler(current_state), [&](State const &s) { + float delta = cost(s) - cost(best_state); + if (randf() < exp(-delta / search_config.temperature)) { + if (delta < 0) { + best_state = s; + } + return s; + } + return current_state; + }); + current_state = or_else(maybe_new_state, [&]() { return current_state; }); + } + return best_state; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_config.dtg.toml b/lib/compiler/include/compiler/mcmc/generic_mcmc_config.dtg.toml new file mode 100644 index 0000000000..bd55983d92 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_config.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "GenericMCMCConfig" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "temperature" +type = "float" + +[[fields]] +name = "num_iterations" +type = "::FlexFlow::nonnegative_int" \ No newline at end of file diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h new file mode 100644 index 0000000000..08c9c3470d --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H +#define _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H + +#include "compiler/cost_estimator/runtime_only_cost_estimator.h" +#include "compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.h" +#include "compiler/search_result.dtg.h" +#include "pcg/computation_graph.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution.h" + +namespace FlexFlow { + +SearchResult + mcmc_over_mapped_pcg(ParallelComputationGraph const &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &machine_spec, + MCMCOverMappedPCGConfig const &search_config); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.toml b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.toml new file mode 100644 index 0000000000..99320f735e --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "MCMCOverMappedPCGConfig" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "temperature" +type = "float" + +[[fields]] +name = "num_iterations" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "substitution_frequency" +type = "float" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" \ No newline at end of file diff --git a/lib/compiler/include/compiler/search_result.struct.toml b/lib/compiler/include/compiler/search_result.struct.toml new file mode 100644 index 0000000000..7e7e59d7c9 --- /dev/null +++ b/lib/compiler/include/compiler/search_result.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SearchResult" +features = [ +] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.h", + "compiler/machine_mapping/machine_mapping.h", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" \ No newline at end of file diff --git a/lib/compiler/src/compiler/compiler.cc b/lib/compiler/src/compiler/compiler.cc index 714cda3f86..a48a84fbd1 100644 --- a/lib/compiler/src/compiler/compiler.cc +++ b/lib/compiler/src/compiler/compiler.cc @@ -1,5 +1,6 @@ #include "compiler/compiler.h" #include "compiler/cost_estimator/runtime_only_cost_estimator_from_cost_estimator.h" +#include "compiler/mcmc/mcmc_over_mapped_pcg.h" #include "compiler/unity_algorithm/unity_algorithm.h" #include "pcg/pcg_from_computation_graph.h" #include "utils/overload.h" @@ -24,6 +25,15 @@ SearchResult optimize(ComputationGraph const &computation_graph, machine_specification.compute_specification, config); }, + [&](MCMCOverMappedPCGConfig const &config) { + ParallelComputationGraph pcg = + pcg_from_computation_graph(computation_graph); + return mcmc_over_mapped_pcg( + pcg, + runtime_only_cost_estimator_from_cost_estimator(cost_estimator), + machine_specification, + config); + }, }); } diff --git a/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc b/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc index ec369f2f03..c3d9ae7bfb 100644 --- a/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc @@ -59,6 +59,8 @@ static std::unordered_set product(transform(tensor_dims, [](positive_int num_devices) { return nonnegative_int{num_devices.int_from_positive_int() - 1}; })); + min_num_devices_with_full_stride_volume = + std::max(min_num_devices_with_full_stride_volume, 1_n); return ceildiv(total_devices, positive_int{min_num_devices_with_full_stride_volume}); }; diff --git a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc new file mode 100644 index 0000000000..7ccab2fac9 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc @@ -0,0 +1,78 @@ +#include "compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/apply_substitution/evaluate_substitution_output.h" +#include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph_data.dtg.h" +#include "substitutions/sub_parallel_computation_graph_edge.h" +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/filter.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/random_utils.h" +#include + +namespace FlexFlow { + +SearchResult apply_substitution_and_update_machine_mapping( + SearchResult const &mapped_pcg, + Substitution const &sub, + PCGPatternMatch const &match) { + SubParallelComputationGraph spcg = sub_pcg_from_full_pcg(mapped_pcg.pcg); + + std::pair + substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + + SubParallelComputationGraph post_substitution_graph = + apply_substitution_from_output_result( + substitution_output_result, spcg, sub, match); + + std::unordered_map post_node_data = + get_sub_pcg_data(post_substitution_graph).node_data; + + std::unordered_set + substitution_output_parallel_layers = + get_parallel_layers(substitution_output_result.first); + + std::unordered_map machine_views = + mapped_pcg.machine_mapping.machine_views; + + std::unordered_set matched_nodes = + unordered_set_of(values(match.node_assignment)); + + std::vector substituted_machine_views = vector_of( + transform(matched_nodes, [&](parallel_layer_guid_t const &node) { + return machine_views.at(node); + })); + + for (parallel_layer_guid_t layer : substitution_output_parallel_layers) { + machine_views.insert_or_assign(layer, + select_random(substituted_machine_views)); + } + + ASSERT(is_subseteq_of(keys(post_node_data), keys(machine_views))); + + std::unordered_map + post_node_machine_views = + filter(machine_views, + [&](std::pair const &p) { + return post_node_data.count(p.first); + }); + + ASSERT(keys(post_node_data) == keys(post_node_machine_views)); + + return SearchResult{ + pcg_from_sub_pcg_by_dropping_inputs(post_substitution_graph), + MachineMapping{post_node_machine_views}}; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc new file mode 100644 index 0000000000..47639ff88a --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc @@ -0,0 +1,55 @@ +#include "compiler/machine_mapping/machine_mapping_mutation_set.h" +#include "compiler/machine_mapping/allowed_machine_views.h" +#include "compiler/machine_mapping/machine_view.h" +#include "op-attrs/operator_task_space.h" +#include "pcg/machine_compute_resource_slice.h" +#include "utils/containers/vector_of.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/random_utils.h" + +namespace FlexFlow { + +std::optional + get_random_mapping(ParallelComputationGraph const &pcg, + MachineComputeSpecification const &resources, + DeviceType const &device_type) { + std::vector layers = topological_ordering(pcg); + std::unordered_map machine_views; + for (parallel_layer_guid_t layer : layers) { + OperatorTaskSpace task = get_operator_task_space(pcg, layer); + std::unordered_set allowed_machine_views = + get_allowed_machine_views( + compute_slice_from_specification(resources), task, DeviceType::GPU); + if (allowed_machine_views.empty()) { + return std::nullopt; + } + machine_views.insert( + {layer, select_random(vector_of(allowed_machine_views))}); + } + return MachineMapping{machine_views}; +} + +std::optional + get_random_mutation(SearchResult const &mapped_pcg, + MachineComputeSpecification const &resources, + DeviceType const &device_type) { + ParallelComputationGraph pcg = mapped_pcg.pcg; + std::vector layers = topological_ordering(pcg); + if (layers.size() == 0) { + return std::nullopt; + } + parallel_layer_guid_t random_layer = select_random(layers); + + MachineMapping machine_mapping = mapped_pcg.machine_mapping; + MachineView machine_view = machine_mapping.machine_views.at(random_layer); + OperatorTaskSpace task = get_operator_task_space(pcg, random_layer); + + std::vector allowed_machine_views = + vector_of(get_allowed_machine_views( + compute_slice_from_specification(resources), task, device_type)); + MachineView random_new_machine_view = select_random(allowed_machine_views); + + machine_mapping.machine_views.at(random_layer) = random_new_machine_view; + return machine_mapping; +} +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc new file mode 100644 index 0000000000..2c8fcea86d --- /dev/null +++ b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -0,0 +1,15 @@ +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using State = value_type<0>; +using SamplingFn = std::function(State)>; +using CostFn = std::function; + +template State run_mcmc(State const &starting_state, + SamplingFn const &sampler, + CostFn const &cost, + GenericMCMCConfig const &search_config); + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc new file mode 100644 index 0000000000..583a60b1ad --- /dev/null +++ b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -0,0 +1,68 @@ +#include "compiler/mcmc/mcmc_over_mapped_pcg.h" +#include "compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_mutation_set.h" +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "compiler/search_result.h" +#include "compiler/task_graph_simulator/task_simulator.h" +#include "pcg/machine_compute_resource_slice.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/unity_substitution_set.h" +#include "utils/optional.h" +#include "utils/random_utils.h" +#include + +namespace FlexFlow { + +SearchResult + mcmc_over_mapped_pcg(ParallelComputationGraph const &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &machine_spec, + MCMCOverMappedPCGConfig const &search_config) { + MachineComputeSpecification compute_spec = machine_spec.compute_specification; + std::vector substitutions = get_substitution_set(compute_spec); + MachineMapping random_mapping = assert_unwrap( + get_random_mapping(pcg, compute_spec, search_config.device_type)); + SearchResult starting_state = SearchResult{pcg, random_mapping}; + + auto sampler = [&](SearchResult mapped_pcg) -> std::optional { + // applies substitution with substitution_frequency probability + // applies machine mapping mutation with (1 - substitution_frequency) + // probability + ASSERT(search_config.substitution_frequency >= 0 && + search_config.substitution_frequency <= 1); + if (randf() < search_config.substitution_frequency) { + Substitution random_substitution = + assert_unwrap(get_random_substitution(compute_spec)); + std::optional maybe_pattern_match = + get_random_pattern_match(random_substitution.pcg_pattern, + sub_pcg_from_full_pcg(mapped_pcg.pcg)); + return transform(maybe_pattern_match, [&](PCGPatternMatch match) { + return apply_substitution_and_update_machine_mapping( + mapped_pcg, random_substitution, match); + }); + } else { + MachineMapping new_machine_mapping = assert_unwrap(get_random_mutation( + mapped_pcg, compute_spec, search_config.device_type)); + return SearchResult{mapped_pcg.pcg, new_machine_mapping}; + } + }; + + auto cost = [&](SearchResult mapped_pcg) -> float { + return task_simulator_estimate_forward_pass_time(mapped_pcg.pcg, + cost_estimator, + mapped_pcg.machine_mapping, + machine_spec) + .unwrap_milliseconds(); + }; + + GenericMCMCConfig config = + GenericMCMCConfig{/*temperature*/ search_config.temperature, + /*num_iterations*/ search_config.num_iterations}; + + SearchResult result = run_mcmc(starting_state, sampler, cost, config); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc new file mode 100644 index 0000000000..b21ee4333f --- /dev/null +++ b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -0,0 +1,29 @@ +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "doctest/doctest.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("generic_mcmc_algorithm") { + float starting_state = 0.1; + auto sampler = [](float x) -> std::optional { + float new_x = x + (randf() - 0.5); + if (new_x < 0) { + return std::nullopt; + } + if (new_x > 1) { + return std::nullopt; + } + return new_x; + }; + auto cost = [](float x) { return (x - 0.5) * (x - 0.5); }; + GenericMCMCConfig config = GenericMCMCConfig{/*temperature=*/1.0, + /*num_iterations=*/100_n}; + float answer = run_mcmc(starting_state, sampler, cost, config); + float error = cost(answer); + CHECK(answer > 0.47); + CHECK(answer < 0.53); + CHECK(error >= 0); + CHECK(error < 0.001); + } +} diff --git a/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc new file mode 100644 index 0000000000..2584f6b3a6 --- /dev/null +++ b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -0,0 +1,97 @@ +#include "compiler/mcmc/mcmc_over_mapped_pcg.h" +#include "compiler/task_graph_simulator/task_simulator.h" +#include "doctest/doctest.h" +#include "internal/runtime_only_cost_estimator_for_test.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/replica_type.dtg.h" +#include "op-attrs/shard_parallel_dim.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/integer_conversions.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("mcmc_over_mapped_pcg") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{ + FFOrdered{32_p, 64_p}, + }, + DataType::FLOAT, + }; + tensor_guid_t t = b.create_input(input_tensor_shape, CreateGrad::YES); + t = b.dense(t, + /*outDim=*/16_p, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + t = b.relu(t); + t = b.dense(t, + /*outDim=*/8_p, + /*activation=*/Activation::RELU); + return b.computation_graph; + }(); + + ParallelComputationGraph pcg = pcg_from_computation_graph(cg); + + RuntimeOnlyCostEstimator cost_estimator = + make_fake_constant_runtime_only_cost_estimator( + /*forward_op_cost=*/10_ms, + /*backward_op_cost=*/10_ms, + /*comm_cost=*/1_ms); + + MachineSpecification full_machine_spec = MachineSpecification{ + MachineComputeSpecification{ + /*num_nodes=*/3_p, + /*num_cpus_per_node=*/3_p, + /*num_gpus_per_node=*/3_p, + }, + MachineInterconnectSpecification{ + /*inter_node_bandwidth=*/bytes_per_second_t{1.0f}, + /*intra_node_bandwidth=*/bytes_per_second_t{1.0f}, + }, + }; + + MCMCOverMappedPCGConfig no_search = + MCMCOverMappedPCGConfig{/*temperature=*/1.0, + /*num_iterations=*/1_n, + /*substitution_frequency=*/0.2, + /*device_type=*/DeviceType::GPU}; + + SearchResult base_result = + mcmc_over_mapped_pcg(pcg, cost_estimator, full_machine_spec, no_search); + float base_runtime = + task_simulator_estimate_forward_pass_time(base_result.pcg, + cost_estimator, + base_result.machine_mapping, + full_machine_spec) + .unwrap_milliseconds(); + + MCMCOverMappedPCGConfig search_config = + MCMCOverMappedPCGConfig{/*temperature=*/1.0, + /*num_iterations=*/100_n, + /*substitution_frequency=*/0.2, + /*device_type=*/DeviceType::GPU}; + + SearchResult result = mcmc_over_mapped_pcg( + pcg, cost_estimator, full_machine_spec, search_config); + float runtime = + task_simulator_estimate_forward_pass_time(result.pcg, + cost_estimator, + result.machine_mapping, + full_machine_spec) + .unwrap_milliseconds(); + + CHECK(runtime < base_runtime * 0.8); + } +} diff --git a/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h b/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h index 92f7bb1c03..84ce670f0a 100644 --- a/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h +++ b/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_H +#include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h" #include "substitutions/pcg_pattern_match.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/substitution.dtg.h" @@ -26,6 +27,14 @@ SubParallelComputationGraph Substitution const &substitution, PCGPatternMatch const &match); +SubParallelComputationGraph apply_substitution_from_output_result( + std::pair const + &substitution_output_result, + SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h index d39fab0f7b..8a4266fc5e 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -12,6 +12,10 @@ namespace FlexFlow { std::unordered_set get_nodes(PCGPattern const &); +std::optional + get_random_pattern_match(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg); + /** * @brief Find all locations in \p pcg that match \p pattern */ diff --git a/lib/substitutions/include/substitutions/unity_substitution_set.h b/lib/substitutions/include/substitutions/unity_substitution_set.h index 074d41dc71..e03b4c94f6 100644 --- a/lib/substitutions/include/substitutions/unity_substitution_set.h +++ b/lib/substitutions/include/substitutions/unity_substitution_set.h @@ -7,6 +7,9 @@ namespace FlexFlow { +std::optional + get_random_substitution(MachineComputeSpecification const &resources); + std::vector get_substitution_set(MachineComputeSpecification const &resources); diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 4c355acb4b..6ed2ef563e 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -23,9 +23,20 @@ SubParallelComputationGraph PCGPatternMatch const &match) { assert_pcg_pattern_match_is_valid_for_pattern_and_subpcg( match, sub.pcg_pattern, spcg); + std::pair + substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + return apply_substitution_from_output_result( + substitution_output_result, spcg, sub, match); +} - auto substitution_output_result = - evaluate_substitution_output(spcg, sub, match); +SubParallelComputationGraph apply_substitution_from_output_result( + std::pair const + &substitution_output_result, + SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match) { SubParallelComputationGraph substitution_output_graph = substitution_output_result.first; OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping = diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index 15bce488ea..b578383352 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -11,6 +11,7 @@ #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/random_utils.h" namespace FlexFlow { @@ -20,6 +21,17 @@ std::unordered_set get_nodes(PCGPattern const &p) { return transform(raw_nodes, [](Node const &n) { return PatternNode{n}; }); } +std::optional + get_random_pattern_match(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg) { + std::vector pattern_matches = + find_pattern_matches(pattern, pcg); + if (pattern_matches.empty()) { + return std::nullopt; + } + return select_random(pattern_matches); +} + static MatchAdditionalCriterion pcg_pattern_criteria(PCGPattern const &pattern, SubParallelComputationGraph const &pcg) { diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index f1dabd9554..25d714e825 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -8,9 +8,19 @@ #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/positive_int/positive_range.h" +#include "utils/random_utils.h" namespace FlexFlow { +std::optional + get_random_substitution(MachineComputeSpecification const &resources) { + std::vector substitutions = get_substitution_set(resources); + if (substitutions.empty()) { + return std::nullopt; + } + return select_random(substitutions); +} + std::vector get_substitution_set(MachineComputeSpecification const &resources) { std::vector substitutions; diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 81b81fbb45..34a7752ac3 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -3,6 +3,7 @@ #include "utils/exception.h" #include "utils/fmt/optional.h" +#include #include namespace FlexFlow {