|
| 1 | +#![allow(clippy::needless_doctest_main)] |
| 2 | +//! Real-world breast cancer classification example. |
| 3 | +//! |
| 4 | +//! This example trains a multi-model comparison on the Wisconsin Diagnostic |
| 5 | +//! Breast Cancer dataset bundled with the repository. The workflow shows how to |
| 6 | +//! load a CSV file, wire up a preprocessing pipeline, and customize the |
| 7 | +//! algorithms that participate in the comparison. |
| 8 | +//! |
| 9 | +//! Run with: |
| 10 | +//! |
| 11 | +//! ```bash |
| 12 | +//! cargo run --example breast_cancer_csv |
| 13 | +//! ``` |
| 14 | +
|
| 15 | +#[path = "../tests/fixtures/breast_cancer_dataset.rs"] |
| 16 | +mod breast_cancer_dataset; |
| 17 | + |
| 18 | +use std::error::Error; |
| 19 | + |
| 20 | +use automl::settings::{ |
| 21 | + ClassificationSettings, FinalAlgorithm, PreprocessingPipeline, PreprocessingStep, |
| 22 | + RandomForestClassifierParameters, StandardizeParams, |
| 23 | +}; |
| 24 | +use automl::{ClassificationModel, DenseMatrix}; |
| 25 | +use breast_cancer_dataset::load_breast_cancer_dataset; |
| 26 | + |
| 27 | +fn main() -> Result<(), Box<dyn Error>> { |
| 28 | + let (features, targets) = load_breast_cancer_dataset()?; |
| 29 | + |
| 30 | + let preprocessing = PreprocessingPipeline::new() |
| 31 | + .add_step(PreprocessingStep::Standardize(StandardizeParams::default())); |
| 32 | + |
| 33 | + let settings = ClassificationSettings::default() |
| 34 | + .with_number_of_folds(5) |
| 35 | + .shuffle_data(true) |
| 36 | + .with_final_model(FinalAlgorithm::Best) |
| 37 | + .with_preprocessing(preprocessing) |
| 38 | + .with_random_forest_classifier_settings( |
| 39 | + RandomForestClassifierParameters::default() |
| 40 | + .with_n_trees(200) |
| 41 | + .with_max_depth(8) |
| 42 | + .with_min_samples_split(4) |
| 43 | + .with_min_samples_leaf(2), |
| 44 | + ); |
| 45 | + |
| 46 | + let mut model = ClassificationModel::new(features, targets, settings); |
| 47 | + model.train()?; |
| 48 | + |
| 49 | + println!("{model}"); |
| 50 | + |
| 51 | + let example_patient = DenseMatrix::from_2d_vec(&vec![vec![ |
| 52 | + 13.540, 14.360, 87.460, 566.300, 0.097, 0.052, 0.024, 0.015, 0.153, 0.055, 0.284, 0.915, |
| 53 | + 2.376, 23.420, 0.005, 0.013, 0.010, 0.005, 0.018, 0.002, 14.230, 17.730, 91.760, 618.800, |
| 54 | + 0.118, 0.115, 0.068, 0.025, 0.210, 0.062, |
| 55 | + ]])?; |
| 56 | + let predictions = model.predict(example_patient)?; |
| 57 | + println!("Predicted class for the evaluation patient: {predictions:?}"); |
| 58 | + |
| 59 | + Ok(()) |
| 60 | +} |
0 commit comments