Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions avalanche/training/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@
from .update_fecam import *
from .feature_distillation import *
from .il2m import IL2MPlugin
from .incdet_ewc import HuberEWCPlugin
from .pseudo_annotation import PseudoAnnotationPlugin
97 changes: 97 additions & 0 deletions avalanche/training/plugins/incdet_ewc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
################################################################################
# Copyright (c) 2025.
# Copyrights licensed under the MIT License.
# See the accompanying LICENSE file for terms.
#
# Date: 2025-09-13
# Author(s): Muhammad Aniq, GPT-5 assistant
# Website: avalanche.continualai.org
################################################################################
from typing import Optional

import torch
import torch.nn.functional as F # noqa: F401 (kept for future extensions)

from avalanche.training.plugins.ewc import EWCPlugin


class HuberEWCPlugin(EWCPlugin):
"""
Elastic Weight Consolidation with Huber loss regularization.

This plugin extends the standard :class:`EWCPlugin` by replacing the
quadratic penalty with a Huber penalty to improve stability when the
regularization strength is large. The importance matrix computation and
state handling are inherited from :class:`EWCPlugin`.

The Huber loss is applied to the scaled parameter difference
z = sqrt(importance) * (theta - theta_old) with threshold ``beta``.
"""

def __init__(
self,
ewc_lambda: float,
*,
beta: float = 1.0,
mode: str = "separate",
decay_factor: Optional[float] = None,
keep_importance_data: bool = False,
):
super().__init__(
ewc_lambda=ewc_lambda,
mode=mode,
decay_factor=decay_factor,
keep_importance_data=keep_importance_data,
)
self.beta = float(beta)

@staticmethod
def _huber_sum(x: torch.Tensor, beta: float) -> torch.Tensor:
"""Element-wise Huber with threshold beta, summed over all elements."""
abs_x = x.abs()
quad = 0.5 * (x**2)
lin = beta * (abs_x - 0.5 * beta)
return torch.where(abs_x <= beta, quad, lin).sum()

def before_backward(self, strategy, **kwargs):
"""
Compute Huber-based EWC penalty and add it to strategy.loss.
"""
exp_counter = strategy.clock.train_exp_counter
if exp_counter == 0:
return

device = strategy.device
penalty = torch.tensor(0.0, device=device)

if self.mode == "separate":
for experience in range(exp_counter):
for k, cur_param in strategy.model.named_parameters():
if k not in self.saved_params[experience]:
continue
saved_param = self.saved_params[experience][k]
imp = self.importances[experience][k]
new_shape = cur_param.shape
delta = cur_param - saved_param.expand(new_shape)
# Scale by sqrt(importance)
scaled_delta = imp.expand(new_shape).sqrt() * delta
# Use custom huber to avoid version mismatches
penalty = penalty + self._huber_sum(scaled_delta, self.beta)
elif self.mode == "online":
prev_exp = exp_counter - 1
for k, cur_param in strategy.model.named_parameters():
if k not in self.saved_params[prev_exp]:
continue
saved_param = self.saved_params[prev_exp][k]
imp = self.importances[prev_exp][k]
new_shape = cur_param.shape
delta = cur_param - saved_param.expand(new_shape)
scaled_delta = imp.expand(new_shape).sqrt() * delta
penalty = penalty + self._huber_sum(scaled_delta, self.beta)
else:
raise ValueError("Wrong EWC mode.")

strategy.loss += self.ewc_lambda * penalty


__all__ = ["HuberEWCPlugin"]
95 changes: 95 additions & 0 deletions avalanche/training/plugins/pseudo_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
################################################################################
# Copyright (c) 2025.
# Copyrights licensed under the MIT License.
# See the accompanying LICENSE file for terms.
#
# Date: 2025-09-13
# Author(s): Muhammad Aniq, GPT-5 assistant
# Website: avalanche.continualai.org
################################################################################
from copy import deepcopy
from typing import List

import torch
from torch.utils.data import TensorDataset # noqa: F401 (kept for reference)

from avalanche.benchmarks.utils.classification_dataset import (
_make_taskaware_tensor_classification_dataset,
)
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin


class PseudoAnnotationPlugin(SupervisedPlugin):
"""
Simple pseudo-annotation plugin for classification benchmarks.

Before training on a new experience, use the previous model to generate
high-confidence predictions on current data for classes seen so far and
append those pseudo-labeled samples to the adapted dataset.

Note: This is a minimal classification-oriented implementation (not object
detection). It assumes inputs are tensors after transforms and that the
benchmark exposes class timelines.
"""

def __init__(self, confidence_thresh: float = 0.9):
super().__init__()
self.confidence_thresh = float(confidence_thresh)
self._prev_model = None
self._seen_classes: set[int] = set()

def after_training_exp(self, strategy, **kwargs):
self._prev_model = deepcopy(strategy.model)
self._prev_model.to(strategy.device)
self._prev_model.eval()

if hasattr(strategy.experience, "classes_in_this_experience"):
self._seen_classes.update(
map(int, strategy.experience.classes_in_this_experience)
)

@torch.no_grad()
def after_train_dataset_adaptation(self, strategy, **kwargs):
if self._prev_model is None:
return

if not hasattr(strategy.experience, "classes_in_this_experience"):
return

new_classes = set(map(int, strategy.experience.classes_in_this_experience))
old_classes = sorted(list(self._seen_classes - new_classes))
if len(old_classes) == 0:
return

# Use the adapted dataset (already set to train transforms)
from avalanche.training.utils import load_all_dataset

assert strategy.adapted_dataset is not None
x, *rest = load_all_dataset(strategy.adapted_dataset)
x = x.to(strategy.device)

logits = self._prev_model(x)
probs = torch.softmax(logits, dim=1)

xs: List[torch.Tensor] = []
ys: List[int] = []
for idx in range(probs.shape[0]):
p = probs[idx]
for c in old_classes:
if c < p.numel() and float(p[c]) >= self.confidence_thresh:
xs.append(x[idx].detach().cpu())
ys.append(int(c))

if len(ys) == 0:
return

x_t = torch.stack(xs, dim=0)
y_t = torch.tensor(ys, dtype=torch.long)
pseudo_ds = _make_taskaware_tensor_classification_dataset(
x_t, y_t, task_labels=0
)

strategy.adapted_dataset = strategy.adapted_dataset.concat(pseudo_ds)


__all__ = ["PseudoAnnotationPlugin"]
1 change: 1 addition & 0 deletions docs/gitbook/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* [Contribute to Avalanche](from-zero-to-hero-tutorial/09_contribute-to-avalanche.md)

## How-Tos
* [IncDet‑EWC (Huber EWC + Pseudo‑Annotation)](how-tos/incdet_ewc.md)

* [AvalancheDataset](how-tos/avalanchedataset/README.md)
* [avalanche-datasets](how-tos/avalanchedataset/avalanche-datasets.md)
Expand Down
48 changes: 48 additions & 0 deletions docs/gitbook/how-tos/incdet_ewc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# IncDet‑EWC (Huber EWC + Pseudo‑Annotation)

## Paper summary
This paper presents IncDet, a framework that successfully adapts Elastic Weight Consolidation (EWC) to the task of incremental object detection. While EWC is effective in general incremental learning, it has previously been shown to fail when directly applied to object detection.

The authors identify two core issues responsible for this failure through controlled experiments:
1. Missing Old Class Annotations: When training on a new set of classes, images may contain objects from old classes that are not annotated. This causes the model to incorrectly learn to classify these old-class objects as background, leading to catastrophic forgetting.
2. Unstable Training: The quadratic regularisation loss used in EWC can cause gradient explosion when trying to balance performance between old and new classes, leading to unstable training.

To address these problems, the paper proposes two corresponding solutions:
1. Pseudo Annotation: To compensate for missing labels, the old model is used to predict bounding boxes for old-class objects in the new training images. These "pseudo" annotations are then combined with the ground-truth annotations for the new classes, preventing the model from misclassifying old objects as background.
2. Huber Regularization: A novel Huber regularization loss is introduced to replace EWC's original quadratic loss. This method adaptively clips the gradient for each parameter based on its importance to the old tasks, which prevents gradient explosion and allows for stable training and a better trade-off between remembering old classes and learning new ones.

These solutions are integrated into the IncDet framework, a general and flexible pipeline for incremental object detection. The process involves:
- Initial Training: A base model is trained on an initial set of classes.
- Predict & Aggregate: The trained model generates pseudo-annotations for old classes on new images, which are then aggregated with the manual annotations for the new classes.
- Incremental Fine-tuning: The model is fine-tuned using the combined annotations and the Huber regularization to learn the new classes while retaining knowledge of the old ones. This cycle can be executed recursively as more classes are added.

The framework was implemented using both Fast R-CNN and Faster R-CNN, demonstrating its versatility. Experiments on the PASCAL VOC and COCO datasets show that IncDet achieves new state-of-the-art results, surpassing previous methods in both final performance and in minimizing the performance gap compared to joint training on all classes. The proposed method is also more computationally and memory-efficient during training compared to prior auxiliary-based approaches.

Use `HuberEWCPlugin` with `PseudoAnnotationPlugin` to mitigate forgetting and leverage confident predictions on past classes.

## Quickstart (SplitMNIST)
```bash
python run_incdet_ewc.py --device auto --train_mb_size 64 --eval_mb_size 64 \
--train_epochs 1 --lr 0.01 --momentum 0.9 --ewc_lambda 1000.0 --beta 0.5 \
--confidence_thresh 0.95 --n_experiences 2 --seed 42
```

## CIFAR
```bash
python examples/incdet_ewc_cifar.py --benchmark cifar10 --use_huber 1 --train_mb_size 128 --train_epochs 2
```

## Minimal API
```python
from avalanche.training.plugins import HuberEWCPlugin, PseudoAnnotationPlugin

plugins = [
HuberEWCPlugin(ewc_lambda=1000.0, beta=0.5),
PseudoAnnotationPlugin(confidence_thresh=0.95),
]
# pass plugins=plugins into your Naive(...) strategy
```

## References
- IncDet: https://ieeexplore.ieee.org/document/9127478 — DOI: https://doi.org/10.1109/TNNLS.2020.3002583
- EWC: https://www.pnas.org/doi/10.1073/pnas.1611835114
87 changes: 87 additions & 0 deletions examples/incdet_ewc_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import argparse
import torch
from torch.optim import SGD
from torch.nn import CrossEntropyLoss

from avalanche.benchmarks.classic.ccifar10 import SplitCIFAR10
from avalanche.benchmarks.classic.ccifar100 import SplitCIFAR100
from avalanche.models import SimpleCNN
from avalanche.training.supervised.strategy_wrappers import Naive
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger

from avalanche.training.plugins import HuberEWCPlugin, EWCPlugin


def _resolve_device():
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def run(benchmark_name: str, use_huber: bool, device, mb_size: int, epochs: int):
if benchmark_name == "cifar10":
bench = SplitCIFAR10(n_experiences=5, seed=42)
num_classes = 10
elif benchmark_name == "cifar100":
bench = SplitCIFAR100(n_experiences=10, seed=42)
num_classes = 100
else:
raise ValueError("benchmark_name must be 'cifar10' or 'cifar100'")

# SimpleCNN has built-ins for CIFAR
model = SimpleCNN(num_classes=num_classes)
optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9)
criterion = CrossEntropyLoss()

eval_plugin = EvaluationPlugin(
accuracy_metrics(epoch=True, experience=True, stream=True),
loss_metrics(epoch=True, experience=True, stream=True),
loggers=[InteractiveLogger()],
)

plugin = (
HuberEWCPlugin(ewc_lambda=50.0, beta=0.5)
if use_huber
else EWCPlugin(ewc_lambda=50.0)
)

strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=mb_size,
train_epochs=epochs,
eval_mb_size=mb_size,
device=device,
plugins=[plugin],
evaluator=eval_plugin,
)

for exp in bench.train_stream:
print(f"Start of experience {exp.current_experience}")
strategy.train(exp, num_workers=2, pin_memory=False)
print("End of experience", exp.current_experience)
print("Evaluation on test stream")
strategy.eval(bench.test_stream, num_workers=2, pin_memory=False)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--benchmark", type=str, default="cifar10", choices=["cifar10", "cifar100"]
)
parser.add_argument("--use_huber", type=int, default=1)
parser.add_argument("--train_mb_size", type=int, default=128)
parser.add_argument("--train_epochs", type=int, default=2)
args = parser.parse_args()

device = _resolve_device()
run(
args.benchmark,
bool(args.use_huber),
device,
args.train_mb_size,
args.train_epochs,
)
Loading