We propose LoRA-MCL, a training scheme that extends next-token prediction in language models with a method designed to decode diverse, plausible sentence continuations at inference time. Traditional language modeling is an intrinsically ill-posed problem: given a context, multiple futures may be equally plausible. Our approach leverages Multiple Choice Learning (MCL) and the Winner-Takes-All loss to efficiently handle ambiguity through Low-Rank Adaptation. We provide a theoretical interpretation of applying MCL to language modeling, assuming the data is generated from a mixture of distributions. To illustrate the proposed approach, we use data sampled from mixtures of Markov chains. We then demonstrate with experiments on visual and audio captioning, as well as machine translation, that our method achieves high diversity and relevance in generated outputs.
Overview of LoRA-MCL. A linear layer with LoRA enabled is shown. Frozen base weights are in blue; trainable LoRA adapters are in light red. The forward pass (in gray) is computed independently for each hypothesis. Gradients (purple arrows) are stronger for the winning hypothesis compared to the others.
The repository is organized as follows:
-
General-Purpose LoRA-MCL Wrapper. The
peft_mcl/directory provides a general wrapper for applying LoRA-MCL to any Hugging Face model with minimal configuration. -
Synthetic Data Experiments. The
toy/directory contains experiments using Mixtures of Markov Chains for synthetic data evaluation. -
Audio Captioning with Qwen2-Audio. The
Qwen2-Audio/directory includes all code related to audio captioning experiments based on the Qwen2-Audio. -
Image Captioning with LLaVA. To be released soon.
-
Diverse Machine Translation with ALMA. The
ALMA/directory provides code for diverse machine translation experiments with ALMA.
The tests/ directory is used for running tests (not exhaustive).
We provide a general module named peft_mcl that is designed to be integrated in any model from the transformers library.
β οΈ Important: You need a HuggingFace access token for most models. Set it withexport HF_TOKEN=your_token_herebefore running the examples.
The
peft_mclmodule should work with any Hugging Face model from the transformers library with minimal configuration changes.
To use the peft_mcl package, first clone the repository and create a conda environment:
git clone https://github.com/Victorletzelter/LoRA-MCL.git
cd LoRA-MCL
conda create -y -n testenv python=3.10.15
conda activate testenv
pip install -e .See the example below to get started, or run example_usage.py for a demo with dummy data.
from peft_mcl import get_peft_mcl, MCLTrainer, patch_peft_for_mcl
from peft import LoraConfig
import torch
patch_peft_for_mcl(enable=True) # Always run `patch_peft_for_mcl(enable=True)` before using peft_mcl to ensure proper PEFT library patching.
# Standard LoRA configuration
lora_r = 16 # LoRA Rank
lora_alpha = 16 # LoRA Alpha
lora_dropout = 0.1 # LoRA Dropout (during training)
target_modules = ["q_proj", "k_proj", "v_proj", "down_proj", "up_proj"] # Modules where LoRA is enabled
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
task_type="CAUSAL_LM"
)
mcl_params = { # LoRA-MCL related parameters.
'num_hyps': 3, # Number of hypotheses (K)
'wta_training_mode' : "relaxed-wta", # "wta", "relaxed-wta" or "annealed-wta"
'use_group_lora' : False, # For faster training, consider using `use_group_lora=True` which allows parallelization over hypotheses
'wta_params_epsilon' : 0.05, # \\varepsilon when wta_training_mode == 'relaxed-wta'
# Parameters only used if 'wta_training_mode'='annealed-wta':
'wta_params_ini_temp': 1.0, # initial temperature when 'wta_training_mode'='annealed-wta'
'wta_params_fin_temp': 1e-6, # final temperature (i.e., temperature from which the mode is switched back to wta) when 'wta_training_mode'='annealed-wta'
'wta_params_decay_rate': 0.999, # decay rate rho, where temperature(t) = temperature(0)*rho**{t}
'wta_params_schedule_mode': "global_step", # type of decay for the temperature, either "global step" if t represents the training step, and "epoch_number" if t is the epoch number
}
loading_kwargs = { # Loading kwargs for the model which are passed to the from_pretrained method.
"low_cpu_mem_usage": True,
"torch_dtype": torch.bfloat16}
model_name = "gpt2" # Any HuggingFace model
# MCL Configuration
model = get_peft_mcl(
model_name_or_path=model_name, # Any HuggingFace model
lora_config=lora_config,
**mcl_params
)Training can be performed as:
from transformers import TrainingArguments, Trainer, AutoTokenizer
train_config = TrainingArguments(...) # Define your training arguments
tokenizer = AutoTokenizer.from_pretrained(model_name) # Auto Tokenizer of the pretrained model.
train_dataset = ... # Your training dataset
eval_dataset = ... # Your validation dataset
trainer = MCLTrainer( # Your trainer (MCLTrainer is required if wta_training_mode=annealed-wta, otherwise, the vanilla HF Trainer can be used)
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
args=train_config,
data_collator=data_collator,
)
trainer.train() # Launch trainingGeneration can be done with the generate method by specifying the hypothesis to evaluate:
import torch
input_ids = ... # Input (tokenized) prompt
hypothesis_idx = 0 # Hypothesis chosen for generation (should be in {0,...,num_hypotheses - 1})
out = model.generate(inputs = input_ids, hypothesis_idx = hypothesis_idx)To reproduce the synthetic data experiments, first go to the toy dir with cd toy
Create a virtual environment with conda create -y -n synthenv python=3.10.15. Activate with conda activate synthenv and install the required packages with pip install -r toy/requirements.txt.
π Note: LaTeX is enabled plot rendering. Install it with:
sudo apt-get install -y dvipng texlive-latex-extra texlive-fonts-recommended cm-super.
To reproduce Figure 1 and Figure 4, please run the training with the following commands:
python train.py experiment=transition1
python train.py experiment=transition2Please check
transition1.yamlto see the overriden parameters. You may want to investigate the following parameters in this experiment:
batch_size: the batch size used.seq_len: the sequence length.N_it: the number of training iterations.window_size: the size of the window in sliding window attention (which is used to improve the convergence of transformers on Markov data)seed: the seed used.
The results will be logged in results/pkl_files and results/runs. More precisely, the runs in the pkl_files will be organized as follows:
βββ toy
βββ results
βββ pkl_files
βββ <run_folder_name> # By Default <start_run_time>_<wta_training_mode>_seqlen_<seq_len>_Nit_<Number of iterations>_N_hyps_<number of hypotheses min>_<number of hypotheses max>_seed_<seed>_name_<experiment name>_window_<window_size>_div_rank_<div rank boolean>
βββ config.yaml # config associated with the run
βββ results.pkl # pkl file containing the model weights
βββ training_loss_vs_training_steps.png # (if do_save_plot is True), training loss against training steps (displayed with LaTeX rendering). This corresponds to the curves in Figures 1 and 4 (left).
βββ transition_matrices_comparison_grid_weighted.png # (if do_save_plot is True), 2x3 grid containing the predicted and target matrices at the end of the training. This corresponds to the right subplot in Figures 1 and 4.
βββ transition_matrices_comparison_n_hyps_1_weighted # Folder containing the evolution of the transition matrices (predicted and target, for the 1 hyp model) during training
βββ transition_matrices_comparison_n_hyps_2 # Folder containing the evolution of the transition matrices (predicted and target, for the 2 hyp model) during training
βββ runs
βββ <run_folder_name> # By default, <date>_<start_run_time>_<machine name>, a folder which contains the training tensorboard logs.The generated figure should be as follows (see Figure 1 and 4 from the main paper).
Comparison of LoRA-MCL with vanilla maximum likelihood estimation (MLE) (Figure 1)
This section is intended for reproducing the results on Audio Captioning. To follow the instruction, please go to the Qwen2-Audio dir with cd Qwen2-Audio. This code was tested on Ubuntu 20.04.6.
If you use conda, you can setup a virtual environment with
bash setup_env.shAn environment named qwen_env will be setup in Qwen2-Audio/env. You can then activate it with PROJECT_DIR="$(pwd) ; source activate $PROJECT_DIR/env.
Note that the implementation of LoRA-MCL in Qwen2-Audio was done without using the
peft_mcl, but instead with a local copy of the transformers library inQwen2-Audio/local_transformers. Although it is not fully tested yet, you can try using thepeft_mclby installingpip install transformers==4.51.0and settinguse_local_transformers=Falseinconf/model/default.yaml.
For the purpose of metrics computation, download metrics-related cache data (~7.6GB) with
python download/download_cache.pyAnd set your hugging face token with export HF_TOKEN=....
The preprocessing on AudioCaps and Clotho was done following the pipeline in the Conette repository. Note that, unlike Conette (which considers a setup where the audio encoder is frozen and both the spectrogram and embeddings are computed offline β using a CNN in their case), our setup only preprocesses the raw audio, with only the following steps:
- Resampling the audio from 44.1 kHz (Clotho) and 32 kHz (AudioCaps) to 16 kHz. We used 16kHz as target sampling rate to be consistent with the Whisper implementation used in Qwen Audio.
- If the data contains multiple channels, averaging across channels to obtain a single-channel audio signal, as done in the Conette implementation.
- Packing the data into HDF5 files.
This is done through
get_resample_mean_rawfunction ofQwen2-Audio/utils/conette/full_conette/conette-audio-captioning/src/conette/transforms/get.py. We had to implement this function in our cloned Conette repository because, in the Qwen-2-Audio pipeline, the spectrogram computation is done online.
To download the (preprocessed) Audio Captioning datasets (AudioCaps ~32GB and Clotho-V2 ~11GB), please run:
python download/download_data.pyThe datasets will be placed in Qwen2-Audio/data.
π‘ Tip: We recommend using our preprocessed data rather than reproducing the preprocessing steps, unless you specifically need to modify the preprocessing pipeline.
If you want instead to download your own version of AudioCaps and Clotho and reproduce the steps, the download of the data followed by preprocessing on AudioCaps and Clotho can be done by first installing required packages with sudo apt install -y openjdk-11-jdk ffmpeg zip && python3 -m pip install -U yt-dlp[default]. You can then run:
CONETTE_DIR=LoRA-MCL/Qwen2-Audio/utils/conette/full_conette/conette-audio-captioning
common_args="data.download=true pack_to_hdf=true audio_t=resample_mean_raw post_hdf_name=none pretag=resample_mean_raw data.n_workers=0"
python ${CONETTE_DIR}/src/conette/prepare.py data=audiocaps audio_t.src_sr=32000 ${common_args} # AudioCaps source sampling rate is 32kHz
python ${CONETTE_DIR}/src/conette/prepare.py data=clotho audio_t.src_sr=44100 ${common_args} # Clotho source sampling rate is 44.1kHzThese commands will first download the datasets as .flac files in ${CONETTE_DIR}/src/conette/data/{DATASET}. Then, the datasets will be packed into hdf files in ${CONETTE_DIR}/src/conette/data/HDF. Once the hdf files are created, move them to LoRA-MCL/Qwen2-Audio/data/* with:
mv ${CONETTE_DIR}/src/conette/data/HDF/* data/*To train all LoRA-MLE and LoRA-MCL models (with 5 hypotheses) on audio captioning, run:
bash train_loop.shIndividual training scripts can be launched with commands of the form:
bash train_scripts.sh dataset_name num_hyps wta_mode epsilon rank seed use_slurmWhere:
dataset_name: dataset name (audiocapsorclotho)num_hyps: number of hypotheseswta_mode: wta training mode (wtaorrelaxed-wtaorannealed-wta).epsilon: Ξ΅ parameter whenwta_mode=relaxed-wta.rank: LoRA rank.seed: random seed.use_slurmwhether the scripts should be submitted using slurm.
By default, fine-tuning runs for 10 epochs on Clotho and 1 epoch on AudioCaps, with maximum audio lengths of 30 seconds and 10 seconds, respectively. See train_script.sh and the scripts/ folder for more details.
When launching the trainings, the logs will be saved in Qwen2-Audio/logs following the Hydra template, that is organized as follows:
βββ tsExperiments
βββ logs
βββ <dataset_name>
βββ <run_folder_name> # By Default: <start_run_time>_<wta_training_mode>_epsilon-<epsilon_value>_<num_hyps>-hyp_rank-<rank>_dataset_name, where start_run_time is in the form %Y-%m-%d_%H-%M-%S
βββ .hydra # Folder to save the config yaml files associated with the run
βββ adapter_model # Folder where the adapters are saved. The latter correspond to the best adapters on the validation set. By default, it contains the folders `lora{k}` for k in {0,...,num_hyps - 1}. Note that when using a Group LoRA architecture (which allow to parallelize over the hypotheses), the parameters of all the adapters are stored in lora0. (TODO: precise better)
βββ events.out.tf.events.*** # A tensorboard event fileBy default, MLFlow is enabled and will save experiment files in Qwen2-Audio/logs/mlflow.
To perform inference, first extract the checkpoint paths with
python extract/extract_ckpts.py --log_dir logsA json file ckpts.json containing checkpoints paths will then be created in Qwen2-Audio/.
Note that if you want to run inference and evaluation with our own checkpoints, you can download them (~3.45GB) in Qwen2-Audio/ckpts with
python download/download_ckpts.pyIn this case, run the path extraction script with the correct dir with:
python extract/extract_ckpts.py --log_dir ckptsYou can then create environment variables with the checkpoint paths with:
python extract/generate_checkpoint_variables.py --input ckpts.json --output checkpoint_vars.sh --format export
. checkpoint_vars.shPlease refer to the generated checkpoint_vars.sh for an overview of the created environment variables in the current shell.
You can now run the evaluation script, by defining:
N_HYPS=5 # number of hypotheses to use for evaluation
DATASET=audiocaps # can be clotho or audiocapsand running:
bash eval_scripts.sh $N_HYPS $DATASETThis script submits a slurm job for each configuration by running a sbatch command through scripts/qwen_eval.slurm slurm script.
The inference includes the following decoding methods for each training scheme:
- LoRA-MLE with Beam Search (BS), Diverse Beam Search (DBS) decoding, with Ξ»=0.5,0.8,1.0, with Beam sizes B = 5, 10 and 25 (if
lora_mle=Truein the script). - LoRA-MLE with Test-time augmentation (TTA) and BS decoding with B = 1, 2 and 5 (if
tta=True). - LoRA-MoE with BS, DBS decoding, with Ξ»=0.5,0.8,1.0, with B = 5, 10 and 25 (if
tta=Truein the script).. - LoRA-MCL with BS decoding with B = 1, 2 and 5. This is performed for the relaxed variant if
lora_mcl_relaxed=Truein the script and with the annealed variant iflora_mcl_annealed=True. Note that you can perform sampling-based decoding for each method by settingdo_sampling=Truein the script.
The instructions in the Inference and Evaluation section compute evaluation metrics by default
(if cfg.do_compute_metrics=True in main_hydra.py).
Metrics computation is performed using the Conette, AAC-Metrics and COCO-Caption libraries.
Custom adaptations of the AAC-Metrics library are located in metrics/aac_metrics_custom/. These modifications enable sentence-based oracle metrics computation, since standard captioning metrics do not support multiple hypotheses.
The computation is performed using the mh_evaluate function (defined in aac_metrics_custom/functional/evaluate.py), which is called within the MH_AllMetrics class
(defined in mh_all_metrics.py).
Below is a simplified pseudo-code illustrating the computation for each metric i (e.g., SPIDEr):
# Let i be the index of a metric (e.g., SPIDEr)
for key in candidates: # key = '0', ..., f'{n_hypotheses-1}' and each candidate[key] is a List[str].
# Compute the sentence-based metric for each hypothesis:
_, outs_sents_i[key] = metric(candidates[key], mult_references)
# Extract the oracle (best) sentence-level metric for each sentence and compute the average.
mean_oracle_metric_i = extract_oracle_sentence_metrics(
outs_sents_i=outs_sents_i, # Sentence-based results per hypothesis
higher_is_better=higher_is_better_i # Whether higher values indicate better performance
)The function used to extract oracle scores is defined as follows (simplified for clarity):
def extract_oracle_sentence_metrics(
outs_sents: Dict[str, Tensor],
higher_is_better: bool
):
"""
Extract the oracle-based sentence-level metric from the outs_sents dictionary.
Args:
outs_sents: Dictionary containing sentence-level metrics for each hypothesis.
Expected keys: '0', '1', ..., 'num_hypotheses-1'.
Each outs_sents[key] is a list or tensor of float values.
higher_is_better: Whether higher metric values indicate better performance.
Returns:
The mean oracle sentence-level metric over all evaluation examples.
"""
N_sentences = len(outs_sents['0']) # Number of examples in the evaluation set
# Tensor of shape [N_sentences, N_hypotheses]
metric_values = torch.stack(
[outs_sents[hypothesis_idx] for hypothesis_idx in outs_sents.keys()],
dim=1
)
# Compute the oracle value for each sentence
if higher_is_better:
individual_oracle_sents = torch.max(metric_values, dim=1).values
else:
individual_oracle_sents = torch.min(metric_values, dim=1).values
return torch.mean(individual_oracle_sents).item()The results can be extracted as csv files by running
bash extract/extract_results.shThe results will be stored as csv files in results/saved_csv.
We provide a script to extract latex tables and plot figures. LaTeX is enabled for plot rendering. Install it with: sudo apt-get install -y dvipng texlive-latex-extra texlive-fonts-recommended cm-super. It can be executed with
python extract/extract_metrics_csv.pyThis will generate a txt file containing the LaTeX command for generating the table with quantitative results in results/latex, and a plot displaying the results in results/figure.
Our image captioning codebase is built upon on the excellent LLaVA repository. This code was tested with Ubuntu 22.04. It will be released soon.
This section provides instructions for reproducing the results on Machine Translation. To follow the steps, navigate to the MT directory with cd ALMA. The codebase is built upon the excellent ALMA repository. This code was tested with Ubuntu 22.04.
Create an environment with
conda create -n alma python=3.11
conda activate almaAnd install the required packages:
bash install_alma.shTo launch the fine-tuning, both for a one hypothesis model and for a 3-hypotheses model, please run:
bash train_1h.sh # For the 1-hypothesis model with rank 16 and 48
bash train_3h.sh # For the 3-hypotheses model with rank 16.The logs will be written in logs and will have the following structure:
βββ logs
βββ ${N_HYPS}h_rank_${RANK} # Logs associated with the N_HYPS hypotheses with rank $RANK.
βββ best_ckpt # Best checkpoint according to the validation loss
βββ checkpoint-X # Other checkpoint stored during training at step / epoch X.
βββ tensorboard # Folder with tensorboard logs.Once the trainings have finished, you can run the inference and metrics computation with:
CKPT_FOLDER=logs
bash eval_1h.sh $CKPT_FOLDER # For the 1-hypothesis model with rank 16 and 48
bash eval_3h.sh $CKPT_FOLDER # For the 3-hypotheses model with rank 16.This will generate an additional folder in the above structure, containing the predictions as well as metrics, computed over the full test dataset (when the number of test samples is -1, as explained below).
βββ logs
βββ ${N_HYPS}h_rank_${RANK} # Logs associated with the N_HYPS hypotheses with rank $RANK.
βββ ...
βββ ${DATASET}-${N_TEST}samples-NR${NR} # Outputs when testing on $DATASET, with N_TEST samples for evaluation and setting num_return_sequences=${NR} for generation (num_return_sequences=1 for multi-hypotheses models). If N_TEST=-1, evaluation is performed on the full dataset.
βββ ${DATASET}-${N_TEST}samples-B${BEAM}_Div${DIV_PEN}_NR${NR}_test.yaml # yaml file with the metrics when using beam size = ${BEAM}, diversity_penalty=${DIV_PEN}, num_return_sequences=${NR}
βββ ${DATASET}-${N_TEST}samples-B${BEAM}_Div${DIV_PEN}_NR${NR}_test.pkl # pkl file containing the predictions (same syntax as above otherwise)
βββ ${DATASET}-${N_TEST}samples-B${BEAM}_Div${DIV_PEN}-NR${NR}.txt # txt file containing the predictions (same syntax as above otherwise)To run inference and evaluation with our checkpoints, first download the adapter weights (~542MB) with
python download/download_ckpts.pyThen run the above commands with CKPT_FOLDER=ckpts instead:
CKPT_FOLDER=ckpts
bash eval_1h.sh $CKPT_FOLDER # For the 1-hypothesis model with rank 16 and 48
bash eval_3h.sh $CKPT_FOLDER # For the 3-hypotheses model with rank 16.To extract the results, first install latex: sudo apt-get install -y dvipng texlive-latex-extra texlive-fonts-recommended cm-super and run:
python extract_results.pyThis will write a txt file containing the latex table, as well as the Figure displaying the results in the (Diversity, Quality) plane (in png and pdf format) in results/.
This work was funded by the French Association for Technological Research (ANRT CIFRE contract 2022-1854) and the LISTEN Laboratory of TΓ©lΓ©com Paris. It also benefited from access to the HPC resources of IDRIS (allocation 2024-AD011014345) by GENCI.
This repository contains source code adapted from the following Github repositories, for which we greatly thank the authors:
pytorch-lightning (under Apache 2.0 License)
Hydra (under MIT License)
transformers (under Apache 2.0) and peft (under Apache 2.0)
Conette (which is adapted in Qwen2-Audio/utils/conette/full_conette/conette-audio-captioning/) and aac-metrics (under MIT License)
CocoCaption (under BSD License)
LLaVA (which is adapted in llava16/) (under Apache 2.0)
TextCaps from MMF under BSD License.
ALMA (which is adapted in ALMA/) under MIT License.
fairseq under MIT License.
analyzing-uncertainty-nmt under CC BY-NC 4.0 License.
If our work helped in your research, feel free to give us a star β or to cite us with the following bibtex code:
@article{loramcl,
title={Multiple Choice Learning of Low Rank Adapters for Language Modeling},
author={Letzelter, Victor and Malard, Hugo and Fontaine, Mathieu and Richard, Ga{\"e}l and Essid, Slim and Bursuc, Andrei and P{\'e}rez, Patrick},
journal={arXiv preprint arXiv:2507.10419},
year={2025}
}