-
Notifications
You must be signed in to change notification settings - Fork 6
Trait Discovery Logbook
An on-going logbook for project management.
- We can use sparse autoencoders (SAEs) to discover precise semantic concepts in vision models. (prior work, but recent)
- Using a more specific dataset to train the SAE leads to more specific semantic concepts (known in LLMs, but not formally studied in SAEs for vision models)
- We train an SAE on a large dataset of butterfly images and train an SAE that discovers precise different traits in butterfly images (new but not exciting yet)
- We use this SAE to find traits that are present in Heliconius Erato but NOT Heliconius Melpomene (these traits are known to taxonomists)
- We use this SAE to find traits that are present in Heliconius XX but NOT Heliconius YY (these traits are NOT known to taxonomists and are a viable scientific discover)
Here are the experiments we need to run:
- Train a ReLU SAE on DINOv2/ImageNet-1K activations.
- Train a ReLU SAE on DINOv2/butterfly activations.
- Inspect visual features for both SAEs.
- Filter SAE features based on Heliconius Melpomene/Erato images.
- Verify if the features match the Melpomene/Erato trait.
- Repeat 4 and 5 on XX/YY to find a new trait.
I'm sure this will work on the first try :)
Some individual tasks:
- [Jake] Train a ReLU SAE on DINOv2/ImageNet-1K activations. This can be done right away because we know it works.
- [Jake] Train a ReLU SAE on DINOv2/iNat21 activations. This can be done right away because we know it works.
- [Sam] Create one or more datasets of butterflies for training. These can contain images from iNat21, ToL-10M or ToL-250M.
- Train a ReLU SAE on DINOv2/Butterfly activations. This depends on the previous bullet point.
- [Sam] Get a list of butterfly mimic pairs, with and without diagnostic traits. We need Neil, Dan and/or Christopher for this.
- [Sam & Jake] Formalize the trait discovery task as a computational problem. Describe it mathematically.
- Write a script to evaluate SAEs on the butterfly mimic pairs on the task of trait discovery.
Jake and I met:
We learned that
- H. melpomene vs erato is our biggest specimen image pair.
- H. melpomene vs elevatus is bounded by elevatus with 56 images.
- We need other sources of data for evaluation. And they might not be specimen images. Rather, can we identify traits using species-labeled in-situ images?
Discussion around computational framing:
- What if the SAE discovers an "erato" feature?
- Do we have to fine-tune on classification?
- We want traits that are well-describable, and we want "erato"-specific traits that are not just the "entire erato".
- We have this notion of specificity that's important for a trait.
- How do we convince CS/AI people that "two ecologists we're friends with thought this was cool and that's sufficient evaluation"?
Comimic Erato/Melpomene subspecies of interest:
| H. Erato subspecies | Count | H. Melpomene subspecies | Count | Total |
|---|---|---|---|---|
| H. Erato lativita | 1827 | H. Melpomene malleti | 898 | 2725 |
| H. Erato cybria | 1030 | H. Melpomene cythera | 55 | 1085 |
| H. Erato notabilis | 310 | H. Melpomene plesseni | 251 | 561 |
| H. Erato hydara | 140 | H. Melpomene melpomene | 237 | 377 |
| H. Erato venus | 208 | H. Melpomene vulcanus | 103 | 311 |
| H. Erato demophoon | 221 | H. Melpomene rosina | 64 | 285 |
| H. Erato phyllis | 194 | H. Melpomene nanna | 77 | 271 |
| H. Erato erato | 33 | H. Melpomene thelxiopeia | 1 | 34 |
| H. Erato favorinus | 0 | H. Melpomene amaryllis | 33 | 33 |
| H. Erato almafreda | 3 | H. Melpomene meriana | 18 | 21 |
| H. Erato etylus | 3 | H. Melpomene ecuadorensis | 14 | 17 |
| H. Erato emma | 0 | H. Melpomene algaope | 2 | 2 |
| H. Erato microclea | 0 | H. Melpomene xenoclea | 0 | 0 |
Todo:
- Identify remaining comimic pairs (see H. melpomene vs H. elevatus)
- Train SAE on activations of ImageNet (Jake)
- Generate iNat activations
- Train SAE on iNat activations
- Generate butterfly activations (Jake)
- Train SAE on butterfly activations (Jake)
Trait‑Discovery Report—Key Takeaways (Mr Gippity)
- Granularity: how narrowly an interpretability map localizes signal.
-
Two trait‑agnostic metrics
- Entropy (H): pixelwise Shannon entropy of the normalized heat‑map. ↓H ⇒ narrower, more specific.
- Spatial autocorrelation (Moran’s I or similar): measures local coherence. ↑I ⇒ structured (edges, stripes, ordered dots).
| Scenario | Entropy | Autocorr. | Interpretation |
|---|---|---|---|
| Tight, structured feature (edge, dot) | Low | High | High‑quality granular trait |
| Tight, random blob | Low | Low | Localized noise / artifact |
| Large, structured pattern (broad band) | High | High | Coarse but meaningful |
| Diffuse noise | High | Low | Worst case |
Together they:
- Remain label‑free (trait discovery friendly).
- Penalize both random noise and overly diffuse heat‑maps.
- Label dependence via prediction: If “quality” = predictive power, the system chases any correlates (e.g., lighting, camera artifacts).
- Resolution / scale sensitivity: H and I change with image size & preprocessing.
- Trait interactions ignored: metrics score single heat‑maps; biology often uses combinations.
- Missing causal grounding: high score ≠ causally relevant trait.
- Dynamic traits: age/seasonal changes can lower scores even when biologically critical.
- Human‑interpretability gap: high‑ranked traits may be hard to explain biologically.
- Coarse traits usually drive species/sex prediction; fine traits differentiate individuals / cryptic species.
- Over‑focusing on granularity can hurt broad‑class accuracy; ignoring it misses subtle discoveries. A hierarchical blend is ideal.
Score = α · PredictivePower + β · (–Entropy) + γ · Autocorr.
Fundamental risks
- Confirmation bias: optimizes for known labels, suppressing novel unlabeled traits.
- Artifact capture: camera/system quirks can score high.
- No guarantee of biological meaning: high score may still be irrelevant.
Bottom line: useful as a filter, not as sole arbiter of trait quality.
- Keep H & I as quick, universal filters.
- Add unsupervised diversity objective (e.g., novelty search) to escape label bias.
- Periodically benchmark against expert‑annotated traits to calibrate weights.
- Inspect high‑scoring heat‑maps manually/with LLM‑judge on a small batch to sanity‑check biological plausibility.
- Store multi‑scale versions of heat‑maps to reduce resolution sensitivity.
Use the entropy + autocorrelation duo for fast triage, but pair it with unsupervised exploration and intermittent expert validation to ensure genuinely new, biologically meaningful trait discovery.
Quick notes from reading Matryoshka SAE paper:
- Using matryoshka SAEs for this project is likely a better way to go. It may help to prevent the feature splitting/absorption that is a big problem with the current basic SAE architecture.
- Performing a similar linear regression probing experiment for feature absorption is also probably a really good idea. It may be a good way to quantify the performance of these models for diagnostic trait recovery and may be a good way to compare our model's performance to existing models.
Plus, a general note:
- Now that we have more time, we should consider expanding our experiments, both within butterflies and to other datasets. Within butterflies, we may want to consider an even more restricted experiment just focusing on the two most common co-mimic subspecies in the Cambridge dataset. Outside of butterflies, we could consider looking into some other sources of data that have diagnostic traits, like birds.
We likely want a test set of some kind to compare different baselines. I think CUB-200-2011 will be the de-facto recommendation from reviewers, but I think this applies equally well to any image-level trait dataset.
We want to measure how precise different methods are at identifying a trait and how well traits are recovered, with some minimum threshold on precision, for different methods.
CUB-200-2011 has 6K images in the test set, each annotated with 312 different binary traits, like "has blue feathers" or "has red eyes". We want to find an SAE feature for each trait. We also should compare some baselines like
- Dot product with random activations (random directions in the embedding space)
- Dot product with principal components (PCA) of the activation training data.
- Dot product with k-means centroids, fit on the activation training data.
- Linear probes trained specifically for binary classification on the activations. This should represent an upper bound on precision.
For each trait and each potential feature (vector), we can pick out the best feature by measuring precision@k (different k to k-means) by taking the top-k images as ranked by the feature and measuring how many actually contain the trait, using the trait labels. Let me make this concrete.
Broadly, we want to get a matrix scores of scores
Once we do, we can take the top k images for a given feature f with np.argsort(scores[:, f])[-k:]. Then we can see, for given trait t, how many of those images actually have the trait t. That's precision@k for trait t and feature f.
- Mean precision@k: What is the mean precision@k across all traits for each method?
- Aggregate recall: How many traits have at least one feature with precision@k > 0.8 for each method?
Some pseudocode below.
# D = ViT residual dimension
# S = Sparse dimension
# N = Number of images in dataset.
# P = Number of patches per image.
import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import saev
vit_dataset = saev.activations.Dataset(...)
# SAE
sae = saev.nn.load(...)
# PCA baseline
pca = PCA(n_components=2048).fit(vit_dataset)
# k-Means baseline
kmeans = KMeans(n_clusters=2048).fit(all_patches)
# Entry:
#
# scores[i,j] = a_{u,j}(x_i)
#
# is the scalar ``activation strength'' of direction u_j on image x_i.
sae_scores = []
pca_scores = []
km_scores = []
for vit_acts_PD in dataset:
sae_scores.append(einops.reduce(
sae.encode(vit_acts_PD),
"patches d_sae -> d_sae",
reduction="max"
))
# Might have to L2 normalize these vectors (both PCA and KMeans)
pca_scores.append(einops.reduce(
einops.einsum(
vit_acts_PD,
pca.components_,
"patches dim, k dim -> patches k"
),
"patches k -> k",
reduction="max"
))
km_scores.append(einops.reduce(
einops.einsum(
vit_acts_PD,
kmeans.centroids_,
"patches dim, k dim -> patches k"
),
"patches k -> k",
reduction="max"
))
sae_scores_NS = np.stack(sae_scores, axis=0)
pca_scores_NS = np.stack(pca_scores, axis=0)
km_scores_NS = np.stack(km_scores, axis=0)
def evaluate(
scores: Float[np.ndarray, "N M"],
y_true: Int[np.ndarray, " N"],
k: int = 100,
) -> tuple[float, int]:
"""Calculate the feature with the best precision@k for a given trait.
Args:
scores: A ranking table; each feature (M) assigns a scalar value to each image (N).
y_true: 0/1 labels for a given trait (attribute) for all images.
"""
N, M = scores.shape
best_prec, best_j = 0.0, None
for j in range(M):
topk = np.argsort(scores[:, j])[-k:]
prec = y_true[topk].mean()
if prec > best_prec:
best_prec, best_j = prec, j
return best_prec, best_j
# Measure precision for each attribute in CUB-200-2011 using evaluate().