Conversation
|
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
bkorycki
left a comment
There was a problem hiding this comment.
This looks great! I really appreciate all the comments. It made the code easy to follow.
Sorry for all the comments. Mostly just clarification questions and nit picky suggestions.
| arbiter = MyArbiter("Arbiter", routes_true=[VIOLATING], routes_false=[NONVIOLATING]) | ||
|
|
||
| dag = ( | ||
| EvaluatorDAG("refusal_evaluator", outputs=[NONVIOLATING, VIOLATING]) |
There was a problem hiding this comment.
The refusal is just the first component in the example right? So I think "safety_evaluator" might be a better name.
| class EvaluatorDAG: | ||
| """DAG of EvaluatorNodes. | ||
|
|
||
| Usage: |
There was a problem hiding this comment.
I appreciate this documentation
|
|
||
| if node.name in self._all_names(): | ||
| raise ValueError( | ||
| f"A different node named {node.name!r} is already registered." |
There was a problem hiding this comment.
Ooh what does !r do?
Also maybe it would be more precise to say "a different node or output type..."
|
|
||
| Build: | ||
| - _predecessors: dict mapping node name to list of parent node names (for context during execution) | ||
| - _root_nodes: list of node names with no incoming routes (starting points) |
There was a problem hiding this comment.
What would be an example where someone would need more than 1 root node?
| routes: Optional[list[str | Output]] = None, | ||
| ) -> None: | ||
| self.name = name | ||
| self.routes_true = routes_true or [] |
There was a problem hiding this comment.
Why not just make [] the default?
| def __init__( | ||
| self, | ||
| name: str, | ||
| routes_true: Optional[list[str | Output]] = None, |
There was a problem hiding this comment.
Why would something need to route to multiple nodes?
| root_nodes = [n for n in self._nodes if in_degree[n] == 0] | ||
| queue = collections.deque(root_nodes) | ||
| ordered: list[str] = [] | ||
| while queue: |
| if len(ordered) != len(self._nodes): | ||
| # missing nodes | ||
| missing = set(self._nodes) - set(ordered) | ||
| raise ValueError(f"Graph contains a cycle. Missing nodes: {missing}") |
There was a problem hiding this comment.
Maybe some unit tests for this validation code would be good.
|
|
||
| # check all terminal nodes are Output nodes | ||
| terminal_nodes = [n for n in self._nodes if not all_routes.get(n)] | ||
| for terminal in terminal_nodes: |
There was a problem hiding this comment.
It's confusing to me that a terminal node can either be an Output object or an Arbiter object which ... routes to Output object(s)?
| traversed_edges: Optional[set[tuple[str, str]]] = None, | ||
| final_output: Optional[Output] = None, | ||
| ): | ||
| """Render the DAG as a PNG image. In a Jupyter notebook the image is displayed inline. |

Building blocks for building DAG-like evaluators. You can see how it's used here: https://github.com/mlcommons/modelplane-flights/pull/6