|
1 | 1 | import logging |
2 | 2 | from typing import Collection, Dict, List, Optional, Tuple |
3 | 3 |
|
| 4 | +import psutil |
4 | 5 | import torch |
5 | 6 | import torch.fx.passes.operator_support as ops |
6 | 7 | from torch.fx.node import Target |
@@ -225,13 +226,80 @@ def partition_graph(self) -> torch.fx.GraphModule: |
225 | 226 | # Remove segments smaller than the block size (with exceptions) |
226 | 227 | subgraphs = self.remove_small_acc_subgraphs(subgraphs) |
227 | 228 |
|
| 229 | + num_of_break = self.calculate_num_of_break(subgraphs) |
| 230 | + subgraphs = self.break_subgraphs(subgraphs, num_of_break=num_of_break) |
| 231 | + |
228 | 232 | # Set the number of TRT engines to be generated |
229 | 233 | self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) |
230 | 234 |
|
231 | 235 | # Tag the accelerated nodes and split the graph accordingly |
232 | 236 | self.tag(subgraphs) |
233 | 237 | return self.split() |
234 | 238 |
|
| 239 | + def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int: |
| 240 | + """ |
| 241 | + This function calculates the break period based on the number of subgraphs. |
| 242 | + """ |
| 243 | + rss = psutil.Process().memory_info().rss |
| 244 | + available_rss = psutil.virtual_memory().available |
| 245 | + num_of_graphs = len(subgraphs) |
| 246 | + if rss < available_rss * 0.3: |
| 247 | + num_of_graphs = 1 |
| 248 | + elif rss < available_rss * 0.5: |
| 249 | + num_of_graphs = 2 |
| 250 | + elif rss < available_rss: |
| 251 | + num_of_graphs = 4 |
| 252 | + elif rss < available_rss * 1.5: |
| 253 | + num_of_graphs = 8 |
| 254 | + elif rss < available_rss * 2: |
| 255 | + num_of_graphs = 16 |
| 256 | + else: |
| 257 | + num_of_graphs = 32 |
| 258 | + |
| 259 | + return max( |
| 260 | + 1, num_of_graphs // ((len(subgraphs) + 1) // 2) |
| 261 | + ) # If there are already graph breaks, for each TRT subgraph, we break for a few times. |
| 262 | + |
| 263 | + def break_subgraphs( |
| 264 | + self, subgraphs: List[Subgraph], num_of_break: int = 1 |
| 265 | + ) -> List[Subgraph]: |
| 266 | + """ |
| 267 | + This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. |
| 268 | + """ |
| 269 | + |
| 270 | + num_of_sdpa_node = len( |
| 271 | + [node for node in self.acc_nodes if "scaled_dot" in str(node.target)] |
| 272 | + ) |
| 273 | + break_period = num_of_sdpa_node // num_of_break + 1 |
| 274 | + current_break_idx = 0 |
| 275 | + current_num_break = 0 |
| 276 | + new_subgraphs = [] |
| 277 | + for subgraph in subgraphs: |
| 278 | + if subgraph.is_acc: |
| 279 | + for i, node in enumerate(subgraph.nodes): |
| 280 | + if "scaled_dot" in str(node.target): |
| 281 | + current_num_break += 1 |
| 282 | + if current_num_break % break_period != 0: |
| 283 | + continue |
| 284 | + new_subgraphs.append( |
| 285 | + Subgraph( |
| 286 | + is_acc=True, |
| 287 | + nodes=subgraph.nodes[current_break_idx : i + 1], |
| 288 | + device_ordinal=subgraph.device_ordinal, |
| 289 | + ) |
| 290 | + ) |
| 291 | + current_break_idx = i + 1 |
| 292 | + new_subgraphs.append( |
| 293 | + Subgraph( |
| 294 | + is_acc=True, |
| 295 | + nodes=subgraph.nodes[current_break_idx:], |
| 296 | + device_ordinal=subgraph.device_ordinal, |
| 297 | + ) |
| 298 | + ) |
| 299 | + else: |
| 300 | + new_subgraphs.append(subgraph) |
| 301 | + return new_subgraphs |
| 302 | + |
235 | 303 | def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: |
236 | 304 | """Generates starter nodes for partitioning + segmentation""" |
237 | 305 | # Starter accelerated nodes are all callable accelerated ops |
|
0 commit comments