Skip to content

YankaiGroup/RADDT

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Differentiable Decision Tree via "ReLU+Argmin" Reformulation

RADDT ("ReLU+Argmin"-based Differentiable Decision Tree) is designed to learn oblique decision trees by optimizing the entire tree structure via gradient-based optimization. The algorithm introduces a novel unconstrained reformulation for decision tree training, leveraging "ReLU+Argmin" followed by softmin approximation to overcome the inherent non-differentiability of decision trees. It supports both regression and classification tasks. For detailed information on the algorithm, please refer to our Spotlight 🏴 paper published at NeurIPS 2025, available at https://neurips.cc/virtual/2025/poster/119074.

Language Dependencies Status

Repository Overview

This repository contains the source code for the RADDT algorithm. To support various use cases and hardware configurations, the codebase is organized into the following branches:

  • regression: Designed for regression tasks. Supports single-GPU or CPU execution.

  • classification: Designed for classification tasks. Supports single-GPU or CPU execution.

  • multi-gpu: Designed for regression tasks, adapted for distributed multi-GPU training using the Distributed Data Parallel (DDP) strategy.

Current Branch: regression

You are currently on the regression branch.

  • Supported Task: Regression
  • Hardware: Single-GPU or CPU

We are currently working on wrapping up the code into a more organized and user-friendly interface like pip package. Stay tuned for updates!

If you require some specific features not included in the current release, please feel free to contact us for access to more additional codes :)

Requirements

  • PyTorch 2.0.1
  • Python 3.9.6
    • scikit-learn 1.0.2
    • numpy 1.21.6
    • pandas 1.3.5
    • h5py 3.8.0

Script Description

  • src folder contains the scripts of "ReLU+Argmin"-based Differentiable Decision Tree Optimization for oblique trees, termed as RADDT.

    • ancestorTF_File subfolder contains the deterministic tree path routing (simply tree path of sample assignment) in h5 file format. These files can be generated by the function of treePathCalculation in treeFunc.py script.
    • treeFunc.py includes utility functions.
    • dataset.py is to load the dataset.
    • warmStart.py generates the warm-start initialization based on CART method.
    • modifiedScheduler.py is a learning rate scheduler with initial linear warmup. We acknowledge the original contributor for this type of scheduler as posted in PyTorch issues #80308.
    • RADDT.py includes the main functions.
  • test folder contains the script of running these algorithms.

    • test_RADDT.py is to test the RADDT method.
  • data folder contains the datasets. These publicly-available datasets can be obtained from the UCI Machine Learning repository and OpenML. Each dataset is shuffled and split into training, validation and testing sets, and saved in the *.csv format.

  • sh_narval_MultiGPU folder contains job submission script for running experiments on the "Narval" cluster in Computer Canada. This version utilizes the Distributed Data Parallel (DDP) strategy using multi-GPUs (e.g. eight GPUs in the example).

We release the source code in several different branches. The scripts are nearly identical, but the multi-gpu branch is adapted for multi-GPU training using the Distributed Data Parallel (DDP) strategy. Scripts modified for this purpose are distinguished by a "_DDP" suffix.

Usage Example

The examples for single GPU/CPU computing can be implemented via:

# test the RADDT method 
python .\test\test_RADDT.py 3 3 1 1 2 3000 "cuda" 10 5

For distributed multi-GPU computing, please refer to the job submission script in the sh_narval_MultiGPU folder within multi-gpu branch for an example of running on the "Narval" cluster in Computer Canada.

Others

If you encounter any errors or notice unexpected tree performance, please don't hesitate to contact us (maoq@student.ubc.ca).

License

This repository is published under the terms of the GNU General Public License v3.0 .

About

"ReLU+Argmin"-based Differentiable Decision Tree (Spotlight Paper at NeurIPS 2025)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%