-
Notifications
You must be signed in to change notification settings - Fork 4
Proxy: torch.nn.Parameter subclass #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new ProxyParameter subclass of torch.nn.Parameter to enable parameter tracing that is compatible with the PyTorch library. The implementation provides an alternative to the existing proxy wrapper approach by leveraging PyTorch's subclassing mechanism.
Key changes:
- Creates a ProxyParameter class that extends torch.nn.Parameter with tracing capabilities
- Adds support for a new "proxyparameter" model tracker style alongside existing "proxy" and "sampler" modes
- Updates observer and tracer components to handle ProxyParameter instances
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| traincheck/proxy_wrapper/subclass.py | New file implementing ProxyParameter class and proxy_parameter function for parameter tracing |
| traincheck/proxy_wrapper/proxy_observer.py | Updates observer to handle ProxyParameter instances alongside existing Proxy objects |
| traincheck/proxy_wrapper/proxy_config.py | Adds proxy_attribute list for filtering ProxyParameter-specific attributes |
| traincheck/proxy_wrapper/proxy_basics.py | Adds is_proxyparamtetr function to detect ProxyParameter instances |
| traincheck/instrumentor/tracer.py | Updates tracer to recognize and handle ProxyParameter objects in function arguments |
| traincheck/instrumentor/source_file.py | Adds "proxyparameter" mode support to model assignment instrumentation |
| traincheck/instrumentor/dumper.py | Filters out proxy-specific attributes when dumping ProxyParameter objects |
| traincheck/collect_trace.py | Adds "proxyparameter" as a valid choice for model tracker style |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
|
||
| if from_iter: | ||
| phase = "iter" | ||
| # if the object is generated from getattr, then do not dump it |
Copilot
AI
Sep 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic flow is incorrect. If from_call is True on line 109, phase is set to "call", but then if from_iter is also True, it gets overwritten to "iter". The else clause on line 115 should be elif to properly handle the case where neither from_call nor from_iter is True.
| if from_iter: | |
| phase = "iter" | |
| # if the object is generated from getattr, then do not dump it | |
| elif from_iter: | |
| phase = "iter" |
| ) | ||
|
|
||
| else: | ||
| raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']") |
Copilot
AI
Sep 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error message is outdated and doesn't include the new "proxyparameter" mode. Should be "Must be one of ['proxy', 'sampler', 'proxyparameter']".
| raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']") | |
| raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler', 'proxyparameter']") |
Use torch.nn.Parameter subclass to proxy the model parameter and trace the changes. It is compatible with the torch library.