From 80f4efb531158820f518152b42e583f16bce4069 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Mar 2026 15:55:50 -0800 Subject: [PATCH 1/2] Add V1MappedParallelComputationGraph. --- .../v1/v1_mapped_operator_task_group.dtg.toml | 19 +++++ .../v1/v1_mapped_operator_task_group.h | 13 ++++ ...mapped_parallel_computation_graph.dtg.toml | 29 ++++++++ .../v1/v1_mapped_parallel_computation_graph.h | 13 ++++ .../parallel_layer_guid_t.dtg.toml | 1 + .../v1/v1_mapped_operator_task_group.cc | 9 +++ .../v1_mapped_parallel_computation_graph.cc | 17 +++++ .../v1_mapped_parallel_computation_graph.cc | 69 +++++++++++++++++++ 8 files changed, 170 insertions(+) create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.dtg.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc create mode 100644 lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.dtg.toml new file mode 100644 index 0000000000..2e4300745d --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "V1MappedOperatorTaskGroup" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/machine_space_coordinate.dtg.h", + "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "shard_bindings" +type = "::FlexFlow::bidict<::FlexFlow::MachineSpaceCoordinate, ::FlexFlow::OperatorAtomicTaskShardBinding>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h new file mode 100644 index 0000000000..7c1788b9d0 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_OPERATOR_TASK_GROUP_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_OPERATOR_TASK_GROUP_H + +#include "pcg/file_format/v1/v1_mapped_operator_task_group.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" + +namespace FlexFlow { + +V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.toml new file mode 100644 index 0000000000..8dc336e4ea --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "V1MappedParallelComputationGraph" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "", + "pcg/file_format/v1/v1_parallel_computation_graph.dtg.h", + "pcg/file_format/v1/v1_mapped_operator_task_group.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::V1ParallelComputationGraph" + +[[fields]] +name = "mapped_tasks" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::V1MappedOperatorTaskGroup>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h new file mode 100644 index 0000000000..5b9d18ccc4 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_MAPPED_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/file_format/v1/v1_mapped_parallel_computation_graph.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +V1MappedParallelComputationGraph to_v1(MappedParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml index 618bcb0dc4..292b361fc8 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc b/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc new file mode 100644 index 0000000000..480ea7197a --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc @@ -0,0 +1,9 @@ +#include "pcg/file_format/v1/v1_mapped_operator_task_group.h" + +namespace FlexFlow { + +V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &g) { + return V1MappedOperatorTaskGroup{g.get_shard_bindings()}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc new file mode 100644 index 0000000000..96429e06a6 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc @@ -0,0 +1,17 @@ +#include "pcg/file_format/v1/v1_mapped_parallel_computation_graph.h" +#include "pcg/file_format/v1/v1_mapped_operator_task_group.h" +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "utils/containers/map_values.h" + +namespace FlexFlow { + +V1MappedParallelComputationGraph + to_v1(MappedParallelComputationGraph const &mpcg) { + return V1MappedParallelComputationGraph{ + to_v1(mpcg.pcg), + map_values(mpcg.mapped_tasks, + [](MappedOperatorTaskGroup const &g) { return to_v1(g); }), + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc new file mode 100644 index 0000000000..dd2cdb35ed --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc @@ -0,0 +1,69 @@ +#include "pcg/file_format/v1/v1_mapped_parallel_computation_graph.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" +#include "utils/bidict/bidict.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1MappedParallelComputationGraph") { + MappedParallelComputationGraph mpcg = [] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 16_p, + }, + }, + DataType::FLOAT, + }; + + ParallelLayerAddedResult result = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t layer = result.parallel_layer; + + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }; + + OperatorAtomicTaskShardBinding binding = OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::OUTPUT, + ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_components=*/FFOrdered{0_n, 0_n}, + }, + }, + }, + }; + + MappedOperatorTaskGroup task_group = MappedOperatorTaskGroup{ + bidict{ + {coord, binding}, + }, + }; + + return MappedParallelComputationGraph{ + /*pcg=*/pcg, + /*mapped_tasks=*/{{layer, task_group}}, + }; + }(); + + V1MappedParallelComputationGraph v1_mpcg = to_v1(mpcg); + nlohmann::json j = v1_mpcg; + } +} From 541e8faa63fc2b04768328ec86945afac7ab9c26 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Mar 2026 16:44:43 -0800 Subject: [PATCH 2/2] Add from_v1. --- .../graphs/v1_labelled_kwarg_dataflow_graph.h | 57 +++++++++++ .../pcg/file_format/v1/v1_computation_graph.h | 1 + .../v1/v1_mapped_operator_task_group.h | 1 + .../v1/v1_mapped_parallel_computation_graph.h | 2 + .../v1/v1_parallel_computation_graph.h | 1 + .../v1_labelled_kwarg_dataflow_graph.cc | 3 + .../file_format/v1/v1_computation_graph.cc | 6 ++ .../v1/v1_mapped_operator_task_group.cc | 4 + .../v1_mapped_parallel_computation_graph.cc | 9 ++ .../v1/v1_parallel_computation_graph.cc | 6 ++ .../file_format/v1/v1_computation_graph.cc | 12 ++- .../v1_mapped_parallel_computation_graph.cc | 94 +++++++++++-------- .../v1/v1_parallel_computation_graph.cc | 12 ++- 13 files changed, 165 insertions(+), 43 deletions(-) diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h index dbe660c3a6..1c1bc70c88 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h @@ -6,7 +6,14 @@ #include "utils/bidict/algorithms/bidict_from_enumerating.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" #include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -50,6 +57,56 @@ V1LabelledKwargDataflowGraph to_v1( return to_v1_including_node_numbering(g).first; } +template +LabelledKwargDataflowGraph from_v1( + V1LabelledKwargDataflowGraph const &v1) { + // Build incoming-edge map + std::unordered_map>> + incoming; + for (nonnegative_int const &n : v1.graph.nodes) { + incoming[n] = {}; + } + for (V1GraphEdge const &e : v1.graph.edges) { + incoming[e.dstNode].push_back(e); + } + + // Build a DiGraph with V1 indices as Node raw_uids to get topological order + DiGraph dg = DiGraph::create(); + for (nonnegative_int const &n : v1.graph.nodes) { + dg.add_node_unsafe(Node{static_cast(n.unwrap_nonnegative())}); + } + for (V1GraphEdge const &e : v1.graph.edges) { + dg.add_edge(DirectedEdge{ + Node{static_cast(e.srcNode.unwrap_nonnegative())}, + Node{static_cast(e.dstNode.unwrap_nonnegative())}}); + } + + auto g = LabelledKwargDataflowGraph:: + template create>(); + + std::unordered_map node_map; + for (Node const &topo_node : get_topological_ordering(dg)) { + nonnegative_int v1_idx{static_cast(topo_node.raw_uid)}; + + std::unordered_map> inputs; + for (V1GraphEdge const &e : incoming.at(v1_idx)) { + inputs.emplace( + e.dstSlot, + KwargDataflowOutput{node_map.at(e.srcNode), e.srcSlot}); + } + + KwargNodeAddedResult result = g.add_node( + v1.node_labels.at(v1_idx), inputs, v1.output_labels.at(v1_idx)); + + node_map.emplace(v1_idx, result.node); + } + + return g; +} + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h index c0e9966425..8b6128d603 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h @@ -8,6 +8,7 @@ namespace FlexFlow { V1ComputationGraph to_v1(ComputationGraph const &); +ComputationGraph from_v1(V1ComputationGraph const &); std::pair> to_v1_including_node_numbering(ComputationGraph const &); diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h index 7c1788b9d0..8e386e156f 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h @@ -7,6 +7,7 @@ namespace FlexFlow { V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &); +MappedOperatorTaskGroup from_v1(V1MappedOperatorTaskGroup const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h index 5b9d18ccc4..f78efc4591 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h @@ -7,6 +7,8 @@ namespace FlexFlow { V1MappedParallelComputationGraph to_v1(MappedParallelComputationGraph const &); +MappedParallelComputationGraph + from_v1(V1MappedParallelComputationGraph const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h index aceb59f5af..d481096d49 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h @@ -7,6 +7,7 @@ namespace FlexFlow { V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); +ParallelComputationGraph from_v1(V1ParallelComputationGraph const &); } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc index 4e7b9b651f..a2953e8fb3 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc @@ -18,4 +18,7 @@ template std::pair< template V1LabelledKwargDataflowGraph to_v1( LabelledKwargDataflowGraphView const &); +template LabelledKwargDataflowGraph from_v1( + V1LabelledKwargDataflowGraph const &); + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc index 852ca73a36..71fc105711 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -10,6 +10,12 @@ V1ComputationGraph to_v1(ComputationGraph const &g) { }; } +ComputationGraph from_v1(V1ComputationGraph const &v1) { + return ComputationGraph{ + from_v1(v1.raw_graph), + }; +} + std::pair> to_v1_including_node_numbering(ComputationGraph const &cg) { std::pair< diff --git a/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc b/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc index 480ea7197a..465dd01fb6 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc @@ -6,4 +6,8 @@ V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &g) { return V1MappedOperatorTaskGroup{g.get_shard_bindings()}; } +MappedOperatorTaskGroup from_v1(V1MappedOperatorTaskGroup const &v1) { + return MappedOperatorTaskGroup{v1.shard_bindings}; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc index 96429e06a6..0236a8834c 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc @@ -14,4 +14,13 @@ V1MappedParallelComputationGraph }; } +MappedParallelComputationGraph + from_v1(V1MappedParallelComputationGraph const &v1) { + return MappedParallelComputationGraph{ + from_v1(v1.pcg), + map_values(v1.mapped_tasks, + [](V1MappedOperatorTaskGroup const &g) { return from_v1(g); }), + }; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc index e14d15d66a..a169abe4c1 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -10,4 +10,10 @@ V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { }; } +ParallelComputationGraph from_v1(V1ParallelComputationGraph const &v1) { + return ParallelComputationGraph{ + from_v1(v1.raw_graph), + }; +} + } // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc index 7af3f648d9..2ae643bd0f 100644 --- a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc @@ -1,6 +1,8 @@ #include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" #include +#include using namespace ::FlexFlow; @@ -25,6 +27,14 @@ TEST_SUITE(FF_TEST_SUITE) { }(); V1ComputationGraph v1_cg = to_v1(cg); - nlohmann::json j = v1_cg; + + SUBCASE("serializes to JSON") { + nlohmann::json j = v1_cg; + } + + SUBCASE("round-trips via from_v1") { + ComputationGraph result = from_v1(v1_cg); + CHECK(computation_graphs_are_isomorphic(cg, result)); + } } } diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc index dd2cdb35ed..78da5430b7 100644 --- a/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc @@ -2,6 +2,7 @@ #include "op-attrs/parallel_tensor_space_coordinate.dtg.h" #include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/device_type.dtg.h" +#include "pcg/file_format/v1/v1_mapped_operator_task_group.h" #include "pcg/machine_space_coordinate.dtg.h" #include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" @@ -16,54 +17,65 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("V1MappedParallelComputationGraph") { - MappedParallelComputationGraph mpcg = [] { - ParallelComputationGraph pcg = empty_parallel_computation_graph(); + ParallelComputationGraph pcg = empty_parallel_computation_graph(); - TensorShape input_shape = TensorShape{ - TensorDims{ - FFOrdered{ - 12_p, - 16_p, - }, - }, - DataType::FLOAT, - }; + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 16_p, + }, + }, + DataType::FLOAT, + }; - ParallelLayerAddedResult result = pcg_add_input_layer(pcg, input_shape); - parallel_layer_guid_t layer = result.parallel_layer; + ParallelLayerAddedResult result = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t layer = result.parallel_layer; - MachineSpaceCoordinate coord = MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }; + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }; - OperatorAtomicTaskShardBinding binding = OperatorAtomicTaskShardBinding{ - /*tensor_coords=*/{ - { - TensorSlotName::OUTPUT, - ParallelTensorSpaceCoordinate{ - /*sum_component=*/0_n, - /*discard_copy_component=*/0_n, - /*shard_components=*/FFOrdered{0_n, 0_n}, - }, - }, - }, - }; + OperatorAtomicTaskShardBinding binding = OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::OUTPUT, + ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_components=*/FFOrdered{0_n, 0_n}, + }, + }, + }, + }; - MappedOperatorTaskGroup task_group = MappedOperatorTaskGroup{ - bidict{ - {coord, binding}, - }, - }; + MappedOperatorTaskGroup task_group = MappedOperatorTaskGroup{ + bidict{ + {coord, binding}, + }, + }; - return MappedParallelComputationGraph{ - /*pcg=*/pcg, - /*mapped_tasks=*/{{layer, task_group}}, - }; - }(); + MappedParallelComputationGraph mpcg = MappedParallelComputationGraph{ + /*pcg=*/pcg, + /*mapped_tasks=*/{{layer, task_group}}, + }; V1MappedParallelComputationGraph v1_mpcg = to_v1(mpcg); - nlohmann::json j = v1_mpcg; + + SUBCASE("serializes to JSON") { + nlohmann::json j = v1_mpcg; + } + + SUBCASE("MappedOperatorTaskGroup round-trips via from_v1") { + MappedOperatorTaskGroup result = from_v1(to_v1(task_group)); + CHECK(result == task_group); + } + + SUBCASE("MappedParallelComputationGraph round-trips via from_v1") { + MappedParallelComputationGraph result = from_v1(v1_mpcg); + CHECK(pcgs_are_isomorphic(mpcg.pcg, result.pcg)); + } } } diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc index ec6a4ab006..033626ab5c 100644 --- a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -1,6 +1,8 @@ #include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include +#include using namespace ::FlexFlow; @@ -29,6 +31,14 @@ TEST_SUITE(FF_TEST_SUITE) { }(); V1ParallelComputationGraph v1_pcg = to_v1(pcg); - nlohmann::json j = v1_pcg; + + SUBCASE("serializes to JSON") { + nlohmann::json j = v1_pcg; + } + + SUBCASE("round-trips via from_v1") { + ParallelComputationGraph result = from_v1(v1_pcg); + CHECK(pcgs_are_isomorphic(pcg, result)); + } } }