|
1 | 1 | //! Utilities for data preprocessing. |
2 | 2 |
|
3 | | -use crate::settings::PreProcessing; |
| 3 | +use crate::model::error::ModelError; |
| 4 | +use crate::settings::{PreProcessing, SettingsError}; |
4 | 5 | use crate::utils::features::{FeatureError, interaction_features, polynomial_features}; |
5 | 6 | use smartcore::{ |
6 | 7 | decomposition::{ |
7 | 8 | pca::{PCA, PCAParameters}, |
8 | 9 | svd::{SVD, SVDParameters}, |
9 | 10 | }, |
| 11 | + error::Failed, |
10 | 12 | linalg::{ |
11 | 13 | basic::arrays::{Array, Array2}, |
12 | 14 | traits::{ |
@@ -52,72 +54,103 @@ where |
52 | 54 | } |
53 | 55 | } |
54 | 56 |
|
55 | | - /// Train preprocessing models based on settings. |
56 | | - pub fn train(&mut self, x: &InputArray, settings: &PreProcessing) { |
| 57 | + /// Fit preprocessing state (if required) and return a transformed copy of the |
| 58 | + /// training matrix. |
| 59 | + pub fn fit_transform( |
| 60 | + &mut self, |
| 61 | + x: InputArray, |
| 62 | + settings: &PreProcessing, |
| 63 | + ) -> Result<InputArray, SettingsError> { |
| 64 | + self.pca = None; |
| 65 | + self.svd = None; |
57 | 66 | match settings { |
| 67 | + PreProcessing::None => Ok(x), |
| 68 | + PreProcessing::AddInteractions => { |
| 69 | + interaction_features(x).map_err(Self::feature_error_to_settings) |
| 70 | + } |
| 71 | + PreProcessing::AddPolynomial { order } => { |
| 72 | + polynomial_features(x, *order).map_err(Self::feature_error_to_settings) |
| 73 | + } |
58 | 74 | PreProcessing::ReplaceWithPCA { |
59 | 75 | number_of_components, |
60 | | - } => { |
61 | | - self.train_pca(x, *number_of_components); |
62 | | - } |
| 76 | + } => self.fit_pca(&x, *number_of_components), |
63 | 77 | PreProcessing::ReplaceWithSVD { |
64 | 78 | number_of_components, |
65 | | - } => { |
66 | | - self.train_svd(x, *number_of_components); |
67 | | - } |
68 | | - _ => {} |
| 79 | + } => self.fit_svd(&x, *number_of_components), |
69 | 80 | } |
70 | 81 | } |
71 | 82 |
|
72 | | - /// Apply preprocessing to data. |
| 83 | + /// Apply preprocessing to inference data. |
73 | 84 | pub fn preprocess( |
74 | 85 | &self, |
75 | 86 | x: InputArray, |
76 | 87 | settings: &PreProcessing, |
77 | | - ) -> Result<InputArray, FeatureError> { |
78 | | - Ok(match settings { |
79 | | - PreProcessing::None => x, |
80 | | - PreProcessing::AddInteractions => interaction_features(x)?, |
81 | | - PreProcessing::AddPolynomial { order } => polynomial_features(x, *order)?, |
82 | | - PreProcessing::ReplaceWithPCA { |
83 | | - number_of_components: _, |
84 | | - } => self.pca_features(&x), |
85 | | - PreProcessing::ReplaceWithSVD { |
86 | | - number_of_components: _, |
87 | | - } => self.svd_features(&x), |
88 | | - }) |
| 88 | + ) -> Result<InputArray, ModelError> { |
| 89 | + match settings { |
| 90 | + PreProcessing::None => Ok(x), |
| 91 | + PreProcessing::AddInteractions => { |
| 92 | + interaction_features(x).map_err(Self::feature_error_to_model) |
| 93 | + } |
| 94 | + PreProcessing::AddPolynomial { order } => { |
| 95 | + polynomial_features(x, *order).map_err(Self::feature_error_to_model) |
| 96 | + } |
| 97 | + PreProcessing::ReplaceWithPCA { .. } => self.pca_features(&x), |
| 98 | + PreProcessing::ReplaceWithSVD { .. } => self.svd_features(&x), |
| 99 | + } |
89 | 100 | } |
90 | 101 |
|
91 | | - fn train_pca(&mut self, x: &InputArray, n: usize) { |
| 102 | + fn fit_pca(&mut self, x: &InputArray, n: usize) -> Result<InputArray, SettingsError> { |
92 | 103 | let pca = PCA::fit( |
93 | 104 | x, |
94 | 105 | PCAParameters::default() |
95 | 106 | .with_n_components(n) |
96 | 107 | .with_use_correlation_matrix(true), |
97 | 108 | ) |
98 | | - .expect("Could not train PCA preprocessor"); |
| 109 | + .map_err(|err| Self::failed_to_settings(&err))?; |
| 110 | + let transformed = pca |
| 111 | + .transform(x) |
| 112 | + .map_err(|err| Self::failed_to_settings(&err))?; |
99 | 113 | self.pca = Some(pca); |
| 114 | + Ok(transformed) |
100 | 115 | } |
101 | 116 |
|
102 | | - fn pca_features(&self, x: &InputArray) -> InputArray { |
103 | | - self.pca |
| 117 | + fn pca_features(&self, x: &InputArray) -> Result<InputArray, ModelError> { |
| 118 | + let pca = self |
| 119 | + .pca |
104 | 120 | .as_ref() |
105 | | - .expect("PCA model not trained") |
106 | | - .transform(x) |
107 | | - .expect("Could not transform data using PCA") |
| 121 | + .ok_or_else(|| ModelError::Inference("PCA model not trained".to_string()))?; |
| 122 | + pca.transform(x) |
| 123 | + .map_err(|err| ModelError::Inference(err.to_string())) |
108 | 124 | } |
109 | 125 |
|
110 | | - fn train_svd(&mut self, x: &InputArray, n: usize) { |
| 126 | + fn fit_svd(&mut self, x: &InputArray, n: usize) -> Result<InputArray, SettingsError> { |
111 | 127 | let svd = SVD::fit(x, SVDParameters::default().with_n_components(n)) |
112 | | - .expect("Could not train SVD preprocessor"); |
| 128 | + .map_err(|err| Self::failed_to_settings(&err))?; |
| 129 | + let transformed = svd |
| 130 | + .transform(x) |
| 131 | + .map_err(|err| Self::failed_to_settings(&err))?; |
113 | 132 | self.svd = Some(svd); |
| 133 | + Ok(transformed) |
114 | 134 | } |
115 | 135 |
|
116 | | - fn svd_features(&self, x: &InputArray) -> InputArray { |
117 | | - self.svd |
| 136 | + fn svd_features(&self, x: &InputArray) -> Result<InputArray, ModelError> { |
| 137 | + let svd = self |
| 138 | + .svd |
118 | 139 | .as_ref() |
119 | | - .expect("SVD model not trained") |
120 | | - .transform(x) |
121 | | - .expect("Could not transform data using SVD") |
| 140 | + .ok_or_else(|| ModelError::Inference("SVD model not trained".to_string()))?; |
| 141 | + svd.transform(x) |
| 142 | + .map_err(|err| ModelError::Inference(err.to_string())) |
| 143 | + } |
| 144 | + |
| 145 | + fn feature_error_to_settings(err: FeatureError) -> SettingsError { |
| 146 | + SettingsError::PreProcessingFailed(err.to_string()) |
| 147 | + } |
| 148 | + |
| 149 | + fn feature_error_to_model(err: FeatureError) -> ModelError { |
| 150 | + ModelError::Inference(err.to_string()) |
| 151 | + } |
| 152 | + |
| 153 | + fn failed_to_settings(err: &Failed) -> SettingsError { |
| 154 | + SettingsError::PreProcessingFailed(err.to_string()) |
122 | 155 | } |
123 | 156 | } |
0 commit comments