-
Notifications
You must be signed in to change notification settings - Fork 776
Move torch.cond predicate non-persistent buffer to CPU #16378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
larryliu0820
wants to merge
27
commits into
main
Choose a base branch
from
gh/larryliu0820/86/head
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
63a2766
Update
larryliu0820 f02dbe1
Update
larryliu0820 9a7aa91
Update
larryliu0820 bc07a7b
Update
larryliu0820 a97933b
Update
larryliu0820 99ca698
Update
larryliu0820 e1bb6c2
Update
larryliu0820 395ab4f
Update
larryliu0820 2a7a9f0
Update
larryliu0820 a86ab6e
Update
larryliu0820 ca3ac6d
Update
larryliu0820 8b94087
Update
larryliu0820 5f755f9
Update
larryliu0820 690546b
Update
larryliu0820 73efe12
Update
larryliu0820 d96dec8
Update
larryliu0820 eb6a7e6
Update
larryliu0820 d5c53ec
Update
larryliu0820 8b8580d
Update
larryliu0820 ba6fdff
Update
larryliu0820 b103b7f
Update
larryliu0820 a8b20f5
Update
larryliu0820 016adb3
Update
larryliu0820 3fc3117
Update
larryliu0820 e8349a7
Update
larryliu0820 5897ba4
Update
larryliu0820 ab861b9
Update
larryliu0820 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| from torch.export import ExportedProgram | ||
|
|
||
|
|
||
| class MoveCondPredicateToCpuPass: | ||
| """ | ||
| A pass that moves the predicate of torch.cond to CPU if the predicate is a constantbuffer. | ||
| This is useful for models that use the predicate as a constant buffer, such as an `initialized` flag for cross attention kv cache. | ||
|
|
||
| This saves ~50us per torch.cond call on RTX 5080. | ||
|
|
||
| Example: | ||
| ``` | ||
| class CrossAttentionWithCache(torch.nn.Module): | ||
| def __init__(self, hidden_size): | ||
| super().__init__() | ||
| self.k_proj = torch.nn.Linear(hidden_size, hidden_size) | ||
| self.v_proj = torch.nn.Linear(hidden_size, hidden_size) | ||
| self.q_proj = torch.nn.Linear(hidden_size, hidden_size) | ||
| self.out_proj = torch.nn.Linear(hidden_size, hidden_size) | ||
| # Buffer used as predicate for torch.cond | ||
| self.register_buffer("initialized", torch.tensor([False]), persistent=False) | ||
| self.register_buffer("k_cache", torch.zeros(1, 10, hidden_size), persistent=False) | ||
| self.register_buffer("v_cache", torch.zeros(1, 10, hidden_size), persistent=False) | ||
|
|
||
| def compute_kv(self, encoder_hidden_states): | ||
| k = self.k_proj(encoder_hidden_states) | ||
| v = self.v_proj(encoder_hidden_states) | ||
| self.k_cache.copy_(k) | ||
| self.v_cache.copy_(v) | ||
| self.initialized.fill_(True) | ||
| return k, v | ||
|
|
||
| def use_cached_kv(self, encoder_hidden_states): | ||
| return self.k_cache.clone(), self.v_cache.clone() | ||
|
|
||
| def forward(self, hidden_states, encoder_hidden_states): | ||
| q = self.q_proj(hidden_states) | ||
| # Use torch.cond with initialized buffer as predicate | ||
| k, v = torch.cond( | ||
| self.initialized, | ||
| self.use_cached_kv, | ||
| self.compute_kv, | ||
| (encoder_hidden_states,), | ||
| ) | ||
| attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v) | ||
| return self.out_proj(attn_output) | ||
| ``` | ||
| In this example if we keep `self.initialized` on GPU, we will need to copy it to CPU for every forward pass. | ||
| We move the predicate to CPU to avoid device to host copies. | ||
| This pass is only applicable to models that use torch.cond and its predicate is a constant buffer. | ||
| """ | ||
|
|
||
| requires_exported_program = True | ||
|
|
||
| def __call__(self, exported_program: ExportedProgram): | ||
larryliu0820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| graph_module = exported_program.graph_module | ||
|
|
||
| # Map input names to buffer names | ||
| inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers | ||
|
|
||
| for node in graph_module.graph.nodes: | ||
| if ( | ||
| node.op == "call_function" | ||
| and node.target == torch.ops.higher_order.cond | ||
| ): | ||
| pred_node = node.args[0] | ||
| if ( | ||
| pred_node.op == "placeholder" | ||
| and pred_node.name in inputs_to_buffers | ||
| ): | ||
| buffer_name = inputs_to_buffers[pred_node.name] | ||
|
|
||
| if buffer_name in exported_program.constants: | ||
| tensor = exported_program._constants[buffer_name] | ||
| if tensor.device.type != "cpu": | ||
| exported_program._constants[buffer_name] = tensor.to("cpu") | ||
|
|
||
| # Also update the placeholder metadata | ||
| if "val" in pred_node.meta: | ||
| fake_tensor = pred_node.meta["val"] | ||
| if isinstance(fake_tensor, torch.Tensor): | ||
| pred_node.meta["val"] = fake_tensor.to("cpu") | ||
| exported_program.validate() | ||
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.