Skip to content

Commit c4f6999

Browse files
authored
Merge pull request #71 from cmccomb/codex/update-preprocessor-for-fit-and-transform
2 parents 6629f21 + 852599c commit c4f6999

File tree

5 files changed

+152
-49
lines changed

5 files changed

+152
-49
lines changed

src/model/preprocessing.rs

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
//! Utilities for data preprocessing.
22
3-
use crate::settings::PreProcessing;
3+
use crate::model::error::ModelError;
4+
use crate::settings::{PreProcessing, SettingsError};
45
use crate::utils::features::{FeatureError, interaction_features, polynomial_features};
56
use smartcore::{
67
decomposition::{
78
pca::{PCA, PCAParameters},
89
svd::{SVD, SVDParameters},
910
},
11+
error::Failed,
1012
linalg::{
1113
basic::arrays::{Array, Array2},
1214
traits::{
@@ -52,72 +54,103 @@ where
5254
}
5355
}
5456

55-
/// Train preprocessing models based on settings.
56-
pub fn train(&mut self, x: &InputArray, settings: &PreProcessing) {
57+
/// Fit preprocessing state (if required) and return a transformed copy of the
58+
/// training matrix.
59+
pub fn fit_transform(
60+
&mut self,
61+
x: InputArray,
62+
settings: &PreProcessing,
63+
) -> Result<InputArray, SettingsError> {
64+
self.pca = None;
65+
self.svd = None;
5766
match settings {
67+
PreProcessing::None => Ok(x),
68+
PreProcessing::AddInteractions => {
69+
interaction_features(x).map_err(Self::feature_error_to_settings)
70+
}
71+
PreProcessing::AddPolynomial { order } => {
72+
polynomial_features(x, *order).map_err(Self::feature_error_to_settings)
73+
}
5874
PreProcessing::ReplaceWithPCA {
5975
number_of_components,
60-
} => {
61-
self.train_pca(x, *number_of_components);
62-
}
76+
} => self.fit_pca(&x, *number_of_components),
6377
PreProcessing::ReplaceWithSVD {
6478
number_of_components,
65-
} => {
66-
self.train_svd(x, *number_of_components);
67-
}
68-
_ => {}
79+
} => self.fit_svd(&x, *number_of_components),
6980
}
7081
}
7182

72-
/// Apply preprocessing to data.
83+
/// Apply preprocessing to inference data.
7384
pub fn preprocess(
7485
&self,
7586
x: InputArray,
7687
settings: &PreProcessing,
77-
) -> Result<InputArray, FeatureError> {
78-
Ok(match settings {
79-
PreProcessing::None => x,
80-
PreProcessing::AddInteractions => interaction_features(x)?,
81-
PreProcessing::AddPolynomial { order } => polynomial_features(x, *order)?,
82-
PreProcessing::ReplaceWithPCA {
83-
number_of_components: _,
84-
} => self.pca_features(&x),
85-
PreProcessing::ReplaceWithSVD {
86-
number_of_components: _,
87-
} => self.svd_features(&x),
88-
})
88+
) -> Result<InputArray, ModelError> {
89+
match settings {
90+
PreProcessing::None => Ok(x),
91+
PreProcessing::AddInteractions => {
92+
interaction_features(x).map_err(Self::feature_error_to_model)
93+
}
94+
PreProcessing::AddPolynomial { order } => {
95+
polynomial_features(x, *order).map_err(Self::feature_error_to_model)
96+
}
97+
PreProcessing::ReplaceWithPCA { .. } => self.pca_features(&x),
98+
PreProcessing::ReplaceWithSVD { .. } => self.svd_features(&x),
99+
}
89100
}
90101

91-
fn train_pca(&mut self, x: &InputArray, n: usize) {
102+
fn fit_pca(&mut self, x: &InputArray, n: usize) -> Result<InputArray, SettingsError> {
92103
let pca = PCA::fit(
93104
x,
94105
PCAParameters::default()
95106
.with_n_components(n)
96107
.with_use_correlation_matrix(true),
97108
)
98-
.expect("Could not train PCA preprocessor");
109+
.map_err(|err| Self::failed_to_settings(&err))?;
110+
let transformed = pca
111+
.transform(x)
112+
.map_err(|err| Self::failed_to_settings(&err))?;
99113
self.pca = Some(pca);
114+
Ok(transformed)
100115
}
101116

102-
fn pca_features(&self, x: &InputArray) -> InputArray {
103-
self.pca
117+
fn pca_features(&self, x: &InputArray) -> Result<InputArray, ModelError> {
118+
let pca = self
119+
.pca
104120
.as_ref()
105-
.expect("PCA model not trained")
106-
.transform(x)
107-
.expect("Could not transform data using PCA")
121+
.ok_or_else(|| ModelError::Inference("PCA model not trained".to_string()))?;
122+
pca.transform(x)
123+
.map_err(|err| ModelError::Inference(err.to_string()))
108124
}
109125

110-
fn train_svd(&mut self, x: &InputArray, n: usize) {
126+
fn fit_svd(&mut self, x: &InputArray, n: usize) -> Result<InputArray, SettingsError> {
111127
let svd = SVD::fit(x, SVDParameters::default().with_n_components(n))
112-
.expect("Could not train SVD preprocessor");
128+
.map_err(|err| Self::failed_to_settings(&err))?;
129+
let transformed = svd
130+
.transform(x)
131+
.map_err(|err| Self::failed_to_settings(&err))?;
113132
self.svd = Some(svd);
133+
Ok(transformed)
114134
}
115135

116-
fn svd_features(&self, x: &InputArray) -> InputArray {
117-
self.svd
136+
fn svd_features(&self, x: &InputArray) -> Result<InputArray, ModelError> {
137+
let svd = self
138+
.svd
118139
.as_ref()
119-
.expect("SVD model not trained")
120-
.transform(x)
121-
.expect("Could not transform data using SVD")
140+
.ok_or_else(|| ModelError::Inference("SVD model not trained".to_string()))?;
141+
svd.transform(x)
142+
.map_err(|err| ModelError::Inference(err.to_string()))
143+
}
144+
145+
fn feature_error_to_settings(err: FeatureError) -> SettingsError {
146+
SettingsError::PreProcessingFailed(err.to_string())
147+
}
148+
149+
fn feature_error_to_model(err: FeatureError) -> ModelError {
150+
ModelError::Inference(err.to_string())
151+
}
152+
153+
fn failed_to_settings(err: &Failed) -> SettingsError {
154+
SettingsError::PreProcessingFailed(err.to_string())
122155
}
123156
}

src/model/supervised.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ use crate::model::{
1111
preprocessing::Preprocessor,
1212
};
1313
use crate::settings::{
14-
ClassificationSettings, FinalAlgorithm, Metric, RegressionSettings, SupervisedSettings,
14+
ClassificationSettings, FinalAlgorithm, Metric, RegressionSettings, SettingsError,
15+
SupervisedSettings,
1516
};
1617
use comfy_table::{
1718
Attribute, Cell, Table, modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL,
1819
};
1920
use humantime::format_duration;
20-
use smartcore::error::Failed;
21+
use smartcore::error::{Failed, FailedError};
2122
use smartcore::linalg::{
2223
basic::arrays::{Array, Array1, Array2, MutArrayView1},
2324
traits::{
@@ -111,7 +112,9 @@ where
111112
{
112113
/// Settings for the model.
113114
pub settings: S,
114-
/// Training features.
115+
/// Original training features used to recompute preprocessing steps.
116+
x_train_raw: InputArray,
117+
/// Preprocessed training features fed to algorithms.
115118
x_train: InputArray,
116119
/// Training targets.
117120
y_train: OutputArray,
@@ -136,8 +139,10 @@ where
136139
{
137140
/// Create a new supervised model.
138141
pub fn new(x: InputArray, y: OutputArray, settings: S) -> Self {
142+
let x_train_raw = x.clone();
139143
Self {
140144
settings,
145+
x_train_raw,
141146
x_train: x,
142147
y_train: y,
143148
comparison: Vec::new(),
@@ -152,8 +157,11 @@ where
152157
/// Returns [`Failed`] if cross-validation fails for any algorithm.
153158
pub fn train(&mut self) -> Result<(), Failed> {
154159
let sup = self.settings.supervised();
155-
self.preprocessor
156-
.train(&self.x_train.clone(), &sup.preprocessing);
160+
let raw = self.x_train_raw.clone();
161+
self.x_train = self
162+
.preprocessor
163+
.fit_transform(raw, &sup.preprocessing)
164+
.map_err(|err| Self::preprocessing_failed(&err))?;
157165

158166
for alg in <A>::all_algorithms(&self.settings) {
159167
let trained = alg.cross_validate_model(&self.x_train, &self.y_train, &self.settings)?;
@@ -169,10 +177,7 @@ where
169177
/// Returns [`ModelError::NotTrained`] if no algorithm has been trained or if inference fails.
170178
pub fn predict(&self, x: InputArray) -> ModelResult<OutputArray> {
171179
let sup = self.settings.supervised();
172-
let x = self
173-
.preprocessor
174-
.preprocess(x, &sup.preprocessing)
175-
.map_err(|e| ModelError::Inference(e.to_string()))?;
180+
let x = self.preprocessor.preprocess(x, &sup.preprocessing)?;
176181

177182
match sup.final_model_approach {
178183
FinalAlgorithm::None => Err(ModelError::NotTrained),
@@ -203,6 +208,10 @@ where
203208
self.comparison.reverse();
204209
}
205210
}
211+
212+
fn preprocessing_failed(err: &SettingsError) -> Failed {
213+
Failed::because(FailedError::ParametersError, &err.to_string())
214+
}
206215
}
207216

208217
impl<A, S, InputArray, OutputArray> Display for SupervisedModel<A, S, InputArray, OutputArray>

src/settings/error.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,24 @@ use std::fmt::{Display, Formatter};
55
use super::Metric;
66

77
/// Errors related to model settings.
8-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8+
#[derive(Debug, Clone, PartialEq, Eq)]
99
pub enum SettingsError {
1010
/// A required metric was not specified.
1111
MetricNotSet,
1212
/// The provided metric is not supported for the task.
1313
UnsupportedMetric(Metric),
14+
/// Preprocessing configuration failed to run successfully.
15+
PreProcessingFailed(String),
1416
}
1517

1618
impl Display for SettingsError {
1719
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1820
match self {
1921
Self::MetricNotSet => write!(f, "a metric must be set"),
2022
Self::UnsupportedMetric(m) => write!(f, "unsupported metric: {m}"),
23+
Self::PreProcessingFailed(msg) => {
24+
write!(f, "preprocessing configuration failed: {msg}")
25+
}
2126
}
2227
}
2328
}

tests/classification.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use automl::algorithms::ClassificationAlgorithm;
55
use automl::model::Algorithm;
66
use automl::settings::{
77
BernoulliNBParameters, CategoricalNBParameters, ClassificationSettings,
8-
MultinomialNBParameters, RandomForestClassifierParameters, SVCParameters,
8+
MultinomialNBParameters, PreProcessing, RandomForestClassifierParameters, SVCParameters,
99
};
1010
use automl::{DenseMatrix, ModelError, SupervisedModel};
1111
use classification_data::{
@@ -217,6 +217,32 @@ fn bernoulli_nb_rejects_non_binary_without_threshold() {
217217
);
218218
}
219219

220+
#[test]
221+
fn classification_pca_preprocessing_predicts() {
222+
type Model = SupervisedModel<
223+
ClassificationAlgorithm<f64, u32, DenseMatrix<f64>, Vec<u32>>,
224+
ClassificationSettings,
225+
DenseMatrix<f64>,
226+
Vec<u32>,
227+
>;
228+
229+
let (x, y) = classification_testing_data();
230+
let settings = ClassificationSettings::default()
231+
.with_svc_settings(SVCParameters::default())
232+
.with_preprocessing(PreProcessing::ReplaceWithPCA {
233+
number_of_components: 2,
234+
});
235+
236+
let mut model: Model = SupervisedModel::new(x, y, settings);
237+
model.train().unwrap();
238+
239+
let predictions = model
240+
.predict(DenseMatrix::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap())
241+
.expect("PCA-preprocessed model should predict successfully");
242+
243+
assert_eq!(predictions.len(), 2);
244+
}
245+
220246
#[test]
221247
fn invalid_alpha_returns_error() {
222248
// Arrange

tests/regression.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ mod regression_data;
44
use automl::algorithms::RegressionAlgorithm;
55
use automl::model::Algorithm;
66
use automl::settings::{
7-
Distance, ExtraTreesRegressorParameters, KNNParameters, Kernel, SVRParameters,
7+
Distance, ExtraTreesRegressorParameters, KNNParameters, Kernel, PreProcessing, SVRParameters,
88
XGRegressorParameters,
99
};
1010
use automl::{DenseMatrix, RegressionSettings, SupervisedModel};
@@ -234,6 +234,36 @@ fn test_xgboost_skiplist_controls_algorithms() {
234234
));
235235
}
236236

237+
#[test]
238+
fn regression_polynomial_preprocessing_predicts() {
239+
type Model = SupervisedModel<
240+
RegressionAlgorithm<f64, f64, DenseMatrix<f64>, Vec<f64>>,
241+
RegressionSettings<f64, f64, DenseMatrix<f64>, Vec<f64>>,
242+
DenseMatrix<f64>,
243+
Vec<f64>,
244+
>;
245+
246+
let (x, y) = regression_testing_data();
247+
let settings = RegressionSettings::default()
248+
.with_preprocessing(PreProcessing::AddPolynomial { order: 2 })
249+
.only(&RegressionAlgorithm::default_knn_regressor());
250+
251+
let mut regressor: Model = SupervisedModel::new(x, y, settings);
252+
regressor.train().unwrap();
253+
254+
let predictions = regressor
255+
.predict(
256+
DenseMatrix::from_2d_array(&[
257+
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
258+
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
259+
])
260+
.unwrap(),
261+
)
262+
.expect("Polynomial preprocessing should allow prediction");
263+
264+
assert_eq!(predictions.len(), 2);
265+
}
266+
237267
fn test_from_settings(settings: RegressionSettings<f64, f64, DenseMatrix<f64>, Vec<f64>>) {
238268
// Set up the regressor settings and load data
239269
type Model = SupervisedModel<

0 commit comments

Comments
 (0)