Skip to content

Commit bde9c2b

Browse files
committed
Add real-world cookbook examples
1 parent 153c2d9 commit bde9c2b

File tree

7 files changed

+309
-0
lines changed

7 files changed

+309
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ will perform a comparison of classifier models using cross-validation. Printing
6060

6161
You can then perform inference using the best model with the `predict` method.
6262

63+
## Cookbook
64+
65+
Explore the `automl::cookbook` module for copy-pastable examples that mirror
66+
real-world workflows:
67+
68+
- `cargo run --example breast_cancer_csv` – load the Wisconsin Diagnostic
69+
Breast Cancer dataset from CSV, standardize features, and compare tuned
70+
classifiers.
71+
- `cargo run --example diabetes_regression` – impute, scale, and train
72+
regression models for the diabetes progression dataset.
73+
6374
## Preprocessing pipelines
6475

6576
`automl` now supports composable preprocessing pipelines so you can build

examples/breast_cancer_csv.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#![allow(clippy::needless_doctest_main)]
2+
//! Real-world breast cancer classification example.
3+
//!
4+
//! This example trains a multi-model comparison on the Wisconsin Diagnostic
5+
//! Breast Cancer dataset bundled with the repository. The workflow shows how to
6+
//! load a CSV file, wire up a preprocessing pipeline, and customize the
7+
//! algorithms that participate in the comparison.
8+
//!
9+
//! Run with:
10+
//!
11+
//! ```bash
12+
//! cargo run --example breast_cancer_csv
13+
//! ```
14+
15+
#[path = "../tests/fixtures/breast_cancer_dataset.rs"]
16+
mod breast_cancer_dataset;
17+
18+
use std::error::Error;
19+
20+
use automl::settings::{
21+
ClassificationSettings, FinalAlgorithm, PreprocessingPipeline, PreprocessingStep,
22+
RandomForestClassifierParameters, StandardizeParams,
23+
};
24+
use automl::{ClassificationModel, DenseMatrix};
25+
use breast_cancer_dataset::load_breast_cancer_dataset;
26+
27+
fn main() -> Result<(), Box<dyn Error>> {
28+
let (features, targets) = load_breast_cancer_dataset()?;
29+
30+
let preprocessing = PreprocessingPipeline::new()
31+
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()));
32+
33+
let settings = ClassificationSettings::default()
34+
.with_number_of_folds(5)
35+
.shuffle_data(true)
36+
.with_final_model(FinalAlgorithm::Best)
37+
.with_preprocessing(preprocessing)
38+
.with_random_forest_classifier_settings(
39+
RandomForestClassifierParameters::default()
40+
.with_n_trees(200)
41+
.with_max_depth(8)
42+
.with_min_samples_split(4)
43+
.with_min_samples_leaf(2),
44+
);
45+
46+
let mut model = ClassificationModel::new(features, targets, settings);
47+
model.train()?;
48+
49+
println!("{model}");
50+
51+
let example_patient = DenseMatrix::from_2d_vec(&vec![vec![
52+
13.540, 14.360, 87.460, 566.300, 0.097, 0.052, 0.024, 0.015, 0.153, 0.055, 0.284, 0.915,
53+
2.376, 23.420, 0.005, 0.013, 0.010, 0.005, 0.018, 0.002, 14.230, 17.730, 91.760, 618.800,
54+
0.118, 0.115, 0.068, 0.025, 0.210, 0.062,
55+
]])?;
56+
let predictions = model.predict(example_patient)?;
57+
println!("Predicted class for the evaluation patient: {predictions:?}");
58+
59+
Ok(())
60+
}

examples/diabetes_regression.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#![allow(clippy::needless_doctest_main)]
2+
//! Real-world diabetes progression regression example.
3+
//!
4+
//! The diabetes dataset includes 10 physiological measurements for 442
5+
//! individuals. This example demonstrates how to configure a preprocessing
6+
//! pipeline, tighten algorithm hyperparameters, and evaluate the models via
7+
//! cross-validation before using the best regressor for inference.
8+
//!
9+
//! Run with:
10+
//!
11+
//! ```bash
12+
//! cargo run --example diabetes_regression
13+
//! ```
14+
15+
#[path = "../tests/fixtures/diabetes_dataset.rs"]
16+
mod diabetes_dataset;
17+
18+
use std::error::Error;
19+
20+
use automl::settings::{
21+
ColumnSelector, FinalAlgorithm, ImputeParams, ImputeStrategy, Kernel, PreprocessingPipeline,
22+
PreprocessingStep, RandomForestRegressorParameters, RegressionSettings, SVRParameters,
23+
ScaleParams, ScaleStrategy, StandardizeParams,
24+
};
25+
use automl::{DenseMatrix, RegressionModel};
26+
use diabetes_dataset::load_diabetes_dataset;
27+
28+
fn main() -> Result<(), Box<dyn Error>> {
29+
let (features, targets) = load_diabetes_dataset()?;
30+
31+
let preprocessing = PreprocessingPipeline::new()
32+
.add_step(PreprocessingStep::Impute(ImputeParams {
33+
strategy: ImputeStrategy::Median,
34+
selector: ColumnSelector::All,
35+
}))
36+
.add_step(PreprocessingStep::Scale(ScaleParams {
37+
selector: ColumnSelector::All,
38+
strategy: ScaleStrategy::Standard(StandardizeParams::default()),
39+
}));
40+
41+
let settings = RegressionSettings::default()
42+
.with_number_of_folds(8)
43+
.shuffle_data(true)
44+
.with_final_model(FinalAlgorithm::Best)
45+
.with_preprocessing(preprocessing)
46+
.with_random_forest_regressor_settings(
47+
RandomForestRegressorParameters::default()
48+
.with_n_trees(250)
49+
.with_max_depth(6)
50+
.with_min_samples_leaf(2)
51+
.with_min_samples_split(4),
52+
)
53+
.with_svr_settings(
54+
SVRParameters::default()
55+
.with_c(12.5)
56+
.with_eps(0.05)
57+
.with_kernel(Kernel::RBF(0.35)),
58+
);
59+
60+
let mut model = RegressionModel::new(features, targets, settings);
61+
model.train()?;
62+
63+
println!("{model}");
64+
65+
let evaluation_visit = DenseMatrix::from_2d_vec(&vec![vec![
66+
0.038_075_906,
67+
0.050_680_119,
68+
0.061_696_207,
69+
0.021_872_355,
70+
-0.044_223_498,
71+
-0.034_820_763,
72+
-0.043_400_846,
73+
-0.002_592_262,
74+
0.019_908_421,
75+
-0.017_646_125,
76+
]])?;
77+
let predicted_progression = model.predict(evaluation_visit)?;
78+
println!("Predicted disease progression: {predicted_progression:?}");
79+
80+
Ok(())
81+
}

src/cookbook.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,20 @@
1313
//! ```rust,ignore
1414
#![doc = include_str!("../examples/maximal_regression.rs")]
1515
//! ```
16+
//!
17+
//! ## Wisconsin Breast Cancer Classification
18+
//!
19+
//! Demonstrates loading data from `data/breast_cancer.csv`, standardizing every
20+
//! feature, and customizing the random forest search space before running the
21+
//! leaderboard comparison.
22+
//! ```rust,ignore
23+
#![doc = include_str!("../examples/breast_cancer_csv.rs")]
24+
//! ```
25+
//!
26+
//! ## Diabetes Progression Regression
27+
//!
28+
//! Shows how to impute, standardize, and tune regression algorithms on the
29+
//! diabetes dataset that ships with the repository.
30+
//! ```rust,ignore
31+
#![doc = include_str!("../examples/diabetes_regression.rs")]
32+
//! ```
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use std::error::Error;
2+
use std::path::Path;
3+
4+
use csv::ReaderBuilder;
5+
use smartcore::linalg::basic::matrix::DenseMatrix;
6+
7+
type CsvRows = (Vec<Vec<f64>>, Vec<String>);
8+
type CsvResult = Result<CsvRows, Box<dyn Error>>;
9+
10+
fn load_feature_rows<P: AsRef<Path>>(path: P) -> CsvResult {
11+
let mut reader = ReaderBuilder::new().has_headers(true).from_path(path)?;
12+
let mut features = Vec::new();
13+
let mut targets = Vec::new();
14+
15+
for record in reader.records() {
16+
let record = record?;
17+
let record_len = record.len();
18+
if record_len < 2 {
19+
return Err("dataset requires at least one feature and a target column".into());
20+
}
21+
let feature_len = record_len - 1;
22+
let mut row = Vec::with_capacity(feature_len);
23+
for value in record.iter().take(feature_len) {
24+
row.push(value.parse()?);
25+
}
26+
let target_value = record
27+
.get(feature_len)
28+
.ok_or("dataset missing target column")?;
29+
features.push(row);
30+
targets.push(target_value.to_string());
31+
}
32+
33+
Ok((features, targets))
34+
}
35+
36+
fn parse_label(raw: &str) -> Result<u32, Box<dyn Error>> {
37+
let numeric: f64 = raw.parse()?;
38+
if (numeric - 1.0).abs() < f64::EPSILON {
39+
Ok(1)
40+
} else if numeric.abs() < f64::EPSILON {
41+
Ok(0)
42+
} else {
43+
Err("unexpected label".into())
44+
}
45+
}
46+
47+
/// Load the Wisconsin Diagnostic Breast Cancer dataset from `data/breast_cancer.csv`.
48+
///
49+
/// # Errors
50+
///
51+
/// Returns an error if the CSV file cannot be read or parsed into numeric data.
52+
pub fn load_breast_cancer_dataset() -> Result<(DenseMatrix<f64>, Vec<u32>), Box<dyn Error>> {
53+
let (feature_rows, raw_targets) = load_feature_rows("data/breast_cancer.csv")?;
54+
let features = DenseMatrix::from_2d_vec(&feature_rows)?;
55+
let targets = raw_targets
56+
.into_iter()
57+
.map(|value| parse_label(&value))
58+
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
59+
60+
Ok((features, targets))
61+
}

tests/fixtures/diabetes_dataset.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
use std::error::Error;
2+
use std::path::Path;
3+
4+
use csv::ReaderBuilder;
5+
use smartcore::linalg::basic::matrix::DenseMatrix;
6+
7+
type CsvRows = (Vec<Vec<f64>>, Vec<String>);
8+
type CsvResult = Result<CsvRows, Box<dyn Error>>;
9+
10+
fn load_feature_rows<P: AsRef<Path>>(path: P) -> CsvResult {
11+
let mut reader = ReaderBuilder::new().has_headers(true).from_path(path)?;
12+
let mut features = Vec::new();
13+
let mut targets = Vec::new();
14+
15+
for record in reader.records() {
16+
let record = record?;
17+
let record_len = record.len();
18+
if record_len < 2 {
19+
return Err("dataset requires at least one feature and a target column".into());
20+
}
21+
let feature_len = record_len - 1;
22+
let mut row = Vec::with_capacity(feature_len);
23+
for value in record.iter().take(feature_len) {
24+
row.push(value.parse()?);
25+
}
26+
let target_value = record
27+
.get(feature_len)
28+
.ok_or("dataset missing target column")?;
29+
features.push(row);
30+
targets.push(target_value.to_string());
31+
}
32+
33+
Ok((features, targets))
34+
}
35+
36+
/// Load the diabetes progression dataset from `data/diabetes.csv`.
37+
///
38+
/// # Errors
39+
///
40+
/// Returns an error if the CSV file cannot be read or parsed into numeric data.
41+
pub fn load_diabetes_dataset() -> Result<(DenseMatrix<f64>, Vec<f64>), Box<dyn Error>> {
42+
let (feature_rows, raw_targets) = load_feature_rows("data/diabetes.csv")?;
43+
let features = DenseMatrix::from_2d_vec(&feature_rows)?;
44+
let targets = raw_targets
45+
.into_iter()
46+
.map(|value| -> Result<f64, Box<dyn Error>> { Ok(value.parse()?) })
47+
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
48+
49+
Ok((features, targets))
50+
}

tests/real_world_datasets.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#[path = "fixtures/breast_cancer_dataset.rs"]
2+
mod breast_cancer_dataset;
3+
#[path = "fixtures/diabetes_dataset.rs"]
4+
mod diabetes_dataset;
5+
6+
use breast_cancer_dataset::load_breast_cancer_dataset;
7+
use diabetes_dataset::load_diabetes_dataset;
8+
use smartcore::linalg::basic::arrays::Array;
9+
10+
#[test]
11+
fn breast_cancer_dataset_has_expected_shape() {
12+
let (x, y) = load_breast_cancer_dataset().expect("dataset should load");
13+
let (rows, cols) = x.shape();
14+
assert_eq!(rows, 569);
15+
assert_eq!(cols, 30);
16+
assert_eq!(y.len(), rows);
17+
let positives = y.iter().filter(|label| **label == 1).count();
18+
assert_eq!(positives, 212);
19+
}
20+
21+
#[test]
22+
fn diabetes_dataset_has_expected_shape() {
23+
let (x, y) = load_diabetes_dataset().expect("dataset should load");
24+
let (rows, cols) = x.shape();
25+
assert_eq!(rows, 442);
26+
assert_eq!(cols, 10);
27+
assert_eq!(y.len(), rows);
28+
assert!(y.iter().all(|value| value.is_finite()));
29+
}

0 commit comments

Comments
 (0)