diff --git a/src/gfn/gym/__init__.py b/src/gfn/gym/__init__.py index 2d6cf6ad..366500f8 100644 --- a/src/gfn/gym/__init__.py +++ b/src/gfn/gym/__init__.py @@ -2,6 +2,7 @@ from .bitSequence import BitSequence, BitSequencePlus from .box import Box +from .chip_design import ChipDesign from .discrete_ebm import DiscreteEBM from .graph_building import GraphBuilding, GraphBuildingOnEdges from .hypergrid import HyperGrid @@ -20,4 +21,5 @@ "GraphBuildingOnEdges", "PerfectBinaryTree", "SetAddition", + "ChipDesign", ] diff --git a/src/gfn/gym/chip_design.py b/src/gfn/gym/chip_design.py new file mode 100644 index 00000000..9d923890 --- /dev/null +++ b/src/gfn/gym/chip_design.py @@ -0,0 +1,253 @@ +"""GFlowNet environment for chip placement.""" + +from typing import ClassVar, Optional, Sequence, cast + +import torch + +from gfn.actions import Actions +from gfn.env import DiscreteEnv +from gfn.gym.helpers.chip_design import SAMPLE_INIT_PLACEMENT, SAMPLE_NETLIST_FILE +from gfn.gym.helpers.chip_design import utils as placement_util +from gfn.gym.helpers.chip_design.utils import cost_info_function +from gfn.states import DiscreteStates + + +class ChipDesignStates(DiscreteStates): + """A class to represent the states of the chip design environment.""" + + state_shape: ClassVar[tuple[int, ...]] + s0: ClassVar[torch.Tensor] + sf: ClassVar[torch.Tensor] + n_actions: ClassVar[int] + + def __init__( + self, + tensor: torch.Tensor, + forward_masks: Optional[torch.Tensor] = None, + backward_masks: Optional[torch.Tensor] = None, + current_node_idx: Optional[torch.Tensor] = None, + ): + super().__init__( + tensor=tensor, forward_masks=forward_masks, backward_masks=backward_masks + ) + if current_node_idx is None: + is_unplaced = tensor == -1 + is_unplaced_padded = torch.cat( + [ + is_unplaced, + torch.ones_like(is_unplaced[..., :1]), + ], + dim=-1, + ) + current_node_idx = is_unplaced_padded.long().argmax(dim=-1) + + self.current_node_idx = current_node_idx + + def clone(self) -> "ChipDesignStates": + """Creates a copy of the states.""" + return self.__class__( + self.tensor.clone(), + current_node_idx=self.current_node_idx.clone(), + forward_masks=self.forward_masks.clone(), + backward_masks=self.backward_masks.clone(), + ) + + def __getitem__(self, index) -> "ChipDesignStates": + """Gets a subset of the states.""" + return self.__class__( + self.tensor[index], + current_node_idx=self.current_node_idx[index], + forward_masks=self.forward_masks[index], + backward_masks=self.backward_masks[index], + ) + + def __setitem__(self, index, value: "ChipDesignStates") -> None: + """Sets a subset of the states.""" + super().__setitem__(index, value) + self.current_node_idx[index] = value.current_node_idx + + def extend(self, other: "ChipDesignStates") -> None: + """Extends the states with another states.""" + super().extend(other) + self.current_node_idx = torch.cat( + (self.current_node_idx, other.current_node_idx), + dim=len(self.batch_shape) - 1, + ) + + @classmethod + def stack(cls, states: Sequence["ChipDesignStates"]) -> "ChipDesignStates": + """Stacks the states with another states.""" + stacked = super().stack(states) + stacked.current_node_idx = torch.stack( + [s.current_node_idx for s in states], + dim=0, + ) + return cast(ChipDesignStates, stacked) + + +class ChipDesign(DiscreteEnv): + """ + GFlowNet environment for chip placement. + + The state is a vector of length `n_macros`, where `state[i]` is the grid + cell location of the i-th macro to be placed. Unplaced macros have a + location of -1. + + Actions are integers from `0` to `n_grid_cells - 1`, representing the + grid cell to place the current macro on. Action `n_grid_cells` is the + exit action. + """ + + def __init__( + self, + netlist_file: str = SAMPLE_NETLIST_FILE, + init_placement: str = SAMPLE_INIT_PLACEMENT, + std_cell_placer_mode: str = "fd", + wirelength_weight: float = 1.0, + density_weight: float = 1.0, + congestion_weight: float = 0.5, + device: str | None = None, + check_action_validity: bool = True, + ): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + self.plc = placement_util.create_placement_cost( + netlist_file=netlist_file, init_placement=init_placement + ) + self.std_cell_placer_mode = std_cell_placer_mode + + self.wirelength_weight = wirelength_weight + self.density_weight = density_weight + self.congestion_weight = congestion_weight + + self._grid_cols, self._grid_rows = self.plc.get_grid_num_columns_rows() + self.n_grid_cells = self._grid_cols * self._grid_rows + + self._sorted_node_indices = placement_util.get_ordered_node_indices( + mode="descending_size_macro_first", plc=self.plc + ) + self._hard_macro_indices = [ + m for m in self._sorted_node_indices if not self.plc.is_node_soft_macro(m) + ] + self.n_macros = len(self._hard_macro_indices) + + s0 = torch.full((self.n_macros,), -1, dtype=torch.long, device=device) + sf = torch.full((self.n_macros,), -2, dtype=torch.long, device=device) + n_actions = self.n_grid_cells + 1 + + super().__init__( + n_actions=n_actions, + s0=s0, + state_shape=(self.n_macros,), + sf=sf, + check_action_validity=check_action_validity, + ) + self.States: type[ChipDesignStates] = self.make_states_class() + + def make_states_class(self) -> type[ChipDesignStates]: + """Creates the ChipDesignStates class.""" + env = self + + class BaseChipDesignStates(ChipDesignStates): + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + n_actions = env.n_actions + device = env.device + + return BaseChipDesignStates + + def _apply_state_to_plc(self, state_tensor: torch.Tensor): + """Applies a single state tensor to the plc object.""" + assert state_tensor.shape == (self.n_macros,) + + self.plc.unplace_all_nodes() + for i in range(self.n_macros): + loc = state_tensor[i].item() + if loc != -1: + node_index = self._hard_macro_indices[i] + self.plc.place_node(node_index, loc) + + def update_masks(self, states: ChipDesignStates) -> None: + """Updates the forward and backward masks of the states.""" + states.forward_masks.zero_() + states.backward_masks.zero_() + + for i in range(len(states)): + state_tensor = states.tensor[i] + current_node_idx = states.current_node_idx[i].item() + + if current_node_idx >= self.n_macros: # All macros placed + states.forward_masks[i, -1] = True # Only exit is possible + else: + # Apply partial placement to plc to get mask for next node + self._apply_state_to_plc(state_tensor) + node_to_place = self._hard_macro_indices[int(current_node_idx)] + mask = self.plc.get_node_mask(node_to_place) + mask = torch.tensor(mask, dtype=torch.bool, device=self.device) + states.forward_masks[i, : self.n_grid_cells] = mask + states.forward_masks[i, -1] = False # No exit + + if current_node_idx > 0: + last_placed_loc = state_tensor[int(current_node_idx - 1)].item() + assert last_placed_loc != -1, "Last placed location should not be -1" + states.backward_masks[i, int(last_placed_loc)] = True + + def step(self, states: ChipDesignStates, actions: Actions) -> ChipDesignStates: + """Performs a forward step in the environment.""" + new_tensor = states.tensor.clone() + + non_exit_mask = ~actions.is_exit + if torch.any(non_exit_mask): + rows = torch.arange(len(states), device=self.device)[non_exit_mask] + cols = states.current_node_idx[non_exit_mask] + new_tensor[rows, cols] = actions.tensor[non_exit_mask].squeeze(-1) + + if torch.any(actions.is_exit): + new_tensor[actions.is_exit] = self.sf + + new_current_node_idx = states.current_node_idx.clone() + new_current_node_idx[non_exit_mask] += 1 + + return self.States(tensor=new_tensor, current_node_idx=new_current_node_idx) + + def backward_step( + self, states: ChipDesignStates, actions: Actions + ) -> ChipDesignStates: + """Performs a backward step in the environment.""" + new_tensor = states.tensor.clone() + rows = torch.arange(len(states), device=self.device) + cols = states.current_node_idx - 1 + new_tensor[rows, cols] = -1 + + new_current_node_idx = states.current_node_idx - 1 + return self.States(tensor=new_tensor, current_node_idx=new_current_node_idx) + + def analytical_placer(self): + """Places standard cells using an analytical placer.""" + if self.std_cell_placer_mode == "fd": + placement_util.fd_placement_schedule(self.plc) + else: + raise ValueError( + f"{self.std_cell_placer_mode} is not a supported std_cell_placer_mode." + ) + + def log_reward(self, final_states: ChipDesignStates) -> torch.Tensor: + """Computes the log reward of the final states.""" + rewards = torch.zeros(len(final_states), device=self.device) + for i in range(len(final_states)): + state_tensor = final_states.tensor[i] + self._apply_state_to_plc(state_tensor) + + self.analytical_placer() + + cost, _ = cost_info_function( + plc=self.plc, + done=True, + wirelength_weight=self.wirelength_weight, + density_weight=self.density_weight, + congestion_weight=self.congestion_weight, + ) + rewards[i] = -cost + return rewards diff --git a/src/gfn/gym/helpers/chip_design/__init__.py b/src/gfn/gym/helpers/chip_design/__init__.py new file mode 100644 index 00000000..b7946290 --- /dev/null +++ b/src/gfn/gym/helpers/chip_design/__init__.py @@ -0,0 +1,24 @@ +# Code copied from https://github.com/google-research/circuit_training +# @article{mirhoseini2021graph, +# title={A graph placement methodology for fast chip design}, +# author={Mirhoseini*, Azalia and Goldie*, Anna and Yazgan, Mustafa and Jiang, Joe +# Wenjie and Songhori, Ebrahim and Wang, Shen and Lee, Young-Joon and Johnson, +# Eric and Pathak, Omkar and Nazi, Azade and Pak, Jiwoo and Tong, Andy and +# Srinivasa, Kavya and Hang, William and Tuncer, Emre and V. Le, Quoc and +# Laudon, James and Ho, Richard and Carpenter, Roger and Dean, Jeff}, +# journal={Nature}, +# volume={594}, +# number={7862}, +# pages={207--212}, +# year={2021}, +# publisher={Nature Publishing Group} +# } + +import os + +SAMPLE_NETLIST_FILE = os.path.join( + os.path.dirname(__file__), "test_data", "netlist.pb.txt" +) +SAMPLE_INIT_PLACEMENT = os.path.join( + os.path.dirname(__file__), "test_data", "initial.plc" +) diff --git a/src/gfn/gym/helpers/chip_design/plc_client.py b/src/gfn/gym/helpers/chip_design/plc_client.py new file mode 100644 index 00000000..719bcf45 --- /dev/null +++ b/src/gfn/gym/helpers/chip_design/plc_client.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2021 The Circuit Training Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PlacementCost client class.""" + +import json +import logging +import os +import socket +import subprocess +import tempfile +from typing import Any, Text + +logger = logging.getLogger(__name__) + + +class PlacementCost: + """PlacementCost object wrapper.""" + + BUFFER_LEN = 1024 * 1024 + MAX_RETRY = 256 + + def __init__( + self, + netlist_file: Text, + plc_wrapper_main: str = os.path.join( + os.path.dirname(__file__), "plc_wrapper_main" + ), + macro_macro_x_spacing: float = 0.0, + macro_macro_y_spacing: float = 0.0, + ) -> None: + """Creates a PlacementCost client object. + + It creates a subprocess by calling plc_wrapper_main and communicate with + it over an `AF_UNIX` channel. + + Args: + netlist_file: Path to the netlist proto text file. + macro_macro_x_spacing: Macro-to-macro x spacing in microns. + macro_macro_y_spacing: Macro-to-macro y spacing in microns. + """ + if not plc_wrapper_main: + raise ValueError("plc_wrapper_main should be specified.") + + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + address = tempfile.NamedTemporaryFile().name + self.sock.bind(address) + self.sock.listen(1) + args = [ + plc_wrapper_main, + "--uid=0", + "--gid=0", + f"--pipe_address={address}", + f"--netlist_file={netlist_file}", + f"--macro_macro_x_spacing={macro_macro_x_spacing}", + f"--macro_macro_y_spacing={macro_macro_y_spacing}", + ] + self.process = subprocess.Popen([str(a) for a in args]) + self.conn, _ = self.sock.accept() + + # See circuit_training/environment/plc_client_test.py for the supported APIs. + def __getattr__(self, name) -> Any: + # snake_case to PascalCase. + name = name.replace("_", " ").title().replace(" ", "") + + def f(*args) -> Any: + json_args = json.dumps({"name": name, "args": args}) + self.conn.send(json_args.encode("utf-8")) + json_ret = b"" + retry = 0 + # The stream from the unix socket can be incomplete after a single call + # to `recv` for large (200kb+) return values, e.g. GetMacroAdjacency. The + # loop retries until the returned value is valid json. When the host is + # under load ~10 retries have been needed. Adding a sleep did not seem to + # make a difference only added latency. b/210838186 + while True: + part = self.conn.recv(PlacementCost.BUFFER_LEN) + json_ret += part + if len(part) < PlacementCost.BUFFER_LEN: + json_str = json_ret.decode("utf-8") + try: + output = json.loads(json_str) + break + except json.decoder.JSONDecodeError as e: + logger.warning("JSONDecode Error for %s \n %s", name, e) + if retry < PlacementCost.MAX_RETRY: + logger.info( + "Looking for more data for %s on connection:%s/%s", + name, + retry, + PlacementCost.MAX_RETRY, + ) + retry += 1 + else: + raise e + if isinstance(output, dict): + if "ok" in output and not output["ok"]: # Status::NotOk + raise ValueError( + f"Error in calling {name} with {args}: {output['message']}." + ) + elif "__tuple__" in output: # Tuple + output = tuple(output["items"]) + elif isinstance(output, list): + if ( + len(output) > 0 + and isinstance(output[0], dict) + and "__tuple__" in output[0] + ): # List of tuples + output = [tuple(o["items"]) for o in output] + return output + + return f + + def close(self) -> None: + self.conn.close() + self.process.kill() + self.process.wait() + self.sock.close() diff --git a/src/gfn/gym/helpers/chip_design/plc_wrapper_main b/src/gfn/gym/helpers/chip_design/plc_wrapper_main new file mode 100755 index 00000000..2f0fc3f0 Binary files /dev/null and b/src/gfn/gym/helpers/chip_design/plc_wrapper_main differ diff --git a/src/gfn/gym/helpers/chip_design/test_data/initial.plc b/src/gfn/gym/helpers/chip_design/test_data/initial.plc new file mode 100644 index 00000000..b886e558 --- /dev/null +++ b/src/gfn/gym/helpers/chip_design/test_data/initial.plc @@ -0,0 +1,36 @@ +# Placement file for Circuit Training +# Source input file(s) : circuit_training/environment/test_data/sample_clustered/netlist.pb.txt +# This file : circuit_training/environment/test_data/sample_clustered/initial.plc +# Date : 2022-03-13 09:30:00 +# Columns : 2 Rows : 2 +# Width : 500.000 Height : 500.000 +# Area : 17603.53279986302 +# Wirelength : 0.0 +# Wirelength cost : 0.0 +# Congestion cost : 0.0 +# Density cost : 0.2305 +# Project : circuit_training +# Block : sample_clustered +# Routes per micron, hor : 70.33 ver : 74.51 +# Routes used by macros, hor : 51.79 ver : 51.79 +# Smoothing factor : 2 +# Overlap threshold : 0.004 +# +# +# +# Counts of node types: +# HARD_MACROs : 2 +# HARD_MACRO_PINs : 4 +# MACROs : 3 +# MACRO_PINs : 6 +# PORTs : 2 +# SOFT_MACROs : 1 +# SOFT_MACRO_PINs : 2 +# STDCELLs : 0 +# +# node_index x y orientation fixed +0 0 100 - 1 +1 499 499 - 1 +2 125 375 N 0 +3 375 375 N 0 +8 170 310 N 0 diff --git a/src/gfn/gym/helpers/chip_design/test_data/netlist.pb.txt b/src/gfn/gym/helpers/chip_design/test_data/netlist.pb.txt new file mode 100644 index 00000000..2c6dcfd8 --- /dev/null +++ b/src/gfn/gym/helpers/chip_design/test_data/netlist.pb.txt @@ -0,0 +1,334 @@ +# proto-file: tensorflow/core/framework/graph.proto +# proto-message: tensorflow.GraphDef +node { + name: "P0" + input: "Grp_2/Pinput" + input: "P0_M0" + attr { + key: "side" + value { + placeholder: "LEFT" + } + } + attr { + key: "type" + value { + placeholder: "PORT" + } + } + attr { + key: "x" + value { + f: 0 + } + } + attr { + key: "y" + value { + f: 100 + } + } +} +node { + name: "P1" + attr { + key: "side" + value { + placeholder: "TOP" + } + } + attr { + key: "type" + value { + placeholder: "PORT" + } + } + attr { + key: "x" + value { + f: 499 + } + } + attr { + key: "y" + value { + f: 499 + } + } +} +node { + name: "M0" + attr { + key: "height" + value { + f: 120 + } + } + attr { + key: "orientation" + value { + placeholder: "N" + } + } + attr { + key: "type" + value { + placeholder: "MACRO" + } + } + attr { + key: "width" + value { + f: 120 + } + } +} +node { + name: "M1" + attr { + key: "height" + value { + f: 40 + } + } + attr { + key: "orientation" + value { + placeholder: "N" + } + } + attr { + key: "type" + value { + placeholder: "MACRO" + } + } + attr { + key: "width" + value { + f: 80 + } + } +} +node { + name: "P0_M0" + attr { + key: "macro_name" + value { + placeholder: "M0" + } + } + attr { + key: "type" + value { + placeholder: "MACRO_PIN" + } + } + attr { + key: "x_offset" + value { + f: -60 + } + } + attr { + key: "y_offset" + value { + f: 60 + } + } +} +node { + name: "P1_M0" + input: "Grp_2/Pinput" + attr { + key: "macro_name" + value { + placeholder: "M0" + } + } + attr { + key: "type" + value { + placeholder: "MACRO_PIN" + } + } + attr { + key: "x_offset" + value { + f: 60 + } + } + attr { + key: "y_offset" + value { + f: 60 + } + } +} +node { + name: "P0_M1" + attr { + key: "macro_name" + value { + placeholder: "M1" + } + } + attr { + key: "type" + value { + placeholder: "MACRO_PIN" + } + } + attr { + key: "x_offset" + value { + f: -40 + } + } + attr { + key: "y_offset" + value { + f: 20 + } + } +} +node { + name: "P1_M1" + input: "P1" + attr { + key: "macro_name" + value { + placeholder: "M1" + } + } + attr { + key: "type" + value { + placeholder: "MACRO_PIN" + } + } + attr { + key: "x_offset" + value { + f: 40 + } + } + attr { + key: "y_offset" + value { + f: 20 + } + } +} +node { + name: "Grp_2" + attr { + key: "height" + value { + f: 0.20625865 + } + } + attr { + key: "type" + value { + placeholder: "macro" + } + } + attr { + key: "width" + value { + f: 17.128008 + } + } + attr { + key: "x" + value { + f: 20 + } + } + attr { + key: "y" + value { + f: 45 + } + } +} +node { + name: "Grp_2/Poutput_single_0" + input: "P0_M1" + attr { + key: "macro_name" + value { + placeholder: "Grp_2" + } + } + attr { + key: "type" + value { + placeholder: "macro_pin" + } + } + attr { + key: "x" + value { + f: 20 + } + } + attr { + key: "x_offset" + value { + f: 0 + } + } + attr { + key: "y" + value { + f: 45 + } + } + attr { + key: "y_offset" + value { + f: 0 + } + } +} +node { + name: "Grp_2/Pinput" + attr { + key: "macro_name" + value { + placeholder: "Grp_2" + } + } + attr { + key: "type" + value { + placeholder: "macro_pin" + } + } + attr { + key: "x" + value { + f: 20 + } + } + attr { + key: "x_offset" + value { + f: 0 + } + } + attr { + key: "y" + value { + f: 45 + } + } + attr { + key: "y_offset" + value { + f: 0 + } + } +} diff --git a/src/gfn/gym/helpers/chip_design/utils.py b/src/gfn/gym/helpers/chip_design/utils.py new file mode 100644 index 00000000..c3b19abe --- /dev/null +++ b/src/gfn/gym/helpers/chip_design/utils.py @@ -0,0 +1,783 @@ +# coding=utf-8 +# Copyright 2021 The Circuit Training Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# pyright: reportCallIssue=false, reportArgumentType=false + +"""A collection of non-prod utility functions for placement. + +All the dependencies in this files should be non-prod. +""" + +import datetime +import logging +import re +import textwrap +from typing import Dict, Iterator, List, Optional, Tuple, Union + +import numpy as np + +from . import plc_client + +COST_COMPONENTS = ["wirelength", "congestion", "density"] + + +def cost_info_function( + plc: plc_client.PlacementCost, + done: bool, + infeasible_state: bool = False, + wirelength_weight: float = 1.0, + density_weight: float = 1.0, + congestion_weight: float = 0.5, +) -> tuple[float, dict[str, float]]: + """Returns the RL cost and info.""" + del infeasible_state + proxy_cost = 0.0 + info = {cost: -1.0 for cost in COST_COMPONENTS} + + if not done: + return proxy_cost, info + + if wirelength_weight > 0.0: + info["wirelength"] = plc.get_cost() + proxy_cost += wirelength_weight * info["wirelength"] + + if congestion_weight > 0.0: + info["congestion"] = plc.get_congestion_cost() + proxy_cost += congestion_weight * info["congestion"] + + if density_weight > 0.0: + info["density"] = plc.get_density_cost() + proxy_cost += density_weight * info["density"] + + return proxy_cost, info + + +def nodes_of_types(plc: plc_client.PlacementCost, type_list: List[str]) -> Iterator[int]: + """Yields the index of a node of certain types.""" + i = 0 + while True: + node_type = plc.get_node_type(i) + if not node_type: + break + if node_type in type_list: + yield i + i += 1 + + +def get_node_xy_coordinates( + plc: plc_client.PlacementCost, +) -> Dict[int, Tuple[float, float]]: + """Returns all node x,y coordinates (canvas) in a dict.""" + node_coords = dict() + for node_index in nodes_of_types(plc, ["MACRO", "STDCELL", "PORT"]): + if plc.is_node_placed(node_index): + node_coords[node_index] = plc.get_node_location(node_index) + return node_coords + + +def get_macro_orientations(plc: plc_client.PlacementCost) -> Dict[int, int]: + """Returns all macros' orientations in a dict.""" + macro_orientations = dict() + for node_index in nodes_of_types(plc, ["MACRO"]): + macro_orientations[node_index] = plc.get_macro_orientation(node_index) + return macro_orientations + + +def restore_node_xy_coordinates( + plc: plc_client.PlacementCost, node_coords: Dict[int, Tuple[float, float]] +) -> None: + for node_index, coords in node_coords.items(): + if not plc.is_node_fixed(node_index): + plc.update_node_coords(node_index, coords[0], coords[1]) + + +def restore_macro_orientations( + plc: plc_client.PlacementCost, macro_orientations: Dict[int, int] +) -> None: + for node_index, orientation in macro_orientations.items(): + plc.update_macro_orientation(node_index, orientation) + + +def extract_attribute_from_comments( + attribute: str, filenames: List[str] +) -> Optional[str]: + """Parses the files' comments section, tries to extract the attribute.""" + for filename in filenames: + if filename: + f = filename.split(",")[0] + if f: + with open(f, "r") as infile: + for line in infile: + if line.startswith("#"): + match = re.search(rf"{attribute} : ([-\w]+)", line) + if match: + return match.group(1) + else: + break + return None + + +def get_blockages_from_comments( + filenames: Union[str, List[str]], +) -> Optional[List[List[float]]]: + """Returns list of blockages if they exist in the file's comments section.""" + for filename in filenames: + if not filename: + continue + blockages = [] + try: + with open(filename, "r") as infile: + for line in infile: + if line.startswith("# Blockage : "): + blockages.append([float(x) for x in line.split()[3:8]]) + elif not line.startswith("#"): + break + except OSError: + logging.error("could not read file %s.", filename) + if blockages: + return blockages + + +def extract_sizes_from_comments( + filenames: List[str], +) -> Tuple[Optional[float], Optional[float], Optional[int], Optional[int]]: + """Parses the file's comments section, tries to extract canvas/grid sizes.""" + canvas_width, canvas_height = None, None + grid_cols, grid_rows = None, None + for filename in filenames: + if not filename: + continue + with open(filename, "r") as infile: + for line in infile: + if line.startswith("#"): + fp_re = re.search( + r"FP bbox: {([\d\.]+) ([\d\.]+)} {([\d\.]+) ([\d\.]+)}", line + ) + if fp_re: + canvas_width = float(fp_re.group(3)) + canvas_height = float(fp_re.group(4)) + continue + plc_wh = re.search(r"Width : ([\d\.]+) Height : ([\d\.]+)", line) + if plc_wh: + canvas_width = float(plc_wh.group(1)) + canvas_height = float(plc_wh.group(2)) + continue + plc_cr = re.search(r"Columns : ([\d]+) Rows : ([\d]+)", line) + if plc_cr: + grid_cols = int(plc_cr.group(1)) + grid_rows = int(plc_cr.group(2)) + else: + break + return canvas_width, canvas_height, grid_cols, grid_rows + + +def fix_port_coordinates(plc: plc_client.PlacementCost) -> None: + """Find all ports and fix their coordinates.""" + for node in nodes_of_types(plc, ["PORT"]): + plc.fix_node_coord(node) + + +def create_placement_cost( + netlist_file: str, + init_placement: str, + overlap_threshold: float = 4e-3, + congestion_smooth_range: int = 5, + macro_macro_x_spacing: float = 0.1, + macro_macro_y_spacing: float = 0.1, + boundary_check: bool = False, + horizontal_routes_per_micron: float = 70.33, + vertical_routes_per_micron: float = 74.51, + macro_horizontal_routing_allocation: float = 51.79, + macro_vertical_routing_allocation: float = 51.79, + routes_per_congestion_grid: int = 1000, + blockages: Optional[List[List[float]]] = None, + fixed_macro_names_regex: Optional[List[str]] = None, + legacy_congestion_grid: bool = False, +) -> plc_client.PlacementCost: + """Creates a placement_cost object.""" + if not netlist_file: + raise ValueError("netlist_file should be provided.") + + block_name = extract_attribute_from_comments("Block", [init_placement, netlist_file]) + if not block_name: + logging.warning( + "block_name is not set. Please add the block_name in:\n%s\nor in:\n%s", + netlist_file, + init_placement, + ) + + plc = plc_client.PlacementCost( + netlist_file, + macro_macro_x_spacing=macro_macro_x_spacing, + macro_macro_y_spacing=macro_macro_y_spacing, + ) + + plc.make_soft_macros_square() + + blockages = blockages or get_blockages_from_comments([netlist_file, init_placement]) + if blockages: + for blockage in blockages: + plc.create_blockage(*blockage) + + canvas_width, canvas_height, grid_cols, grid_rows = extract_sizes_from_comments( + [netlist_file, init_placement] + ) + if canvas_width and canvas_height: + plc.set_canvas_size(canvas_width, canvas_height) + if grid_cols and grid_rows: + plc.set_placement_grid(grid_cols, grid_rows) + if legacy_congestion_grid: + plc.set_congestion_grid(grid_cols, grid_rows) + + plc.set_project_name("circuit_training") + plc.set_block_name(block_name or "unset_block") + plc.set_routes_per_micron(horizontal_routes_per_micron, vertical_routes_per_micron) + plc.set_macro_routing_allocation( + macro_horizontal_routing_allocation, macro_vertical_routing_allocation + ) + plc.set_congestion_smooth_range(congestion_smooth_range) + + if not legacy_congestion_grid: + congestion_grid_size = ( + 2.0 + * routes_per_congestion_grid + / (horizontal_routes_per_micron + vertical_routes_per_micron) + ) + canvas_width, canvas_height = plc.get_canvas_width_height() + congestion_grid_cols = max(1, int(canvas_width / congestion_grid_size)) + congestion_grid_rows = max(1, int(canvas_height / congestion_grid_size)) + plc.set_congestion_grid(congestion_grid_cols, congestion_grid_rows) + + plc.set_overlap_threshold(overlap_threshold) + plc.set_canvas_boundary_check(boundary_check) + if init_placement: + plc.restore_placement(init_placement) + fix_port_coordinates(plc) + + if fixed_macro_names_regex: + logging.info("Fixing macro locations using regex.") + fix_macros_by_regex(plc, fixed_macro_names_regex) + + return plc + + +def get_node_type_counts(plc: plc_client.PlacementCost) -> Dict[str, int]: + """Returns number of each type of nodes in the netlist.""" + counts = { + "MACRO": 0, + "STDCELL": 0, + "PORT": 0, + "MACRO_PIN": 0, + "SOFT_MACRO": 0, + "HARD_MACRO": 0, + "SOFT_MACRO_PIN": 0, + "HARD_MACRO_PIN": 0, + } + + for node_index in nodes_of_types(plc, ["MACRO", "STDCELL", "PORT", "MACRO_PIN"]): + node_type = plc.get_node_type(node_index) + counts[node_type] += 1 + if node_type == "MACRO": + if plc.is_node_soft_macro(node_index): + counts["SOFT_MACRO"] += 1 + else: + counts["HARD_MACRO"] += 1 + if node_type == "MACRO_PIN": + ref_id = plc.get_ref_node_id(node_index) + if plc.is_node_soft_macro(ref_id): + counts["SOFT_MACRO_PIN"] += 1 + else: + counts["HARD_MACRO_PIN"] += 1 + return counts + + +def make_blockage_text(plc: plc_client.PlacementCost) -> str: + ret = "" + for blockage in plc.get_blockages(): + ret += "Blockage : {}\n".format(" ".join([str(b) for b in blockage])) + return ret + + +def save_placement( + plc: plc_client.PlacementCost, filename: str, user_comments: str = "" +) -> None: + """Saves the placement file with some information in the comments section.""" + cols, rows = plc.get_grid_num_columns_rows() + congestion_cols, congestion_rows = plc.get_congestion_grid_num_columns_rows() + width, height = plc.get_canvas_width_height() + hor_routes, ver_routes = plc.get_routes_per_micron() + hor_macro_alloc, ver_macro_alloc = plc.get_macro_routing_allocation() + smooth = plc.get_congestion_smooth_range() + info = textwrap.dedent( + """\ + Placement file for Circuit Training + Source input file(s) : {src_filename} + This file : {filename} + Date : {date} + Columns : {cols} Rows : {rows} + Congestion Columns : {congestion_cols} Congestion Rows : {congestion_rows} + Width : {width:.3f} Height : {height:.3f} + Area : {area} + Wirelength : {wl:.3f} + Wirelength cost : {wlc:.4f} + Congestion cost : {cong:.4f} + Density cost : {density:.4f} + Project : {project} + Block : {block_name} + Routes per micron, hor : {hor_routes:.3f} ver : {ver_routes:.3f} + Routes used by macros, hor : {hor_macro_alloc:.3f} ver : {ver_macro_alloc:.3f} + Smoothing factor : {smooth} + Overlap threshold : {overlap_threshold} + """.format( + src_filename=plc.get_source_filename(), + filename=filename, + date=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + cols=cols, + rows=rows, + width=width, + congestion_cols=congestion_cols, + congestion_rows=congestion_rows, + height=height, + area=plc.get_area(), + wl=plc.get_wirelength(), + wlc=plc.get_cost(), + cong=plc.get_congestion_cost(), + density=plc.get_density_cost(), + project=plc.get_project_name(), + block_name=plc.get_block_name(), + hor_routes=hor_routes, + ver_routes=ver_routes, + hor_macro_alloc=hor_macro_alloc, + ver_macro_alloc=ver_macro_alloc, + smooth=smooth, + overlap_threshold=plc.get_overlap_threshold(), + ) + ) + + info += "\n" + make_blockage_text(plc) + "\n" + info += "\nCounts of node types:\n" + node_type_counts = get_node_type_counts(plc) + for node_type in sorted(node_type_counts): + info += "{:<15} : {:>9}\n".format(node_type + "s", node_type_counts[node_type]) + if user_comments: + info += "\nUser comments:\n" + user_comments + "\n" + info += "\nnode_index x y orientation fixed" + return plc.save_placement(filename, info) + + +def fd_placement_schedule( + plc: plc_client.PlacementCost, + num_steps: Tuple[int, ...] = (100, 100, 100), + io_factor: float = 1.0, + move_distance_factors: Tuple[float, ...] = (1.0, 1.0, 1.0), + attract_factor: Tuple[float, ...] = (100.0, 1.0e-3, 1.0e-5), + repel_factor: Tuple[float, ...] = (0.0, 1.0e6, 1.0e7), + use_current_loc: bool = False, + move_macros: bool = False, +) -> None: + """A placement schedule that uses force directed method.""" + assert len(num_steps) == len(move_distance_factors) + assert len(num_steps) == len(repel_factor) + assert len(num_steps) == len(attract_factor) + canvas_size = max(plc.get_canvas_width_height()) + max_move_distance = [ + f * canvas_size / s for s, f in zip(num_steps, move_distance_factors) + ] + move_stdcells = True + log_scale_conns = False + use_sizes = False + plc.optimize_stdcells( + use_current_loc, + move_stdcells, + move_macros, + log_scale_conns, + use_sizes, + io_factor, + num_steps, + max_move_distance, + attract_factor, + repel_factor, + ) + + +def read_node_order_file( + plc: plc_client.PlacementCost, node_order_file: str +) -> List[int]: + """Reads the node order from a file.""" + with open(node_order_file, "r") as f: + node_order = [plc.get_node_index(line.strip()) for line in f.readlines()] + return node_order + + +def save_node_order_file( + plc: plc_client.PlacementCost, + node_order: List[int], + node_order_file: str, +) -> None: + """Saves the node order to a file.""" + with open(node_order_file, "w") as f: + for node_index in node_order: + if not plc.is_node_soft_macro(node_index): + f.write(plc.get_node_name(node_index) + "\n") + + +def get_ordered_node_indices( + mode: str, + plc: plc_client.PlacementCost, + seed: int = 111, + node_order_file: str = "", + exclude_fixed_nodes: bool = True, +) -> List[int]: + """Returns an ordering of node indices according to the specified mode.""" + rng = np.random.default_rng(seed=seed) + macro_indices = plc.get_macro_indices() + hard_macro_indices = [m for m in macro_indices if not plc.is_node_soft_macro(m)] + soft_macro_indices = [m for m in macro_indices if plc.is_node_soft_macro(m)] + + def macro_area(idx): + if idx not in hard_macro_indices: + return 0.0 + w, h = plc.get_node_width_height(idx) + return w * h + + canvas_width, canvas_height = plc.get_canvas_width_height() + + def distance_to_edge(idx): + x, y = plc.get_node_location(idx) + return min( + x, y, canvas_width - x - canvas_width, canvas_height - y - canvas_height + ) + + logging.info("node_order: %s", mode) + if mode == "legalization_order": + ordered_indices = sorted( + hard_macro_indices, + key=distance_to_edge, + ) + sorted( + soft_macro_indices, + key=macro_area, + reverse=True, + ) + elif mode == "descending_size_macro_first": + ordered_indices = sorted( + hard_macro_indices, + key=macro_area, + reverse=True, + ) + sorted( + soft_macro_indices, + key=macro_area, + reverse=True, + ) + elif mode == "random": + rng.shuffle(macro_indices) + ordered_indices = macro_indices + elif mode == "random_macro_first": + rng.shuffle(hard_macro_indices) + logging.info("ordered hard macros: %s", hard_macro_indices) + ordered_indices = hard_macro_indices + soft_macro_indices + elif mode == "fake_net_topological": + fake_net_adj = {} + fake_nets = plc.get_fake_nets() + nodes = ( + set([nm[0] for _, nm in fake_nets]) + .union(set([nm[1] for _, nm in fake_nets])) + .union(set(hard_macro_indices)) + ) + is_port = {n: n not in hard_macro_indices for n in nodes} + macro_with_fake_net = {n: False for n in nodes} + for fake_net in fake_nets: + weight = fake_net[0] + if weight <= 0: + continue + node_0 = fake_net[1][0] + node_1 = fake_net[1][1] + fake_net_adj[(node_0, node_1)] = weight + fake_net_adj[(node_1, node_0)] = weight + if node_0 in hard_macro_indices: + macro_with_fake_net[node_0] = True + if node_1 in hard_macro_indices: + macro_with_fake_net[node_1] = True + + closeness = {n: 0.0 for n in nodes} + + source = max(nodes, key=lambda n: (is_port[n], macro_area(n))) + visited_nodes = [source] + last_node = source + del closeness[last_node] + + while len(visited_nodes) < len(nodes): + for node in nodes: + if node in visited_nodes: + continue + if (node, last_node) in fake_net_adj: + closeness[node] += fake_net_adj[(node, last_node)] + + last_node = max( + closeness, + key=lambda n: (closeness[n], macro_with_fake_net[n], macro_area(n)), + ) + visited_nodes.append(last_node) + del closeness[last_node] + + ordered_indices = [n for n in visited_nodes if n in hard_macro_indices] + sorted( + soft_macro_indices, key=macro_area + )[::-1] + elif mode == "file": + ordered_indices = read_node_order_file(plc, node_order_file) + else: + raise ValueError("{} is an unsupported node placement mode.".format(mode)) + + if exclude_fixed_nodes: + ordered_indices = [m for m in ordered_indices if not plc.is_node_fixed(m)] + return ordered_indices + + +def extract_blockages_from_file( + filename: str, canvas_width: float, canvas_height: float +) -> Optional[List[List[float]]]: + """Reads blockage information from a given file.""" + blockages = [] + try: + with open(filename, "r") as infile: + for line in infile: + if line.startswith("#"): + continue + items = line.split() + if len(items) != 4: + raise ValueError( + "Blockage file does not meet expected format" + "Expected format " + ) + llx = float(items[0]) + lly = float(items[1]) + urx = float(items[2]) + ury = float(items[3]) + if llx >= urx: + raise ValueError(f"Illegal blockage llx {llx} >= urx {urx}") + if lly >= ury: + raise ValueError(f"Illegal blockage lly {lly} >= ury {ury}") + if llx < 0: + raise ValueError(f"Illegal blockage llx {llx} < 0") + if urx > canvas_width: + raise ValueError( + f"Illegal blockage urx {urx} > canvas width {canvas_width}" + ) + if lly < 0: + raise ValueError(f"Illegal blockage lly {lly} < 0") + if ury > canvas_height: + raise ValueError( + f"Illegal blockage ury {ury} > canvas height {canvas_height}" + ) + blockages.append([llx, lly, urx, ury, 0.99]) + except IOError: + logging.error("Could not read file %s", filename) + return blockages + + +def get_node_locations(plc: plc_client.PlacementCost) -> Dict[int, int]: + """Returns all node grid locations (macros and stdcells) in a dict.""" + node_locations = dict() + for i in nodes_of_types(plc, ["MACRO", "STDCELL"]): + node_locations[i] = plc.get_grid_cell_of_node(i) + return node_locations + + +def get_node_ordering_by_size(plc: plc_client.PlacementCost) -> List[int]: + """Returns the list of nodes (macros and stdcells) ordered by area.""" + node_areas = dict() + for i in nodes_of_types(plc, ["MACRO", "STDCELL"]): + if plc.is_node_fixed(i): + continue + w, h = plc.get_node_width_height(i) + node_areas[i] = w * h + return sorted(node_areas, key=node_areas.get, reverse=True) + + +def grid_locations_near( + plc: plc_client.PlacementCost, start_grid_index: int +) -> Iterator[int]: + """Yields node indices closest to the start_grid_index.""" + cols, rows = plc.get_grid_num_columns_rows() + start_col, start_row = start_grid_index % cols, int(start_grid_index / cols) + for distance in range(cols + rows): + for row_offset in range(-distance, distance + 1): + for col_offset in range(-distance, distance + 1): + if abs(row_offset) + abs(col_offset) != distance: + continue + new_col = start_col + col_offset + new_row = start_row + row_offset + if new_col < 0 or new_row < 0 or new_col >= cols or new_row >= rows: + continue + yield int(new_col + new_row * cols) + + +def place_near(plc: plc_client.PlacementCost, node_index: int, location: int) -> bool: + """Places a node (legally) closest to the given location.""" + for loc in grid_locations_near(plc, location): + if plc.can_place_node(node_index, loc): + plc.place_node(node_index, loc) + return True + return False + + +def disconnect_high_fanout_nets( + plc: plc_client.PlacementCost, max_allowed_fanouts: int = 500 +) -> None: + high_fanout_nets = [] + for i in nodes_of_types(plc, ["PORT", "STDCELL", "MACRO_PIN"]): + num_fanouts = len(plc.get_fan_outs_of_node(i)) + if num_fanouts > max_allowed_fanouts: + print( + "Disconnecting node: {} with {} fanouts.".format( + plc.get_node_name(i), num_fanouts + ) + ) + high_fanout_nets.append(i) + plc.disconnect_nets(high_fanout_nets) + + +def legalize_placement(plc: plc_client.PlacementCost) -> bool: + """Places the nodes to legal positions snapping to grid cells.""" + fix_port_coordinates(plc) + node_locations = get_node_locations(plc) + previous_xy_coords = get_node_xy_coordinates(plc) + total_macro_displacement = 0 + total_macros = 0 + plc.unplace_all_nodes() + ordered_nodes = get_node_ordering_by_size(plc) + for node in ordered_nodes: + if not place_near(plc, node, node_locations[node]): + print("Could not place node") + return False + if node in previous_xy_coords and not plc.is_node_soft_macro(node): + x, y = plc.get_node_location(node) + px, py = previous_xy_coords[node] + print( + "x/y displacement: dx = {}, dy = {}, macro: {}".format( + x - px, y - py, plc.get_node_name(node) + ) + ) + total_macro_displacement += abs(x - px) + abs(y - py) + total_macros += 1 + print( + "Total macro displacement: {}, avg: {}".format( + total_macro_displacement, total_macro_displacement / total_macros + ) + ) + return True + + +def fix_macros_by_regex(plc: plc_client.PlacementCost, macro_regex_str_list: List[str]): + """Fix macro locations given a list of macro name regex strings.""" + regexs = [] + for regex_str in macro_regex_str_list: + regexs.append(re.compile(regex_str)) + + hard_macros = [] + for m in plc.get_macro_indices(): + if plc.is_node_soft_macro(m): + continue + hard_macros.append(m) + + total = 0 + for m in plc.get_macro_indices(): + if plc.is_node_soft_macro(m): + continue + macro_name = plc.get_node_name(m) + for regex in regexs: + if regex.fullmatch(macro_name): + plc.fix_node_coord(m) + total += 1 + logging.info("Fixed macro: %s", macro_name) + continue + logging.info("Total number of fixed macros: %d", total) + + +def create_blockages_by_spacing_constraints( + canvas_width: float, + canvas_height: float, + macro_boundary_x_spacing: float = 0, + macro_boundary_y_spacing: float = 0, + rectilinear_blockages: Optional[List[List[float]]] = None, +) -> List[List[float]]: + """Create blockages using macro-to-boundary spacing constraints.""" + blockages = [] + blockage_rate = 0.1 + if macro_boundary_x_spacing: + assert 0 < macro_boundary_x_spacing <= canvas_width + blockages.append([0, 0, macro_boundary_x_spacing, canvas_height, blockage_rate]) + blockages.append( + [ + canvas_width - macro_boundary_x_spacing, + 0, + canvas_width, + canvas_height, + blockage_rate, + ] + ) + if macro_boundary_y_spacing: + assert 0 < macro_boundary_y_spacing <= canvas_height + blockages.append([0, 0, canvas_width, macro_boundary_y_spacing, blockage_rate]) + blockages.append( + [ + 0, + canvas_height - macro_boundary_y_spacing, + canvas_width, + canvas_height, + blockage_rate, + ] + ) + for rectilinear_blockage in rectilinear_blockages or []: + minx, miny, maxx, maxy, _ = rectilinear_blockage + if macro_boundary_x_spacing: + blockages.append( + [ + max(minx - macro_boundary_x_spacing, 0), + max(miny - macro_boundary_y_spacing, 0), + minx, + min(maxy + macro_boundary_y_spacing, canvas_height), + blockage_rate, + ] + ) + blockages.append( + [ + maxx, + max(miny - macro_boundary_y_spacing, 0), + min(maxx + macro_boundary_x_spacing, canvas_width), + min(maxy + macro_boundary_y_spacing, canvas_height), + blockage_rate, + ] + ) + if macro_boundary_y_spacing: + blockages.append( + [ + max(minx - macro_boundary_x_spacing, 0), + max(miny - macro_boundary_y_spacing, 0), + min(maxx + macro_boundary_x_spacing, canvas_width), + miny, + blockage_rate, + ] + ) + blockages.append( + [ + max(minx - macro_boundary_x_spacing, 0), + maxy, + min(maxx + macro_boundary_x_spacing, canvas_width), + min(maxy + macro_boundary_y_spacing, canvas_height), + blockage_rate, + ] + ) + return blockages diff --git a/src/gfn/states.py b/src/gfn/states.py index a8bcc178..ffb9cf38 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -457,13 +457,11 @@ def __init__( forward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions), dtype=torch.bool, - device=self.__class__.device, ) if backward_masks is None: backward_masks = torch.ones( (*self.batch_shape, self.__class__.n_actions - 1), dtype=torch.bool, - device=self.__class__.device, ) self.forward_masks: torch.Tensor = forward_masks diff --git a/testing/test_environments.py b/testing/test_environments.py index 6df9d1fd..91946ab2 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -7,7 +7,8 @@ from gfn.actions import GraphActions, GraphActionType from gfn.env import NonValidActionsError -from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym import Box, ChipDesign, DiscreteEBM, HyperGrid +from gfn.gym.chip_design import ChipDesignStates from gfn.gym.graph_building import GraphBuilding from gfn.gym.perfect_tree import PerfectBinaryTree from gfn.gym.set_addition import SetAddition @@ -779,3 +780,30 @@ def test_perfect_binary_tree_bwd_step(): expected_states = torch.tensor([[0], [0]], dtype=torch.long) assert torch.equal(states.tensor, expected_states) assert torch.all(states.is_initial_state) + + +def test_chip_design(): + BATCH_SIZE = 2 + + env = ChipDesign() + states = env.reset(batch_shape=BATCH_SIZE) + assert states.tensor.shape == (BATCH_SIZE, env.n_macros) + assert torch.all(states.tensor == -1) + + # Place macros + for i in range(env.n_macros): + actions = env.actions_from_tensor(format_tensor([i] * BATCH_SIZE)) + expected_tensor = states.tensor.clone() + states = env._step(states, actions) + expected_tensor[..., i] = i + assert torch.equal(states.tensor, expected_tensor) + + # Exit action (valid) + actions = env.actions_from_tensor(format_tensor([env.n_actions - 1] * BATCH_SIZE)) + final_states = env._step(states, actions) + assert torch.all(final_states.is_sink_state) + + # Check rewards + assert isinstance(final_states, ChipDesignStates) + rewards = env.log_reward(final_states) + assert torch.all(rewards == rewards[0]) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index a77a48b6..4979bfad 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -24,6 +24,7 @@ ) from tutorials.examples.train_bit_sequences import main as train_bitsequence_main from tutorials.examples.train_box import main as train_box_main +from tutorials.examples.train_chip_design import main as train_chip_design_main from tutorials.examples.train_conditional import main as train_conditional_main from tutorials.examples.train_discreteebm import main as train_discreteebm_main from tutorials.examples.train_graph_ring import main as train_graph_ring_main @@ -258,6 +259,16 @@ class ConditionalArgs(CommonArgs): no_cuda: bool = True # Disable CUDA for tests +@dataclass +class ChipDesignArgs(CommonArgs): + n_iterations: int = 10 + embedding_dim: int = 32 + batch_size: int = 16 + seed: int = 4444 + lr: float = 1e-3 + no_cuda: bool = True # Disable CUDA for tests + + @pytest.mark.parametrize("ndim", [2, 4]) @pytest.mark.parametrize("height", [8, 16]) @pytest.mark.parametrize("replay_buffer_size", [0, 1000]) @@ -741,5 +752,9 @@ def test_hypergrid_exploration_smoke(): train_hypergrid_exploration_main(namespace_args) # Runs without errors. -if __name__ == "__main__": - test_graph_triangle_smoke() +def test_chip_design_smoke(): + """Smoke test for the chip design training script.""" + args = ChipDesignArgs() + args_dict = asdict(args) + namespace_args = Namespace(**args_dict) + train_chip_design_main(namespace_args) # Runs without errors. diff --git a/tutorials/examples/train_chip_design.py b/tutorials/examples/train_chip_design.py new file mode 100644 index 00000000..14988403 --- /dev/null +++ b/tutorials/examples/train_chip_design.py @@ -0,0 +1,106 @@ +import argparse + +import torch +import torch.nn as nn +from tqdm import tqdm + +from gfn.estimators import DiscretePolicyEstimator +from gfn.gflownet import TBGFlowNet +from gfn.gym.chip_design import ChipDesign, ChipDesignStates +from gfn.preprocessors import Preprocessor +from gfn.utils.modules import MLP + + +class ChipDesignPreprocessor(Preprocessor): + def __init__(self, env, embedding_dim=64): + super().__init__(output_dim=env.n_macros * embedding_dim) + self.embedding = nn.Embedding( + env.n_grid_cells + 2, embedding_dim + ) # +2 for -1 and -2 + self.n_macros = env.n_macros + self.embedding_dim = embedding_dim + + def preprocess(self, states): + # states.tensor is (batch_size, n_macros) with values from -2 to n_grid_cells-1 + # We add 2 to make them non-negative for embedding. + preprocessed_states = states.tensor + 2 + embedded = self.embedding(preprocessed_states) + # embedded shape: (batch_size, n_macros, embedding_dim) + # flatten it + return embedded.view(-1, self.n_macros * self.embedding_dim) + + +def main(args): + device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + ) + env = ChipDesign(device=str(device)) + + preprocessor = ChipDesignPreprocessor(env, embedding_dim=args.embedding_dim) + + output_dim = preprocessor.output_dim + assert output_dim is not None + module_pf = MLP( + input_dim=output_dim, + output_dim=env.n_actions, + hidden_dim=args.hidden_dim, + n_hidden_layers=args.n_hidden, + ) + module_pb = MLP( + input_dim=output_dim, + output_dim=env.n_actions - 1, + hidden_dim=args.hidden_dim, + n_hidden_layers=args.n_hidden, + trunk=module_pf.trunk, + ) + + pf_estimator = DiscretePolicyEstimator( + module_pf, env.n_actions, preprocessor=preprocessor + ) + pb_estimator = DiscretePolicyEstimator( + module_pb, env.n_actions, preprocessor=preprocessor, is_backward=True + ) + + gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0).to(device) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=args.lr) + + print("Sampling initial states...") + # Sample some final states and print them + final_states = gflownet.sample_terminating_states(env, n=5) + final_rewards = torch.exp(env.log_reward(final_states)) + print("Sampled final placements (macro locations):") + for i in range(len(final_states)): + print(final_states.tensor[i], " with reward ", final_rewards[i].item()) + + for i in tqdm(range(args.n_iterations)): + trajectories = gflownet.sample_trajectories(env, n=args.batch_size) + training_samples = gflownet.to_training_samples(trajectories) + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) + loss.backward() + optimizer.step() + + if (i + 1) % 100 == 0: + print(f"Iteration {i+1}, Loss: {loss.item()}") + + print("Training finished.") + # Sample some final states and print them + final_states = gflownet.sample_terminating_states(env, n=5) + assert isinstance(final_states, ChipDesignStates) + final_rewards = torch.exp(env.log_reward(final_states)) + print("Sampled final placements (macro locations):") + for i in range(len(final_states)): + print(final_states.tensor[i], " with reward ", final_rewards[i].item()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--no_cuda", action="store_true") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--n_iterations", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--embedding_dim", type=int, default=32) + parser.add_argument("--hidden_dim", type=int, default=64) + parser.add_argument("--n_hidden", type=int, default=2) + args = parser.parse_args() + main(args)