diff --git a/include/expression_dag.hh b/include/expression_dag.hh index da08bb5..43a194f 100644 --- a/include/expression_dag.hh +++ b/include/expression_dag.hh @@ -15,6 +15,7 @@ #include "config.hh" #include "scalar_node.hh" #include "assert.hh" +#include "array.hh" namespace bparser { @@ -40,6 +41,8 @@ private: /// Result nodes, given as input. NodeVec results; + typedef std::pair InvDotNameAndIndices; + typedef std::map InvDotMap; /** * Used in the setup_result_storage to note number of unclosed nodes @@ -102,6 +105,7 @@ public: /** * Print ScalarExpression graph in the dot format. + * Useful for debugging */ void print_in_dot() { std::map i_node; @@ -131,8 +135,217 @@ public: std::cout << "Node: " << node->op_name_ << "_" << node->result_idx_ << " " << node->result_storage << std::endl; } + /** + * Print ScalarExpression graph in the common dot format. + * Useful for understanding the DAG. + */ + void print_in_dot2() { + print_in_dot2(InvDotMap()); + } + + /** + * Print ScalarExpression graph in the common dot format. + * Useful for understanding the DAG. Using the parser's map of var. Name -> Array find the inverse ScalarNodePtr -> var. Name + */ + void print_in_dot2(const std::map& symbols) { + print_in_dot2(create_inverse_map(symbols)); + } + + /** + * Print ScalarExpression graph in the common dot format. + * Useful for understanding the DAG. Using the map of ScalarNodePtr -> variableName + */ + void print_in_dot2(const InvDotMap& names) { + + sort_nodes(); + + std::cout << "\n" << "----- begin cut here -----" << "\n"; + std::cout << "digraph Expr {" << "\n"; + + std::cout << "/* definitions */" << "\n"; + + std::cout << "edge [dir=back]" << "\n"; + for (uint i = 0; i < sorted.size(); ++i) { + _print_dot_node_definition(sorted[i],names); + } + std::cout << "/* end of definitions */" << "\n"; + + for (uint i = 0; i < sorted.size(); ++i) { + for (uint in = 0; in < sorted[i]->n_inputs_; ++in) { + std::cout << " "; + _print_dot_node_id(sorted[i]); + std::cout << "\n -> "; + _print_dot_node_id(sorted[i]->inputs_[in]); + std::cout << "\n\n"; + } + } + std::cout << "}" << "\n"; + std::cout << "----- end cut here -----" << "\n"; + std::cout.flush(); + } + + //Create a map of ScalarNodePtr -> (variable name, is_scalar) + InvDotMap create_inverse_map(const std::map& symbols) const { + InvDotMap inv_map; + if (symbols.empty()) return inv_map; + for (const auto& s : symbols) + { + for (MultiIdx idx(s.second.range()); idx.valid();idx.inc_src()) { + inv_map[s.second[idx]] = InvDotNameAndIndices(s.first, idx.indices()); + } + /*for (const auto& n : s.second.elements()) { + inv_map[n] = InvDotNameAndIndices(s.first, s.second.shape().empty()); + }*/ + } + return inv_map; + } + private: + //Print the vertice identifier for dot + void _print_dot_node_id(const ScalarNodePtr& node) const { + std::cout << node->op_name_ << "_" << (uintptr_t)node.get() << "__" << node->result_storage;// << std::endl; + } + + //Print how the vertice should look in dot + void _print_dot_node_definition(const ScalarNodePtr& node, const InvDotMap& invmap) const { + _print_dot_node_id(node); + std::cout << ' '; + + switch (node->result_storage) { + case ResultStorage::constant: { // Constant + std::cout << "[shape=circle,"; + + try { //If the constant has a name + std::string name(invmap.at(node).first); + const MultiIdx::VecUint indices(invmap.at(node).second); + bool scalar(indices.empty()); + if (scalar) { + std::cout << "label=\"" << name << '=' << *node->values_ << '\"'; + } + else { + MultiIdx::VecUint::size_type size(indices.size()); + std::cout << "label=\"" << name << "["; + for (MultiIdx::VecUint::size_type i = 0; i < size; i++) { + std::cout << indices.at(i); + if (i != size - 1) { + std::cout << ','; + } + } + std::cout << "]"; + std::cout << '=' << *node->values_ << '\"'; + } + std::cout << ", group = \"" << name << '\"'; + } + catch (const std::out_of_range&) { //No name + std::cout << "label=\"" << "const " << *node->values_ << '"'; + } + std::cout << "]" << std::endl; + break; + } + + case ResultStorage::constant_bool: { //Constant bool + std::cout << "[shape=circle,"; + + try { //If the constant has a name + std::string name(invmap.at(node).first); + const MultiIdx::VecUint indices(invmap.at(node).second); + bool scalar(indices.empty()); + if (scalar) { + std::cout << "label=\"" << name << '=' << *node->values_ << '\"'; + } + else { + MultiIdx::VecUint::size_type size(indices.size()); + std::cout << "label=\"" << name << "["; + for (MultiIdx::VecUint::size_type i = 0; i < size; i++) { + std::cout << indices.at(i); + if (i != size - 1) { + std::cout << ','; + } + } + std::cout << "]"; + std::cout << '=' << *node->values_ << '\"'; + } + std::cout << ", group = \"" << name << '\"'; + } + catch (const std::out_of_range&) { //No name + std::cout << "label=\"" << "const " << *node->values_ << '"'; + } + std::cout << "]" << std::endl; + break; + } + + case ResultStorage::expr_result: { //Result + std::cout << "[shape=box,label=\"" << node->op_name_ << " [" << node->result_idx_ << "]" << "\"]" << std::endl; + break; + } + + case ResultStorage::value: { // Value + std::cout << "[shape=circle,"; + try { + std::string name(invmap.at(node).first); + const MultiIdx::VecUint indices(invmap.at(node).second); + bool scalar(indices.empty()); + if (scalar) { + std::cout << "label=\"" << name << '"'; + } + else { + MultiIdx::VecUint::size_type size(indices.size()); + std::cout << "label=\"" << name << "["; + for (MultiIdx::VecUint::size_type i = 0; i < size; i++) { + std::cout << indices.at(i); + if (i != size - 1) { + std::cout << ','; + } + } + std::cout << "]\""; + } + std::cout << ",group=\"" << name << '"'; + } + catch (const std::out_of_range&) { + std::cout << "label=<var>"; + } + + std::cout << "]" << std::endl; + break; + } + + case ResultStorage::value_copy: { //Value copy + std::cout << "[shape=circle,"; + try { + std::string name(invmap.at(node).first); + MultiIdx::VecUint indices(invmap.at(node).second); + bool scalar(indices.empty()); + if (scalar) { + std::cout << "label=\"" << name << '"'; + } + else { + MultiIdx::VecUint::size_type size(indices.size()); + std::cout << "label=\"" << name << "["; + for (MultiIdx::VecUint::size_type i = 0; i < size; i++) { + std::cout << indices.at(i); + if (i != size - 1) { + std::cout << ','; + } + } + std::cout << "]\""; + } + std::cout << ",group=\"" << name << '"'; + } + catch (const std::out_of_range&) { + std::cout << "label=<var_cp>"; + } + std::cout << "]" << std::endl; + break; + } + + default: { //Temporary & other + std::cout << "[label=\"" << node->op_name_ << "\"]" << std::endl; + break; + } + } //switch + } + void _print_i_node(uint i) { std::cout << sorted[i]->op_name_ << "_" << i << "_"<< sorted[i]->result_idx_; } diff --git a/include/parser.hh b/include/parser.hh index ca8bde1..781dbc8 100644 --- a/include/parser.hh +++ b/include/parser.hh @@ -190,6 +190,8 @@ public: details::ExpressionDAG se(result_array_.elements()); //se.print_in_dot(); + //se.print_in_dot2(); + //se.print_in_dot2(symbols_); processor = ProcessorBase::create_processor(se, max_vec_size, simd_size, arena); }