Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gfn/gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .bitSequence import BitSequence, BitSequencePlus
from .box import Box
from .chip_design import ChipDesign
from .discrete_ebm import DiscreteEBM
from .graph_building import GraphBuilding, GraphBuildingOnEdges
from .hypergrid import HyperGrid
Expand All @@ -20,4 +21,5 @@
"GraphBuildingOnEdges",
"PerfectBinaryTree",
"SetAddition",
"ChipDesign",
]
253 changes: 253 additions & 0 deletions src/gfn/gym/chip_design.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""GFlowNet environment for chip placement."""

from typing import ClassVar, Optional, Sequence, cast

import torch

from gfn.actions import Actions
from gfn.env import DiscreteEnv
from gfn.gym.helpers.chip_design import SAMPLE_INIT_PLACEMENT, SAMPLE_NETLIST_FILE
from gfn.gym.helpers.chip_design import utils as placement_util
from gfn.gym.helpers.chip_design.utils import cost_info_function
from gfn.states import DiscreteStates


class ChipDesignStates(DiscreteStates):
"""A class to represent the states of the chip design environment."""

state_shape: ClassVar[tuple[int, ...]]
s0: ClassVar[torch.Tensor]
sf: ClassVar[torch.Tensor]
n_actions: ClassVar[int]

def __init__(
self,
tensor: torch.Tensor,
forward_masks: Optional[torch.Tensor] = None,
backward_masks: Optional[torch.Tensor] = None,
current_node_idx: Optional[torch.Tensor] = None,
):
super().__init__(
tensor=tensor, forward_masks=forward_masks, backward_masks=backward_masks
)
if current_node_idx is None:
is_unplaced = tensor == -1
is_unplaced_padded = torch.cat(
[
is_unplaced,
torch.ones_like(is_unplaced[..., :1]),
],
dim=-1,
)
current_node_idx = is_unplaced_padded.long().argmax(dim=-1)

self.current_node_idx = current_node_idx

def clone(self) -> "ChipDesignStates":
"""Creates a copy of the states."""
return self.__class__(
self.tensor.clone(),
current_node_idx=self.current_node_idx.clone(),
forward_masks=self.forward_masks.clone(),
backward_masks=self.backward_masks.clone(),
)

def __getitem__(self, index) -> "ChipDesignStates":
"""Gets a subset of the states."""
return self.__class__(
self.tensor[index],
current_node_idx=self.current_node_idx[index],
forward_masks=self.forward_masks[index],
backward_masks=self.backward_masks[index],
)

def __setitem__(self, index, value: "ChipDesignStates") -> None:
"""Sets a subset of the states."""
super().__setitem__(index, value)
self.current_node_idx[index] = value.current_node_idx

def extend(self, other: "ChipDesignStates") -> None:
"""Extends the states with another states."""
super().extend(other)
self.current_node_idx = torch.cat(
(self.current_node_idx, other.current_node_idx),
dim=len(self.batch_shape) - 1,
)

@classmethod
def stack(cls, states: Sequence["ChipDesignStates"]) -> "ChipDesignStates":
"""Stacks the states with another states."""
stacked = super().stack(states)
stacked.current_node_idx = torch.stack(
[s.current_node_idx for s in states],
dim=0,
)
return cast(ChipDesignStates, stacked)


class ChipDesign(DiscreteEnv):
"""
GFlowNet environment for chip placement.

The state is a vector of length `n_macros`, where `state[i]` is the grid
cell location of the i-th macro to be placed. Unplaced macros have a
location of -1.

Actions are integers from `0` to `n_grid_cells - 1`, representing the
grid cell to place the current macro on. Action `n_grid_cells` is the
exit action.
"""

def __init__(
self,
netlist_file: str = SAMPLE_NETLIST_FILE,
init_placement: str = SAMPLE_INIT_PLACEMENT,
std_cell_placer_mode: str = "fd",
wirelength_weight: float = 1.0,
density_weight: float = 1.0,
congestion_weight: float = 0.5,
device: str | None = None,
check_action_validity: bool = True,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

self.plc = placement_util.create_placement_cost(
netlist_file=netlist_file, init_placement=init_placement
)
self.std_cell_placer_mode = std_cell_placer_mode

self.wirelength_weight = wirelength_weight
self.density_weight = density_weight
self.congestion_weight = congestion_weight

self._grid_cols, self._grid_rows = self.plc.get_grid_num_columns_rows()
self.n_grid_cells = self._grid_cols * self._grid_rows

self._sorted_node_indices = placement_util.get_ordered_node_indices(
mode="descending_size_macro_first", plc=self.plc
)
self._hard_macro_indices = [
m for m in self._sorted_node_indices if not self.plc.is_node_soft_macro(m)
]
self.n_macros = len(self._hard_macro_indices)

s0 = torch.full((self.n_macros,), -1, dtype=torch.long, device=device)
sf = torch.full((self.n_macros,), -2, dtype=torch.long, device=device)
n_actions = self.n_grid_cells + 1

super().__init__(
n_actions=n_actions,
s0=s0,
state_shape=(self.n_macros,),
sf=sf,
check_action_validity=check_action_validity,
)
self.States: type[ChipDesignStates] = self.make_states_class()

def make_states_class(self) -> type[ChipDesignStates]:
"""Creates the ChipDesignStates class."""
env = self

class BaseChipDesignStates(ChipDesignStates):
state_shape = env.state_shape
s0 = env.s0
sf = env.sf
n_actions = env.n_actions
device = env.device

return BaseChipDesignStates

def _apply_state_to_plc(self, state_tensor: torch.Tensor):
"""Applies a single state tensor to the plc object."""
assert state_tensor.shape == (self.n_macros,)

self.plc.unplace_all_nodes()
for i in range(self.n_macros):
loc = state_tensor[i].item()
if loc != -1:
node_index = self._hard_macro_indices[i]
self.plc.place_node(node_index, loc)

def update_masks(self, states: ChipDesignStates) -> None:
"""Updates the forward and backward masks of the states."""
states.forward_masks.zero_()
states.backward_masks.zero_()

for i in range(len(states)):
state_tensor = states.tensor[i]
current_node_idx = states.current_node_idx[i].item()

if current_node_idx >= self.n_macros: # All macros placed
states.forward_masks[i, -1] = True # Only exit is possible
else:
# Apply partial placement to plc to get mask for next node
self._apply_state_to_plc(state_tensor)
node_to_place = self._hard_macro_indices[int(current_node_idx)]
mask = self.plc.get_node_mask(node_to_place)
mask = torch.tensor(mask, dtype=torch.bool, device=self.device)
states.forward_masks[i, : self.n_grid_cells] = mask
states.forward_masks[i, -1] = False # No exit

if current_node_idx > 0:
last_placed_loc = state_tensor[int(current_node_idx - 1)].item()
assert last_placed_loc != -1, "Last placed location should not be -1"
states.backward_masks[i, int(last_placed_loc)] = True

def step(self, states: ChipDesignStates, actions: Actions) -> ChipDesignStates:
"""Performs a forward step in the environment."""
new_tensor = states.tensor.clone()

non_exit_mask = ~actions.is_exit
if torch.any(non_exit_mask):
rows = torch.arange(len(states), device=self.device)[non_exit_mask]
cols = states.current_node_idx[non_exit_mask]
new_tensor[rows, cols] = actions.tensor[non_exit_mask].squeeze(-1)

if torch.any(actions.is_exit):
new_tensor[actions.is_exit] = self.sf

new_current_node_idx = states.current_node_idx.clone()
new_current_node_idx[non_exit_mask] += 1

return self.States(tensor=new_tensor, current_node_idx=new_current_node_idx)

def backward_step(
self, states: ChipDesignStates, actions: Actions
) -> ChipDesignStates:
"""Performs a backward step in the environment."""
new_tensor = states.tensor.clone()
rows = torch.arange(len(states), device=self.device)
cols = states.current_node_idx - 1
new_tensor[rows, cols] = -1

new_current_node_idx = states.current_node_idx - 1
return self.States(tensor=new_tensor, current_node_idx=new_current_node_idx)

def analytical_placer(self):
"""Places standard cells using an analytical placer."""
if self.std_cell_placer_mode == "fd":
placement_util.fd_placement_schedule(self.plc)
else:
raise ValueError(
f"{self.std_cell_placer_mode} is not a supported std_cell_placer_mode."
)

def log_reward(self, final_states: ChipDesignStates) -> torch.Tensor:
"""Computes the log reward of the final states."""
rewards = torch.zeros(len(final_states), device=self.device)
for i in range(len(final_states)):
state_tensor = final_states.tensor[i]
self._apply_state_to_plc(state_tensor)

self.analytical_placer()

cost, _ = cost_info_function(
plc=self.plc,
done=True,
wirelength_weight=self.wirelength_weight,
density_weight=self.density_weight,
congestion_weight=self.congestion_weight,
)
rewards[i] = -cost
return rewards
24 changes: 24 additions & 0 deletions src/gfn/gym/helpers/chip_design/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Code copied from https://github.com/google-research/circuit_training
# @article{mirhoseini2021graph,
# title={A graph placement methodology for fast chip design},
# author={Mirhoseini*, Azalia and Goldie*, Anna and Yazgan, Mustafa and Jiang, Joe
# Wenjie and Songhori, Ebrahim and Wang, Shen and Lee, Young-Joon and Johnson,
# Eric and Pathak, Omkar and Nazi, Azade and Pak, Jiwoo and Tong, Andy and
# Srinivasa, Kavya and Hang, William and Tuncer, Emre and V. Le, Quoc and
# Laudon, James and Ho, Richard and Carpenter, Roger and Dean, Jeff},
# journal={Nature},
# volume={594},
# number={7862},
# pages={207--212},
# year={2021},
# publisher={Nature Publishing Group}
# }

import os

SAMPLE_NETLIST_FILE = os.path.join(
os.path.dirname(__file__), "test_data", "netlist.pb.txt"
)
SAMPLE_INIT_PLACEMENT = os.path.join(
os.path.dirname(__file__), "test_data", "initial.plc"
)
Loading