Skip to content

Commit 7351143

Browse files
authored
Merge pull request #72 from cmccomb/codex/introduce-preprocessing-settings-module
2 parents c4f6999 + 68fa953 commit 7351143

File tree

12 files changed

+576
-133
lines changed

12 files changed

+576
-133
lines changed

README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,40 @@ will perform a comparison of classifier models using cross-validation. Printing
6060

6161
You can then perform inference using the best model with the `predict` method.
6262

63+
## Preprocessing pipelines
64+
65+
`automl` now supports composable preprocessing pipelines so you can build
66+
feature engineering recipes similar to `AutoGluon` or `caret`. Pipelines are
67+
defined with the [`PreprocessingStep`](https://docs.rs/automl/latest/automl/settings/enum.PreprocessingStep.html)
68+
enum and attached via either the `add_step` builder or by passing a full
69+
[`PreprocessingPipeline`](https://docs.rs/automl/latest/automl/settings/struct.PreprocessingPipeline.html).
70+
71+
```rust
72+
use automl::settings::{
73+
ClassificationSettings, PreprocessingPipeline, PreprocessingStep, RegressionSettings,
74+
StandardizeParams,
75+
};
76+
use automl::DenseMatrix;
77+
78+
let regression = RegressionSettings::<f64, f64, DenseMatrix<f64>, Vec<f64>>::default()
79+
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
80+
.add_step(PreprocessingStep::ReplaceWithPCA {
81+
number_of_components: 5,
82+
});
83+
84+
let classification = ClassificationSettings::default().with_preprocessing(
85+
PreprocessingPipeline::new()
86+
.add_step(PreprocessingStep::AddInteractions)
87+
.add_step(PreprocessingStep::ReplaceWithSVD {
88+
number_of_components: 4,
89+
}),
90+
);
91+
```
92+
93+
Pipelines preserve the order of steps. Stateful steps such as PCA, SVD, or
94+
standardization automatically fit during training and reuse the same fitted
95+
state when you call `predict`.
96+
6397
## Features
6498

6599
This crate has several features that add some additional methods.

examples/maximal_regression.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ use automl::{
2121
settings::{
2222
DecisionTreeRegressorParameters, Distance, ElasticNetParameters, FinalAlgorithm,
2323
KNNAlgorithmName, KNNParameters, KNNWeightFunction, Kernel, LassoParameters,
24-
LinearRegressionParameters, LinearRegressionSolverName, Metric,
24+
LinearRegressionParameters, LinearRegressionSolverName, Metric, PreprocessingStep,
2525
RandomForestRegressorParameters, RidgeRegressionParameters, RidgeRegressionSolverName,
26-
SVRParameters, XGRegressorParameters,
26+
SVRParameters, StandardizeParams, XGRegressorParameters,
2727
},
2828
};
2929
use regression_data::regression_testing_data;
@@ -41,7 +41,7 @@ fn main() -> Result<(), Failed> {
4141
.with_final_model(FinalAlgorithm::Best)
4242
.skip(RegressionAlgorithm::default_random_forest())
4343
.sorted_by(Metric::RSquared)
44-
// .with_preprocessing(PreProcessing::AddInteractions)
44+
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
4545
.with_linear_settings(
4646
LinearRegressionParameters::default().with_solver(LinearRegressionSolverName::QR),
4747
)

examples/print_settings.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ use automl::settings::{
44
DecisionTreeRegressorParameters, Distance, ElasticNetParameters, ExtraTreesRegressorParameters,
55
FinalAlgorithm, GaussianNBParameters, KNNAlgorithmName, KNNParameters, KNNWeightFunction,
66
Kernel, LassoParameters, LinearRegressionParameters, LinearRegressionSolverName,
7-
LogisticRegressionParameters, Metric, MultinomialNBParameters, Objective, PreProcessing,
8-
RandomForestClassifierParameters, RandomForestRegressorParameters, RegressionSettings,
9-
RidgeRegressionParameters, RidgeRegressionSolverName, SVCParameters, SVRParameters,
7+
LogisticRegressionParameters, Metric, MultinomialNBParameters, Objective,
8+
PreprocessingPipeline, PreprocessingStep, RandomForestClassifierParameters,
9+
RandomForestRegressorParameters, RegressionSettings, RidgeRegressionParameters,
10+
RidgeRegressionSolverName, SVCParameters, SVRParameters, StandardizeParams,
1011
XGRegressorParameters,
1112
};
1213
use serde_json::to_string_pretty;
@@ -20,7 +21,8 @@ fn build_regression_settings() -> RegressionConfig {
2021
.shuffle_data(true)
2122
.verbose(true)
2223
.sorted_by(Metric::RSquared)
23-
.with_preprocessing(PreProcessing::AddInteractions)
24+
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
25+
.add_step(PreprocessingStep::AddInteractions)
2426
.with_linear_settings(
2527
LinearRegressionParameters::default().with_solver(LinearRegressionSolverName::QR),
2628
)
@@ -99,12 +101,16 @@ fn build_regression_settings() -> RegressionConfig {
99101
}
100102

101103
fn build_classification_settings() -> ClassificationSettings {
104+
let pipeline = PreprocessingPipeline::new()
105+
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
106+
.add_step(PreprocessingStep::AddInteractions);
107+
102108
ClassificationSettings::default()
103109
.with_number_of_folds(6)
104110
.shuffle_data(true)
105111
.verbose(true)
106112
.sorted_by(Metric::Accuracy)
107-
.with_preprocessing(PreProcessing::AddInteractions)
113+
.with_preprocessing(pipeline)
108114
.with_final_model(FinalAlgorithm::Best)
109115
.with_knn_classifier_settings(
110116
KNNParameters::default()

0 commit comments

Comments
 (0)