Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 116 additions & 27 deletions dadapy/diff_imbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class DiffImbalance:
seed (int): seed of JAX random generator, default is 0. Different seeds determine different mini-batch
partitions.
l1_strength (float): strength of the L1 regularization (LASSO) term. Default is 0.
gradient_clip_value (float): maximum norm for gradient clipping. If 0, no clipping is
applied. Default is 0. This is useful when weights are sometimes automatically set to NaN and
there can be gradient explosions.
point_adapt_lambda (bool): whether to use a global smoothing parameter lambda for the c_ij coefficients
in the DII (if False), or a different parameter for each point (if True). Default is True.
k_init (int): initial rank of neighbors used to set lambda. Ranks are defined starting from 1. If
Expand Down Expand Up @@ -180,6 +183,7 @@ def __init__(
learning_rate=1e-2,
learning_rate_decay=None,
num_points_rows=None,
gradient_clip_value=0.0,
):
"""Initialise the DiffImbalance class."""
self.nfeatures_A = data_A.shape[1]
Expand Down Expand Up @@ -258,6 +262,7 @@ def __init__(
self.num_epochs = num_epochs
self.batches_per_epoch = batches_per_epoch
self.l1_strength = l1_strength
self.gradient_clip_value = gradient_clip_value
self.point_adapt_lambda = point_adapt_lambda
self.k_init = k_init
self.k_final = k_final
Expand Down Expand Up @@ -848,7 +853,14 @@ def _init_optimizer(self):
raise ValueError(
f'Unknown learning rate decay schedule "{self.learning_rate_decay}". Choose among None, "cos" and "exp".'
)
optimizer = opt_class(self.lr_schedule)
# Set up optimizer with optional gradient clipping
if self.gradient_clip_value > 0:
optimizer = optax.chain(
optax.clip_by_global_norm(self.gradient_clip_value),
opt_class(self.lr_schedule),
)
else:
optimizer = opt_class(self.lr_schedule)

# Initialize training state
self.state = train_state.TrainState.create(
Expand Down Expand Up @@ -1155,10 +1167,20 @@ def forward_greedy_feature_selection(
learning_rate=self.learning_rate,
learning_rate_decay=self.learning_rate_decay,
num_points_rows=self.num_points_rows,
gradient_clip_value=self.gradient_clip_value,
)

# Set initial parameters and train
_, _ = dii_copy.train()
try:
_, _ = dii_copy.train()
except AssertionError as e:
print(f"Training failed for feature [{feature}]: {str(e)}")
print(f"Skipping feature [{feature}] and continuing...")
single_feature_diis.append(
float("inf")
) # Use infinity as a large penalty
single_feature_errors.append(None)
continue

# Compute DII on the full dataset
if compute_error:
Expand All @@ -1185,9 +1207,18 @@ def forward_greedy_feature_selection(
# Convert to numpy arrays for easier manipulation
single_feature_diis = np.array(single_feature_diis)

# Select the best n_best single features
n_best_actual = min(n_best, n_features)
selected_indices = np.argsort(single_feature_diis)[:n_best_actual]
# Check if we have any valid features (not infinity)
valid_features = np.isfinite(single_feature_diis)
if not np.any(valid_features):
print("ERROR: All single features failed during training!")
return [], [], [], []

# Select the best n_best single features (only from valid ones)
valid_indices = np.where(valid_features)[0]
valid_diis = single_feature_diis[valid_indices]
n_best_actual = min(n_best, len(valid_indices))
best_valid_indices = np.argsort(valid_diis)[:n_best_actual]
selected_indices = valid_indices[best_valid_indices]

# Convert indices to lists for consistent processing
selected_features = [[idx] for idx in selected_indices]
Expand Down Expand Up @@ -1270,10 +1301,24 @@ def forward_greedy_feature_selection(
learning_rate=self.learning_rate,
learning_rate_decay=self.learning_rate_decay,
num_points_rows=self.num_points_rows,
gradient_clip_value=self.gradient_clip_value,
)

# Set initial parameters and train
_, _ = dii_copy.train()
try:
_, _ = dii_copy.train()
except AssertionError as e:
print(
f"Training failed for feature set {candidate_set}: {str(e)}"
)
print(
f"Skipping feature set {candidate_set} and continuing..."
)
candidate_diis.append(
float("inf")
) # Use infinity as a large penalty
candidate_errors.append(None)
continue

# Compute DII on the full dataset
if compute_error:
Expand Down Expand Up @@ -1305,9 +1350,18 @@ def forward_greedy_feature_selection(
if not candidate_features: # No more features to add
break

# Select the best n_best candidates for the next iteration
n_best_actual = min(n_best, len(candidate_features))
best_indices = np.argsort(candidate_diis)[:n_best_actual]
# Check if we have any valid candidates (not infinity)
valid_candidates = np.isfinite(candidate_diis)
if not np.any(valid_candidates):
print("ERROR: All candidate feature sets failed during training!")
break

# Select the best n_best candidates for the next iteration (only from valid ones)
valid_indices = np.where(valid_candidates)[0]
valid_diis = candidate_diis[valid_indices]
n_best_actual = min(n_best, len(valid_indices))
best_valid_indices = np.argsort(valid_diis)[:n_best_actual]
best_indices = valid_indices[best_valid_indices]
selected_features = [candidate_features[i] for i in best_indices]

# Print the best feature set information
Expand Down Expand Up @@ -1347,18 +1401,25 @@ def forward_greedy_feature_selection(
learning_rate=self.learning_rate,
learning_rate_decay=self.learning_rate_decay,
num_points_rows=self.num_points_rows,
gradient_clip_value=self.gradient_clip_value,
)

# Set initial parameters and train
_, _ = dii_copy.train()

# Print and store optimal weights
print(
f"\nOptimal weights for feature set {candidate_features[best_idx]}: {dii_copy.params_final}\n"
)
try:
_, _ = dii_copy.train()
# Print and store optimal weights
print(
f"\nOptimal weights for feature set {candidate_features[best_idx]}: {dii_copy.params_final}\n"
)
# Save optimal weights
best_weights = np.array(dii_copy.params_final)
except AssertionError as e:
print(
f"Training failed for best feature set {candidate_features[best_idx]}: {str(e)}"
)
print(f"Using zero weights for this iteration...")
best_weights = np.zeros(n_features)

# Save optimal weights
best_weights = np.array(dii_copy.params_final)
best_weights_list.append(best_weights)

# Print the best n-tuple information
Expand Down Expand Up @@ -1519,13 +1580,24 @@ def backward_greedy_feature_selection(
learning_rate=self.learning_rate,
learning_rate_decay=self.learning_rate_decay,
num_points_rows=self.num_points_rows,
gradient_clip_value=self.gradient_clip_value,
)

# Set initial parameters and train
_, _ = dii_copy.train()

# Store the trained weights
trained_weights = dii_copy.params_final
try:
_, _ = dii_copy.train()
# Store the trained weights
trained_weights = dii_copy.params_final
except AssertionError as e:
print(
f"Training failed for feature set {candidate_set}: {str(e)}"
)
print(f"Skipping feature set {candidate_set} and continuing...")
candidate_diis.append(
float("inf")
) # Use infinity as a large penalty
candidate_errors.append(None)
continue

# Use return_final_dii to compute DII on the full dataset
dii_copy.params_final = trained_weights
Expand Down Expand Up @@ -1560,9 +1632,18 @@ def backward_greedy_feature_selection(
# Convert to numpy arrays for easier manipulation
candidate_diis = np.array(candidate_diis)

# Select the best n_best candidates
n_best_actual = min(n_best, len(candidate_features))
best_indices = np.argsort(candidate_diis)[:n_best_actual]
# Check if we have any valid candidates (not infinity)
valid_candidates = np.isfinite(candidate_diis)
if not np.any(valid_candidates):
print("ERROR: All candidate feature sets failed during training!")
break

# Select the best n_best candidates (only from valid ones)
valid_indices = np.where(valid_candidates)[0]
valid_diis = candidate_diis[valid_indices]
n_best_actual = min(n_best, len(valid_indices))
best_valid_indices = np.argsort(valid_diis)[:n_best_actual]
best_indices = valid_indices[best_valid_indices]

# Update current features for the next iteration
current_features = [candidate_features[i] for i in best_indices]
Expand Down Expand Up @@ -1596,13 +1677,21 @@ def backward_greedy_feature_selection(
learning_rate=self.learning_rate,
learning_rate_decay=self.learning_rate_decay,
num_points_rows=self.num_points_rows,
gradient_clip_value=self.gradient_clip_value,
)

# Set initial parameters and train
_, _ = dii_copy.train()
try:
_, _ = dii_copy.train()
# Save optimal weights
best_weights = dii_copy.params_final
except AssertionError as e:
print(
f"Training failed for best feature set {best_feature_set}: {str(e)}"
)
print(f"Using zero weights for this iteration...")
best_weights = np.zeros(n_features)

# Save optimal weights
best_weights = dii_copy.params_final
best_weights_list.append(best_weights)

# Store results
Expand Down
1,248 changes: 1,106 additions & 142 deletions examples/notebook_on_differentiable_imbalance_jax.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions tests/test_diff_imbalance_jax/test_greedy_dii.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_DiffImbalance_forward_greedy():
learning_rate=1e-1,
learning_rate_decay="cos",
num_points_rows=None,
gradient_clip_value=0.0,
)
weights, imbs = dii.train()

Expand Down Expand Up @@ -143,6 +144,7 @@ def test_DiffImbalance_backward_greedy():
learning_rate=1e-1,
learning_rate_decay="cos",
num_points_rows=None,
gradient_clip_value=0.0,
)
weights, imbs = dii.train()

Expand Down Expand Up @@ -220,6 +222,7 @@ def test_DiffImbalance_greedy_symmetry_5d_gaussian():
learning_rate=1e-1,
learning_rate_decay="cos",
num_points_rows=None,
gradient_clip_value=0.0,
)
weights, imbs = dii.train()

Expand Down Expand Up @@ -353,6 +356,7 @@ def test_DiffImbalance_greedy_random_initialization():
learning_rate=1e-1,
learning_rate_decay="cos",
num_points_rows=None,
gradient_clip_value=0.0,
)
weights, imbs = dii.train()

Expand Down
Loading