-
Notifications
You must be signed in to change notification settings - Fork 5
Add Permutation Importance #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
This is the current basic call: import numpy as np
import polars as pl
from sklearn.linear_model import LinearRegression
from model_diagnostics.xai import plot_permutation_importance
rng = np.random.default_rng(1)
n = 1000
X = pl.DataFrame(
{
"area": rng.uniform(30, 120, n),
"rooms": rng.choice([2.5, 3.5, 4.5], n),
"age": rng.uniform(0, 100, n),
}
)
y = X["area"] + 20 * X["rooms"] + rng.normal(0, 1, n)
model = LinearRegression()
model.fit(X, y)
_ = plot_permutation_importance(
predict_function=model.predict,
X=X,
y=y,
)The extended feature API allows to permute groups like this: _ = plot_permutation_importance(
predict_function=model.predict,
features={"size": ["area", "rooms"], "age": "age"},
X=X,
y=y,
) |
| from model_diagnostics.scoring import SquaredError | ||
|
|
||
|
|
||
| def safe_copy(X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to put safe_copy and safe_column_names into _utils.array and add tests for them. I think they cause the current CI failure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My local tests are failing for the Python 3.9 environment only (pandas and pyarrow). I will move the functions to _utils.array, draft some unit tests, and rename safe_column_names() to get_column_names().
|
This will be a great addition! Thanks @mayer79 |
|
The failing test is in the python 3.9 env with Could you check if increasing one of the versions fixes the problem, e.g. polars version? |
The following changes in the 3.9 env would be necessary. I don't know how much it would hurt to abandon pandas 1
I have added some additional unit tests and moved |
|
fyi, CI will fail due to new versions of polars and numpy. I am working on a fix. |
|
Fix in #203, you need to sync (e.g. merge) with the main branch (and maybe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review of array utils.
| # if not x.index.is_unique: | ||
| # Pandas might error with: | ||
| # cannot reindex on an axis with duplicate labels | ||
| # Try reindexing ourselves. | ||
| x = x.reset_index(drop=True) | ||
| # if not pd_values.index.is_unique: | ||
| pd_values = pd_values.reset_index(drop=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you find a condition (if statement) such that this is only executed if (strictly) required?
Could you add a test case (to an existing test) or a new test that fails without this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure. I think this test would fail with Pandas 1.5, because the index of the assignment is not matching:
def test_safe_assign_column_works_for_pandas_with_inconsistent_index():
"""Test that safe_assign_column works for pandas dfs regarding indices."""
df = pd_DataFrame({"a": [0, 1, 2]}, index=[0, 1, 2])
if isinstance(df, SkipContainer):
pytest.skip("Module for data container not imported.")
df = safe_assign_column(
df, values=pd_Series([10, 20, 30], index=[1, 1, 0]), column_index=0
)
expected = pd_DataFrame({"a": [10, 20, 30]})
assert_array_equal(df, expected)But I don't know how to test because such test is skipped.
Now, when we remove pandas <2 support, safe_assign_column() does not need to be as strict as in the commited version.
|
fyi: I am preparing to bump the minimum versions of python to 3.11 and numpy to 2. This implies polars 1.1.0, pandas >= 2.2.2 and pyarrow >= 16, see #206. |
|
I have modified these aspects in the main functionality:
|


Implements #201