From 41f0b5c205b1e8b9bfa5fddf12102e59b3f99e07 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Tue, 27 May 2025 15:56:08 +0200 Subject: [PATCH 01/13] save plots --- evaluate/distributional_paper/make_plots.py | 524 ++++++++++++++++++++ 1 file changed, 524 insertions(+) create mode 100644 evaluate/distributional_paper/make_plots.py diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py new file mode 100644 index 0000000..7a1ffe2 --- /dev/null +++ b/evaluate/distributional_paper/make_plots.py @@ -0,0 +1,524 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from scipy.stats import rankdata +from scipy.interpolate import interp1d +from matplotlib.colors import LogNorm + + +pd.set_option("display.max_colwidth", None) +pd.set_option("display.max_columns", None) +pd.set_option("display.expand_frame_repr", False) + +from src.io_utils import get_filtered_and_grouped_paths, collect_results, num_model_params + + +def generate_sample_sizes(total_samples: int) -> tuple[int, ...]: + if total_samples < 1: + raise ValueError("total_samples must be ≥ 1") + bases = (1, 2, 5) # 1-2-5 pattern for each power of ten + result = [] + power = 0 + while True: + scale = 10 ** power + for b in bases: + value = b * scale + if value > total_samples: + # Stop once the next milestone exceeds the target + result.append(total_samples) if result[-1] != total_samples else None + return tuple(result) + result.append(value) + if value == total_samples: + return tuple(result) + power += 1 + +def _dominance_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Return the non-dominated (Pareto-optimal) points, ordered by cost. + The frontier is defined as points for which no other point has + *both* lower cost (x) and lower mean p_harmful (y). + + Parameters + ---------- + xs, ys : 1-D arrays of equal length + Coordinates of the candidate points. + + Returns + ------- + frontier_xs, frontier_ys : 1-D arrays + Coordinates of the Pareto frontier, sorted by xs ascending. + """ + order = np.argsort(xs) # sort by cost + xs_sorted, ys_sorted = xs[order], ys[order] + + frontier_x, frontier_y = [0], [0] + best_y_so_far = 0 + for x_val, y_val in zip(xs_sorted, ys_sorted): + if y_val > best_y_so_far: # strictly better in y + frontier_x.append(x_val) + frontier_y.append(y_val) + best_y_so_far = y_val + frontier_x.append(xs_sorted[-1]) + frontier_y.append(frontier_y[-1]) + return np.asarray(frontier_x), np.asarray(frontier_y) + + +# ------------------------------------------------------------------ +# 1. Empirical‑copula Pareto frontier (no Archimedean fit required) +# ------------------------------------------------------------------ +def _copula_frontier(xs: np.ndarray, + ys: np.ndarray, + eps: float = 1e-9) -> tuple[np.ndarray, np.ndarray]: + """ + Pareto frontier estimator based on the empirical copula level + set at alpha* = min_i C_n(U_i). + + Parameters + ---------- + xs, ys : 1-D arrays + Coordinates of the candidate points. + eps : float + Numerical tolerance when selecting the boundary. + + Returns + ------- + frontier_xs, frontier_ys : 1-D arrays + Estimated frontier, ordered by xs ascending. + """ + n = xs.size + # pseudo‑observations U_k (Eq. 7 in the paper) + u = rankdata(-xs, method="ordinal") / (n + 1.0) + v = rankdata(ys, method="ordinal") / (n + 1.0) + U = np.column_stack((u, v)) + + # empirical copula values at the sample points + # C_n(U_i) = 1/n * number of points dominated by U_i + dom_matrix = (U[:, None, :] <= U[None, :, :]).all(axis=2) + C_vals = dom_matrix.mean(axis=1) + + alpha_star = C_vals.min() # Lemma 2.1 + on_boundary = np.abs(C_vals - alpha_star) < eps + + fx, fy = xs[on_boundary], ys[on_boundary] + order = np.argsort(fx) + return -fx[order], fy[order] + + +# ------------------------------------------------------------------ +# 2. Thin wrapper so you can switch methods with one argument +# ------------------------------------------------------------------ +def _pareto_frontier(xs: np.ndarray, + ys: np.ndarray, + method: str = "empirical_copula", + **kwargs): + if method == "empirical_copula": + return _copula_frontier(xs, ys, **kwargs) + elif method == "basic": + # your original dominance‑based frontier + return _dominance_frontier(xs, ys) # rename old function + else: + raise ValueError(f"Unknown frontier method '{method}'") + + +def pareto_plot( + results: dict[str,np.ndarray], + baseline: dict[str,np.ndarray] | None = None, + title: str = "Pareto Frontier", + sample_levels_to_plot: tuple[int, ...]|None = None, + frontier_method: str = "basic", + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + plot_points: bool = True, + plot_frontiers: bool = True, + plot_envelope: bool = False, + verbose: bool = True, + cumulative: bool = False, + flops_per_step: int | None = None, + n_x_points: int = 10000, + x_scale="log", + threshold: float|None = None, + color_scale: str = "linear", +): + """ + Scatter the full design-space AND overlay Pareto frontiers + for selected sampling counts. + + Parameters + ---------- + data : (x, y) where x is ignored and y has shape (n_steps, total_samples) + Your original score tensor. + sampling_cost_factor : float, optional + Multiplies the sampling cost term j. + frontier_samples : tuple[int, ...], optional + Which n_sample values to draw Pareto lines for. + + Returns + ------- + None + """ + y = np.array(results[metric]) # (B, n_steps, n_samples) + if threshold is not None: + y = y > threshold + n_runs, n_steps, n_total_samples = y.shape + if sample_levels_to_plot is None: + sample_levels_to_plot = generate_sample_sizes(n_total_samples) + + flops_sampling = np.array(results["flops_sampling"]) # (B, n_steps) + if "flops" in results: + flops_optimization = np.array(results["flops"]) # (B, n_steps) + else: + flops_optimization = np.zeros_like(flops_sampling) # (B, n_steps) + if flops_per_step is not None: + flops_optimization += flops_per_step(np.arange(flops_optimization.shape[1])) + + + def subsample_and_aggregate(step_idx, sample_idx, cumulative, y, opt_flops, sampling_flops, rng): + opt_flop = np.mean(opt_flops[:, :step_idx+1].sum(axis=1)) + sampling_flop = np.mean(sampling_flops[:, step_idx]) * sample_idx + if cumulative and step_idx > 0: + samples_at_end = y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1) + samples_up_to_now = y[:, :step_idx, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0] + values = np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1) + return (opt_flop + sampling_flop, step_idx, sample_idx, values.mean(0)) + return (opt_flop + sampling_flop, step_idx, sample_idx, y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1).mean(axis=0)) + + + def get_pts(y, opt_flops, sampling_flops): + n_runs, n_steps, total_samples = y.shape + rng = np.random.default_rng() + pts = [] # (cost, step, n_samples, mean_p) + for j in range(1, total_samples + 1, 1): + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, j, cumulative, y, opt_flops, sampling_flops, rng)) + pts = np.asarray(pts) + return pts + + pts = get_pts(y, flops_optimization, flops_sampling) + cost, step_idx, n_samp, mean_p = pts.T + max_cost = max(cost) + if x_scale == "log": + x_interp = np.logspace(13, np.log10(max_cost+1), n_x_points) + else: + x_interp = np.linspace(0, max_cost+1, n_x_points) + + + # ---------- scatter all points ---------- + plt.figure(figsize=(9, 6)) + if plot_points: + if color_scale == "log": + color_norm = LogNorm() + else: + color_norm = None + sc = plt.scatter(cost, mean_p, c=n_samp, cmap="viridis", alpha=0.15, s=3, norm=color_norm) + plt.xlabel("FLOPS") + if threshold is None: + plt.ylabel("Mean p_harmful") + else: + plt.ylabel(f"Mean ASR (threshold: {threshold})") + + # ---------- overlay Pareto frontiers ---------- + cmap = plt.get_cmap("viridis") + if color_scale == "log": + norm = LogNorm(n_samp.min(), n_samp.max()) + else: + norm = plt.Normalize(n_samp.min(), n_samp.max()) + rng = np.random.default_rng() + + n_smoothing = 50 + if plot_frontiers: + for j in sample_levels_to_plot: + xs = [] + ys = [] + for _ in range(n_smoothing): + pts = [] + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling, rng)) + + pts = np.asarray(pts) + cost, _, _, mean_p = pts.T + + fx, fy = _pareto_frontier(cost, mean_p, method=frontier_method) + xs.append(fx) + ys.append(fy) + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) for x_, y_ in zip(xs, ys)] + + color = cmap(norm(j)) + plt.plot( + x_interp, + np.mean(y_interp, axis=0), + marker="o", + linewidth=1.8, + markersize=2, + label=f"{j} samples", + color=color, + ) + + if plot_envelope: + # ---------- overlay Pareto frontiers ---------- + n_smoothing = 50 + y_interps = [] + for j in range(1, n_total_samples+1): + xs = [] + ys = [] + for n in range(n_smoothing): + pts = [] + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling, rng)) + + pts = np.asarray(pts) + cost, step_idx, n_samp, mean_p = pts.T + + fx, fy = _pareto_frontier(cost, mean_p, method=frontier_method) + xs.append(fx) + ys.append(fy) + + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) for x_, y_ in zip(xs, ys)] + y_interps.append(np.mean(y_interp, axis=0)) + y_interps = np.array(y_interps) + argmax = np.argmax(y_interps, axis=0) + argmax = np.maximum.accumulate(argmax) + y_envelope = np.max(y_interps, axis=0) + + # color = "g" #[ + # plt.plot( + # np.arange(max_cost+1), + # y_envelope, + # marker="o", + # linewidth=1.8, + # markersize=4, + # label=f"envelope", + # color=color, + # ) + color = [cmap(norm(argmax[i])) for i in range(len(argmax))] + plt.scatter(x_interp, y_envelope, c=color, s=2) + + title_suffix = "" + + y = np.array(baseline[metric]) # (B, n_steps, n_samples) + if threshold is not None: + y = y > threshold + + baseline_flops_sampling = np.array(baseline["flops_sampling"]) + if "flops" in baseline: + baseline_flops_optimization = np.array(baseline["flops"]) # (B, n_steps) + else: + baseline_flops_optimization = np.zeros_like(baseline_flops_sampling) # (B, n_steps) + if flops_per_step is not None: + baseline_flops_optimization += flops_per_step(np.arange(baseline_flops_optimization.shape[1])) + + if y is not None: + title_suffix = f" ({n_runs}, {y.shape[0]})" + if verbose: + print(n_runs, "for main") + print(y.shape[0], "for baseline") + n_runs, n_steps, n_total_samples = y.shape + assert n_total_samples == 1 + + rng = np.random.default_rng() + pts = [] # (cost, step, n_samples, mean_p) + for i in range(0, n_steps, 1): + for j in range(1, n_total_samples + 1, 1): + pts.append(subsample_and_aggregate(i, j, cumulative, y, baseline_flops_optimization, baseline_flops_sampling, rng)) + + pts = np.asarray(pts) + cost, step_idx, n_samp, mean_p = pts.T + + # ---------- scatter all points ---------- + # sc = plt.scatter(cost, mean_p, c="r", alpha=0.35, s=4) + + # ---------- overlay Pareto frontiers ---------- + if plot_frontiers or plot_envelope: + mask = n_samp == 1 + fx, fy = _pareto_frontier(cost[mask], mean_p[mask], method=frontier_method) + y_interp = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=(0, max(fy)))(x_interp) + plt.plot( + x_interp, + y_interp, + marker="o", + linewidth=1.8, + markersize=2, + label=f"greedy", + color="r", + ) + plt.title(title + title_suffix) + plt.grid(True, linewidth=0.3) + + plt.xscale(x_scale) + plt.legend(title="Frontiers", loc="upper left" if x_scale == "log" else "lower right") + plt.tight_layout() + plt.savefig(f"evaluate/distributional_paper/pareto_plots/{title}.pdf") + + + +# ---------------------------------------------------------------------------------- +# Pareto plots – simplified +# ---------------------------------------------------------------------------------- +import numpy as np + +MODELS = { + "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta Llama 3.1 8B", + "google/gemma-3-1b-it": "Gemma 3.1 1B", + "GraySwanAI/Llama-3-8B-Instruct-RR": "Llama 3 CB", +} + +FLOPS_PER_STEP = { + "autodan": lambda s, c: 69845248149248 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, + "gcg": lambda s, c: 14958709489152 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, + "beast": lambda s, c: 10447045889280 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, + "pair": lambda s, c: 83795198566400 + 78737584640 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, +} # for 0.5B model + +# Attack-specific configuration ----------------------------------------------------- +ATTACKS = [ + ("pair", dict( + title_suffix="PAIR", + cumulative=True, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + )), + ("autodan", dict( + title_suffix="AutoDAN", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + "early_stopping_threshold": 0, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + )), + ("gcg", dict( + title_suffix="GCG", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + )), + ("gcg", dict( + title_suffix="GCG Entropy Loss", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + "num_steps": 250, + "loss": "entropy_adaptive", + "token_selection": "default", + "use_prefix_cache": True, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + )), + ("bon", dict( + title_suffix="BoN", + cumulative=False, + sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 0.7}}, + baseline_params=lambda: { + # BoN's baseline is *Direct* with one deterministic sample + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + baseline_attack="direct", + postprocess=lambda data, metric: data.__setitem__( + metric, np.array(data[metric]).transpose(0, 2, 1) + ), + )), + ("bon", dict( + title_suffix="BoN Repro", + cumulative=False, + sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 1.0}}, + baseline_params=lambda: { + # BoN's baseline is *Direct* with one deterministic sample + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + baseline_attack="direct", + postprocess=lambda data, metric: data.__setitem__( + metric, np.array(data[metric]).transpose(0, 2, 1) + ), + )), + ("direct", dict( + title_suffix="Direct", + cumulative=True, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 1000, "temperature": 0.7}, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + skip_if_empty=True, # gracefully continue if no paths were found + )), +] + +METRIC = ("scores", "strong_reject", "p_harmful") +GROUP_BY = {("model",)} +DATASET_IDX = list(range(50)) + + +# Helper --------------------------------------------------------------------------- +def run_attack( + model: str, + model_title: str, + atk_name: str, + cfg: dict, +): + print("Attack:", atk_name) + + # ---------- helper to fetch data ---------- + def fetch(attack: str, attack_params: dict): + filter_by = dict( + model=model, + attack=attack, + attack_params=attack_params, + dataset_params={"idx": DATASET_IDX}, + ) + paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) + results = collect_results(paths, infer_sampling_flops=True) + assert len(results) == 1, len(results) + return list(results.values())[0] + + # ---------- sampled run ---------- + sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) + + # Attack-specific post-processing + if post := cfg.get("postprocess"): + post(sampled_data, METRIC) + + # ---------- baseline run ---------- + baseline_attack = cfg.get("baseline_attack", atk_name) + baseline_data = fetch(baseline_attack, cfg["baseline_params"]()) + + # ---------- plot ---------- + pareto_plot( + sampled_data, + baseline_data, + title=f"{model_title} {cfg['title_suffix']}", + cumulative=cfg["cumulative"], + metric=METRIC, + flops_per_step=lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)), + threshold=None, + ) + +# Main loop ------------------------------------------------------------------------ +for model_key, model_title in MODELS.items(): + print("Model:", model_key) + for atk_name, atk_cfg in ATTACKS: + try: + run_attack(model_key, model_title, atk_name, atk_cfg) + except Exception as e: + print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") From 6c9d3b5e0793b483baecde4f058cbe21d5c7e2cb Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Wed, 28 May 2025 12:31:27 +0200 Subject: [PATCH 02/13] update plots --- evaluate/distributional_paper/make_plots.py | 168 +++++++++++++++++--- 1 file changed, 146 insertions(+), 22 deletions(-) diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py index 7a1ffe2..f16641f 100644 --- a/evaluate/distributional_paper/make_plots.py +++ b/evaluate/distributional_paper/make_plots.py @@ -1,4 +1,7 @@ import matplotlib.pyplot as plt +import scienceplots +plt.style.use("science") + import numpy as np import pandas as pd from scipy.stats import rankdata @@ -196,7 +199,7 @@ def get_pts(y, opt_flops, sampling_flops): cost, step_idx, n_samp, mean_p = pts.T max_cost = max(cost) if x_scale == "log": - x_interp = np.logspace(13, np.log10(max_cost+1), n_x_points) + x_interp = np.logspace(11, np.log10(max_cost+1), n_x_points) else: x_interp = np.linspace(0, max_cost+1, n_x_points) @@ -209,11 +212,11 @@ def get_pts(y, opt_flops, sampling_flops): else: color_norm = None sc = plt.scatter(cost, mean_p, c=n_samp, cmap="viridis", alpha=0.15, s=3, norm=color_norm) - plt.xlabel("FLOPS") + plt.xlabel("Cost (FLOPS (optimization + sampling))", fontsize=14) if threshold is None: - plt.ylabel("Mean p_harmful") + plt.ylabel("Mean p_harmful", fontsize=14) else: - plt.ylabel(f"Mean ASR (threshold: {threshold})") + plt.ylabel(f"Mean ASR (threshold: {threshold})", fontsize=14) # ---------- overlay Pareto frontiers ---------- cmap = plt.get_cmap("viridis") @@ -242,9 +245,12 @@ def get_pts(y, opt_flops, sampling_flops): y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) for x_, y_ in zip(xs, ys)] color = cmap(norm(j)) + y_mean = np.mean(y_interp, axis=0) + # Filter out leading zeros + nonzero_mask = y_mean > 0 plt.plot( - x_interp, - np.mean(y_interp, axis=0), + x_interp[nonzero_mask], + y_mean[nonzero_mask], marker="o", linewidth=1.8, markersize=2, @@ -253,7 +259,6 @@ def get_pts(y, opt_flops, sampling_flops): ) if plot_envelope: - # ---------- overlay Pareto frontiers ---------- n_smoothing = 50 y_interps = [] for j in range(1, n_total_samples+1): @@ -278,18 +283,10 @@ def get_pts(y, opt_flops, sampling_flops): argmax = np.maximum.accumulate(argmax) y_envelope = np.max(y_interps, axis=0) - # color = "g" #[ - # plt.plot( - # np.arange(max_cost+1), - # y_envelope, - # marker="o", - # linewidth=1.8, - # markersize=4, - # label=f"envelope", - # color=color, - # ) - color = [cmap(norm(argmax[i])) for i in range(len(argmax))] - plt.scatter(x_interp, y_envelope, c=color, s=2) + # Filter out leading zeros + nonzero_mask = y_envelope > 0 + color = [cmap(norm(argmax[i])) for i in range(len(argmax)) if nonzero_mask[i]] + plt.scatter(x_interp[nonzero_mask], y_envelope[nonzero_mask], c=color, s=2) title_suffix = "" @@ -330,9 +327,10 @@ def get_pts(y, opt_flops, sampling_flops): mask = n_samp == 1 fx, fy = _pareto_frontier(cost[mask], mean_p[mask], method=frontier_method) y_interp = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=(0, max(fy)))(x_interp) + nonzero_mask = y_interp > 0 plt.plot( - x_interp, - y_interp, + x_interp[nonzero_mask], + y_interp[nonzero_mask], marker="o", linewidth=1.8, markersize=2, @@ -341,7 +339,7 @@ def get_pts(y, opt_flops, sampling_flops): ) plt.title(title + title_suffix) plt.grid(True, linewidth=0.3) - + plt.ylim(bottom=0) plt.xscale(x_scale) plt.legend(title="Frontiers", loc="upper left" if x_scale == "log" else "lower right") plt.tight_layout() @@ -463,6 +461,17 @@ def get_pts(y, opt_flops, sampling_flops): }, skip_if_empty=True, # gracefully continue if no paths were found )), + ("direct", dict( + title_suffix="Direct temp 1.0", + cumulative=True, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 1000, "temperature": 1.0}, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + skip_if_empty=True, # gracefully continue if no paths were found + )), ] METRIC = ("scores", "strong_reject", "p_harmful") @@ -522,3 +531,118 @@ def fetch(attack: str, attack_params: dict): run_attack(model_key, model_title, atk_name, atk_cfg) except Exception as e: print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") + + + +# Helper --------------------------------------------------------------------------- +def run_attack_2( + model: str, + model_title: str, + atk_name: str, + cfg: dict, +): + print("Attack:", atk_name) + + # ---------- helper to fetch data ---------- + def fetch(attack: str, attack_params: dict): + filter_by = dict( + model=model, + attack=attack, + attack_params=attack_params, + dataset_params={"idx": DATASET_IDX}, + ) + paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) + results = collect_results(paths, infer_sampling_flops=True) + assert len(results) == 1, len(results) + return list(results.values())[0] + + # ---------- sampled run ---------- + sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) + + # Attack-specific post-processing + if post := cfg.get("postprocess"): + post(sampled_data, METRIC) + + data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")])[:, 0] + # Create histogram plot + plt.figure(figsize=(10, 6)) + plt.hist(data.flatten(), bins=100, alpha=0.7, edgecolor='black') + plt.xlabel('p_harmful', fontsize=14) + plt.ylabel('Frequency', fontsize=14) + plt.title(f'{model_title} - {atk_name} - p_harmful Distribution', fontsize=16) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Save the plot + filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}.png" + plt.savefig(filename, dpi=300, bbox_inches='tight') + plt.close() + +for model_key, model_title in MODELS.items(): + print("Model:", model_key) + for atk_name, atk_cfg in ATTACKS: + if atk_name != "direct": continue + try: + run_attack_2(model_key, model_title, atk_name, atk_cfg) + except Exception as e: + print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") + + +# Helper --------------------------------------------------------------------------- +def run_attack_2( + model: str, + model_title: str, + atk_name: str, + cfg: dict, +): + print("Attack:", atk_name) + + # ---------- helper to fetch data ---------- + def fetch(attack: str, attack_params: dict): + filter_by = dict( + model=model, + attack=attack, + attack_params=attack_params, + dataset_params={"idx": DATASET_IDX}, + ) + paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) + results = collect_results(paths, infer_sampling_flops=True) + assert len(results) == 1, len(results) + return list(results.values())[0] + + # ---------- sampled run ---------- + sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) + + # Attack-specific post-processing + if post := cfg.get("postprocess"): + post(sampled_data, METRIC) + + plt.figure(figsize=(10, 6)) + data_list = [] + positions = [] + for i in range(0, 250, 25): + data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")])[:, i] + data_list.append(data.flatten()) + positions.append(i) + + # Create violin plot + plt.violinplot(data_list, positions=positions, widths=20, showmeans=True, showmedians=True) + plt.xlabel('Step', fontsize=14) + plt.ylabel('Frequency', fontsize=14) + plt.title(f'{model_title} - {atk_name} - p_harmful Distribution', fontsize=16) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Save the plot + filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}.png" + plt.savefig(filename, dpi=300, bbox_inches='tight') + plt.close() + +for model_key, model_title in MODELS.items(): + print("Model:", model_key) + for atk_name, atk_cfg in ATTACKS: + if atk_name != "gcg": continue + try: + run_attack_2(model_key, model_title, atk_name, atk_cfg) + except Exception as e: + print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") From 7bfef126ca9fbba496c1d1f003a570a857876d56 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Wed, 28 May 2025 12:33:36 +0200 Subject: [PATCH 03/13] speed up num_return_sequences != 1 --- src/lm_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/lm_utils.py b/src/lm_utils.py index 58a0c50..55e540d 100644 --- a/src/lm_utils.py +++ b/src/lm_utils.py @@ -25,6 +25,7 @@ def generate_ragged_batched( initial_batch_size: int = 256, use_cache: bool = True, verbose: bool = False, + num_return_sequences: int = 1, **kwargs, ) -> list[list[str]]: """ @@ -48,7 +49,12 @@ def generate_ragged_batched( # Shorter sequences will come first to maximize batch size sorted_indexed_inputs = sorted(list(enumerate(input_list)), key=lambda x: x[1].size(0)) - sorted_input_list = [item for _, item in sorted_indexed_inputs] + + # Duplicate each prompt for multiple return sequences here is faster because it + # avoids the slower for-loop in generate_ragged_batched. We can't easily move this + # inside generate_ragged directly because we need to do it outside of the + # with_max_batchsize context. + sorted_input_list = [it for _, item in sorted_indexed_inputs for it in [item] * num_return_sequences] original_indices = [index for index, _ in sorted_indexed_inputs] def func(chunk): @@ -59,6 +65,7 @@ def func(chunk): token_list=chunk if input_type == "tokens" else None, embedding_list=chunk if input_type == "embeddings" else None, use_cache=use_cache, + num_return_sequences=1, **kwargs, ) sorted_outputs = with_max_batchsize(func, sorted_input_list, initial_batch_size=initial_batch_size, verbose=verbose) @@ -66,7 +73,7 @@ def func(chunk): # Unsort the outputs to match the original input order outputs = [None] * len(input_list) for i, original_index in enumerate(original_indices): - outputs[original_index] = sorted_outputs[i] + outputs[original_index] = [sorted_outputs[i*num_return_sequences + j][0] for j in range(num_return_sequences)] return outputs From 28ebcf4f9e42950786c6f59d20c15903cd524614 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Wed, 28 May 2025 15:14:57 +0200 Subject: [PATCH 04/13] add flop ratio plots --- evaluate/distributional_paper/make_plots.py | 227 +++++++++++++++++++- 1 file changed, 221 insertions(+), 6 deletions(-) diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py index f16641f..a502f66 100644 --- a/evaluate/distributional_paper/make_plots.py +++ b/evaluate/distributional_paper/make_plots.py @@ -346,6 +346,163 @@ def get_pts(y, opt_flops, sampling_flops): plt.savefig(f"evaluate/distributional_paper/pareto_plots/{title}.pdf") +def flops_ratio_plot( + results: dict[str,np.ndarray], + baseline: dict[str,np.ndarray] | None = None, + title: str = "FLOPS Ratio Analysis", + sample_levels_to_plot: tuple[int, ...]|None = None, + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + cumulative: bool = False, + flops_per_step: int | None = None, + threshold: float|None = None, + color_scale: str = "linear", + verbose: bool = True, +): + """ + Plot p_harmful vs the ratio of optimization FLOPS to sampling FLOPS. + + Parameters + ---------- + results : dict + Results dictionary containing metric data and FLOPS information + baseline : dict, optional + Baseline results for comparison + title : str + Plot title + sample_levels_to_plot : tuple[int, ...], optional + Which sample levels to highlight + metric : tuple[str, ...] + Metric to plot on y-axis + cumulative : bool + Whether to use cumulative aggregation + flops_per_step : callable, optional + Function to compute optimization FLOPS per step + threshold : float, optional + Threshold for binary classification + color_scale : str + Color scale type ("linear" or "log") + verbose : bool + Whether to print debug information + """ + y = np.array(results[metric]) # (B, n_steps, n_samples) + if threshold is not None: + y = y > threshold + n_runs, n_steps, n_total_samples = y.shape + if sample_levels_to_plot is None: + sample_levels_to_plot = generate_sample_sizes(n_total_samples) + + flops_sampling = np.array(results["flops_sampling"]) # (B, n_steps) + if "flops" in results: + flops_optimization = np.array(results["flops"]) # (B, n_steps) + else: + flops_optimization = np.zeros_like(flops_sampling) # (B, n_steps) + if flops_per_step is not None: + flops_optimization += flops_per_step(np.arange(flops_optimization.shape[1])) + + def subsample_and_aggregate_ratio(step_idx, sample_idx, cumulative, y, opt_flops, sampling_flops, rng): + n_runs, n_steps, n_total_samples = y.shape + opt_flop = np.mean(opt_flops[:, :step_idx+1].sum(axis=1)) + sampling_flop = np.mean(sampling_flops[:, step_idx]) * sample_idx + + # Calculate ratio (sampling / total), handle division by zero + ratio = sampling_flop / (sampling_flop + opt_flop + 1e-9) + + if cumulative and step_idx > 0: + samples_at_end = y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1) + samples_up_to_now = y[:, :step_idx, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0] + values = np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1) + return (ratio, step_idx, sample_idx, values.mean(0), opt_flop, sampling_flop) + return (ratio, step_idx, sample_idx, y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1).mean(axis=0), opt_flop, sampling_flop) + + def get_ratio_pts(y, opt_flops, sampling_flops): + n_runs, n_steps, n_total_samples = y.shape + rng = np.random.default_rng() + pts = [] # (ratio, step, n_samples, mean_p, opt_flop, sampling_flop) + for j in range(1, n_total_samples + 1, 1): + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate_ratio(i, j, cumulative, y, opt_flops, sampling_flops, rng)) + pts = np.asarray(pts) + return pts + + pts = get_ratio_pts(y, flops_optimization, flops_sampling) + ratio, step_idx, n_samp, mean_p, opt_flop, sampling_flop = pts.T + + # Filter out infinite ratios for plotting + finite_mask = np.isfinite(ratio) + ratio_finite = ratio[finite_mask] + mean_p_finite = mean_p[finite_mask] + n_samp_finite = n_samp[finite_mask] + + plt.figure(figsize=(10, 6)) + + # Scatter plot + if color_scale == "log": + color_norm = LogNorm() + else: + color_norm = None + + sc = plt.scatter(ratio_finite, mean_p_finite, c=n_samp_finite, cmap="viridis", alpha=0.3, s=15, norm=color_norm) + plt.colorbar(sc, label="Number of Samples") + + # Highlight specific sample levels + cmap = plt.get_cmap("viridis") + if color_scale == "log": + norm = LogNorm(n_samp.min(), n_samp.max()) + else: + norm = plt.Normalize(n_samp.min(), n_samp.max()) + + for j in sample_levels_to_plot: + mask = (n_samp == j) & finite_mask + if np.any(mask): + color = cmap(norm(j)) + plt.scatter(ratio[mask], mean_p[mask], + color=color, s=50, alpha=0.8, + edgecolors='black', linewidth=0.5, + label=f"{j} samples") + + plt.xlabel("Sampling FLOPS / Total FLOPS", fontsize=14) + if threshold is None: + plt.ylabel("Mean p_harmful", fontsize=14) + else: + plt.ylabel(f"Mean ASR (threshold: {threshold})", fontsize=14) + + plt.grid(True, alpha=0.3) + plt.title(title, fontsize=16) + + # Add baseline if provided + if baseline is not None: + y_baseline = np.array(baseline[metric]) + if threshold is not None: + y_baseline = y_baseline > threshold + + baseline_flops_sampling = np.array(baseline["flops_sampling"]) + if "flops" in baseline: + baseline_flops_optimization = np.array(baseline["flops"]) + else: + baseline_flops_optimization = np.zeros_like(baseline_flops_sampling) + if flops_per_step is not None: + baseline_flops_optimization += flops_per_step(np.arange(baseline_flops_optimization.shape[1])) + + baseline_pts = get_ratio_pts(y_baseline, baseline_flops_optimization, baseline_flops_sampling) + baseline_ratio, _, baseline_n_samp, baseline_mean_p, _, _ = baseline_pts.T + + baseline_finite_mask = np.isfinite(baseline_ratio) + plt.scatter(baseline_ratio[baseline_finite_mask], baseline_mean_p[baseline_finite_mask], + color="red", s=50, alpha=0.8, marker="^", + edgecolors='black', linewidth=0.5, label="Greedy") + + plt.xscale("log") + plt.xlim(1e-5, 1) + plt.ylim(bottom=0) + plt.legend(loc='upper left') + plt.tight_layout() + plt.savefig(f"evaluate/distributional_paper/flops_ratio_plots/{title}.pdf", bbox_inches='tight') + plt.close() + + if verbose: + print(f"FLOPS ratio range: {ratio_finite.min():.2e} to {ratio_finite.max():.2e}") + print(f"Mean p_harmful range: {mean_p_finite.min():.4f} to {mean_p_finite.max():.4f}") + # ---------------------------------------------------------------------------------- # Pareto plots – simplified @@ -523,15 +680,73 @@ def fetch(attack: str, attack_params: dict): threshold=None, ) -# Main loop ------------------------------------------------------------------------ + +def run_attack_flops_ratio( + model: str, + model_title: str, + atk_name: str, + cfg: dict, +): + print("FLOPS Ratio Attack:", atk_name) + + # ---------- helper to fetch data ---------- + def fetch(attack: str, attack_params: dict): + filter_by = dict( + model=model, + attack=attack, + attack_params=attack_params, + dataset_params={"idx": DATASET_IDX}, + ) + paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) + results = collect_results(paths, infer_sampling_flops=True) + assert len(results) == 1, len(results) + return list(results.values())[0] + + # ---------- sampled run ---------- + sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) + + # Attack-specific post-processing + if post := cfg.get("postprocess"): + post(sampled_data, METRIC) + + # ---------- baseline run ---------- + baseline_attack = cfg.get("baseline_attack", atk_name) + baseline_data = fetch(baseline_attack, cfg["baseline_params"]()) + + # ---------- plot ---------- + flops_ratio_plot( + sampled_data, + baseline_data, + title=f"{model_title} {cfg['title_suffix']} FLOPS Ratio", + cumulative=cfg["cumulative"], + metric=METRIC, + flops_per_step=lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)), + threshold=None, + ) + + +# # Main loop ------------------------------------------------------------------------ +# for model_key, model_title in MODELS.items(): +# print("Model:", model_key) +# for atk_name, atk_cfg in ATTACKS: +# try: +# run_attack(model_key, model_title, atk_name, atk_cfg) +# except Exception as e: +# print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") + + +# FLOPS Ratio plots main loop ------------------------------------------------------ +print("\n" + "="*80) +print("GENERATING FLOPS RATIO PLOTS") +print("="*80) + for model_key, model_title in MODELS.items(): print("Model:", model_key) for atk_name, atk_cfg in ATTACKS: - try: - run_attack(model_key, model_title, atk_name, atk_cfg) - except Exception as e: - print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") - + # try: + run_attack_flops_ratio(model_key, model_title, atk_name, atk_cfg) + # except Exception as e: + # print(f"Error running FLOPS ratio attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") # Helper --------------------------------------------------------------------------- From 57db6ab751fd2d6b80e95d418c4a8e1780e8e9e3 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Fri, 30 May 2025 10:41:30 +0200 Subject: [PATCH 05/13] better plots --- evaluate/distributional_paper/make_plots.py | 350 +++++++++++++++----- 1 file changed, 272 insertions(+), 78 deletions(-) diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py index a502f66..d13cd0a 100644 --- a/evaluate/distributional_paper/make_plots.py +++ b/evaluate/distributional_paper/make_plots.py @@ -6,7 +6,7 @@ import pandas as pd from scipy.stats import rankdata from scipy.interpolate import interp1d -from matplotlib.colors import LogNorm +from matplotlib.colors import LogNorm, PowerNorm pd.set_option("display.max_colwidth", None) @@ -175,6 +175,7 @@ def pareto_plot( def subsample_and_aggregate(step_idx, sample_idx, cumulative, y, opt_flops, sampling_flops, rng): + n_runs, n_steps, n_total_samples = y.shape opt_flop = np.mean(opt_flops[:, :step_idx+1].sum(axis=1)) sampling_flop = np.mean(sampling_flops[:, step_idx]) * sample_idx if cumulative and step_idx > 0: @@ -203,12 +204,18 @@ def get_pts(y, opt_flops, sampling_flops): else: x_interp = np.linspace(0, max_cost+1, n_x_points) + # Create figure with subplots: main plot + 2 bar charts + fig = plt.figure(figsize=(18, 6)) + + # Main Pareto plot (left half, spanning both rows) + ax1 = plt.subplot2grid((1, 4), (0, 0), colspan=2) # ---------- scatter all points ---------- - plt.figure(figsize=(9, 6)) if plot_points: if color_scale == "log": color_norm = LogNorm() + elif color_scale == "sqrt": + color_norm = PowerNorm(gamma=0.5) # Square root normalization else: color_norm = None sc = plt.scatter(cost, mean_p, c=n_samp, cmap="viridis", alpha=0.15, s=3, norm=color_norm) @@ -216,21 +223,27 @@ def get_pts(y, opt_flops, sampling_flops): if threshold is None: plt.ylabel("Mean p_harmful", fontsize=14) else: - plt.ylabel(f"Mean ASR (threshold: {threshold})", fontsize=14) + plt.ylabel(f"Max ASR (threshold: {threshold})", fontsize=12) # ---------- overlay Pareto frontiers ---------- cmap = plt.get_cmap("viridis") if color_scale == "log": norm = LogNorm(n_samp.min(), n_samp.max()) + elif color_scale == "sqrt": + norm = PowerNorm(gamma=0.5, vmin=n_samp.min(), vmax=n_samp.max()) else: norm = plt.Normalize(n_samp.min(), n_samp.max()) rng = np.random.default_rng() n_smoothing = 50 + frontier_data = {} # Store frontier data for bar charts + if plot_frontiers: for j in sample_levels_to_plot: xs = [] ys = [] + if j == n_total_samples: + n_smoothing = 1 for _ in range(n_smoothing): pts = [] for i in range(0, n_steps, 1): @@ -248,6 +261,15 @@ def get_pts(y, opt_flops, sampling_flops): y_mean = np.mean(y_interp, axis=0) # Filter out leading zeros nonzero_mask = y_mean > 0 + + # Store data for bar charts + frontier_data[j] = { + 'x': x_interp[nonzero_mask], + 'y': y_mean[nonzero_mask], + 'color': color, + 'max_asr': np.max(y_mean[nonzero_mask]) if np.any(nonzero_mask) else 0 + } + plt.plot( x_interp[nonzero_mask], y_mean[nonzero_mask], @@ -259,7 +281,7 @@ def get_pts(y, opt_flops, sampling_flops): ) if plot_envelope: - n_smoothing = 50 + n_smoothing = n_total_samples y_interps = [] for j in range(1, n_total_samples+1): xs = [] @@ -290,60 +312,188 @@ def get_pts(y, opt_flops, sampling_flops): title_suffix = "" - y = np.array(baseline[metric]) # (B, n_steps, n_samples) - if threshold is not None: - y = y > threshold - - baseline_flops_sampling = np.array(baseline["flops_sampling"]) - if "flops" in baseline: - baseline_flops_optimization = np.array(baseline["flops"]) # (B, n_steps) - else: - baseline_flops_optimization = np.zeros_like(baseline_flops_sampling) # (B, n_steps) - if flops_per_step is not None: - baseline_flops_optimization += flops_per_step(np.arange(baseline_flops_optimization.shape[1])) - - if y is not None: - title_suffix = f" ({n_runs}, {y.shape[0]})" - if verbose: - print(n_runs, "for main") - print(y.shape[0], "for baseline") - n_runs, n_steps, n_total_samples = y.shape - assert n_total_samples == 1 + # Handle baseline data + baseline_max_asr = 0 + baseline_frontier_data = None - rng = np.random.default_rng() - pts = [] # (cost, step, n_samples, mean_p) - for i in range(0, n_steps, 1): - for j in range(1, n_total_samples + 1, 1): - pts.append(subsample_and_aggregate(i, j, cumulative, y, baseline_flops_optimization, baseline_flops_sampling, rng)) + if baseline is not None: + y_baseline = np.array(baseline[metric]) # (B, n_steps, n_samples) + if threshold is not None: + y_baseline = y_baseline > threshold - pts = np.asarray(pts) - cost, step_idx, n_samp, mean_p = pts.T + baseline_flops_sampling = np.array(baseline["flops_sampling"]) + if "flops" in baseline: + baseline_flops_optimization = np.array(baseline["flops"]) # (B, n_steps) + else: + baseline_flops_optimization = np.zeros_like(baseline_flops_sampling) # (B, n_steps) + if flops_per_step is not None: + baseline_flops_optimization += flops_per_step(np.arange(baseline_flops_optimization.shape[1])) - # ---------- scatter all points ---------- - # sc = plt.scatter(cost, mean_p, c="r", alpha=0.35, s=4) + if y_baseline is not None: + title_suffix = f" ({n_runs}, {y_baseline.shape[0]})" + if verbose: + print(n_runs, "for main") + print(y_baseline.shape[0], "for baseline") + n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape + assert n_total_samples_baseline == 1 + + rng = np.random.default_rng() + pts = [] # (cost, step, n_samples, mean_p) + for i in range(0, n_steps_baseline, 1): + for j in range(1, n_total_samples_baseline + 1, 1): + pts.append(subsample_and_aggregate(i, j, cumulative, y_baseline, baseline_flops_optimization, baseline_flops_sampling, rng)) + + pts = np.asarray(pts) + cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts.T + + # ---------- overlay Pareto frontiers ---------- + if plot_frontiers or plot_envelope: + mask = n_samp_baseline == 1 + fx, fy = _pareto_frontier(cost_baseline[mask], mean_p_baseline[mask], method=frontier_method) + y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=(0, max(fy)))(x_interp) + nonzero_mask_baseline = y_interp_baseline > 0 + + # Store baseline data for bar charts + baseline_max_asr = np.max(y_interp_baseline[nonzero_mask_baseline]) if np.any(nonzero_mask_baseline) else 0 + baseline_frontier_data = { + 'x': x_interp[nonzero_mask_baseline], + 'y': y_interp_baseline[nonzero_mask_baseline], + 'max_asr': baseline_max_asr + } + + plt.plot( + x_interp[nonzero_mask_baseline], + y_interp_baseline[nonzero_mask_baseline], + marker="o", + linewidth=1.8, + markersize=2, + label=f"greedy", + color="r", + ) - # ---------- overlay Pareto frontiers ---------- - if plot_frontiers or plot_envelope: - mask = n_samp == 1 - fx, fy = _pareto_frontier(cost[mask], mean_p[mask], method=frontier_method) - y_interp = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=(0, max(fy)))(x_interp) - nonzero_mask = y_interp > 0 - plt.plot( - x_interp[nonzero_mask], - y_interp[nonzero_mask], - marker="o", - linewidth=1.8, - markersize=2, - label=f"greedy", - color="r", - ) plt.title(title + title_suffix) - plt.grid(True, linewidth=0.3) + plt.grid(True, alpha=0.3) plt.ylim(bottom=0) plt.xscale(x_scale) plt.legend(title="Frontiers", loc="upper left" if x_scale == "log" else "lower right") + + # ---------- Bar Chart 1: Max ASR Comparison (Vertical Slice) ---------- + ax2 = plt.subplot2grid((1, 4), (0, 2)) + + methods = [] + max_asrs = [] + colors = [] + + # Add baseline (delta = 0 for baseline) + if baseline_frontier_data is not None: + methods.append("Greedy") + max_asrs.append(0.0) # Delta from itself is 0 + colors.append("red") + + # Add sampling methods (calculate delta from baseline) + for j in sample_levels_to_plot: + if j in frontier_data: + methods.append(f"{j} samples") + delta_asr = frontier_data[j]['max_asr'] - baseline_max_asr if baseline_frontier_data is not None else 0 + max_asrs.append(delta_asr) + colors.append(frontier_data[j]['color']) + + if methods: + bars = plt.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black') + plt.xlabel("Method", fontsize=12) + if threshold is None: + plt.ylabel(r"$\Delta$ $p_{harmful}$ - greedy", fontsize=12) + else: + plt.ylabel(r"Max $p_{harmful}$" + f" (threshold: {threshold})", fontsize=12) + plt.title(r"Max $p_{harmful}$ Comparison", fontsize=12) + plt.xticks(rotation=45, ha='right') + plt.grid(True, alpha=0.3, axis='y') + # Increase ylim by 2% on top and bottom + ymin, ymax = plt.ylim() + margin = (ymax - ymin) * 0.03 + plt.ylim(ymin - margin, ymax + margin) + + # ----- add labels with a 4-point gap ----- + for bar, value in zip(bars, max_asrs): + # choose label position: above for positive, below for negative + offset_pt = 4 # visual gap in points + y = bar.get_height() + va = 'bottom' if y >= 0 else 'top' + offset = (0, offset_pt if y >= 0 else -offset_pt) + + ax2.annotate(f'{value:.3f}', + xy=(bar.get_x() + bar.get_width()/2, y), + xytext=offset, + textcoords='offset points', + ha='center', va=va, fontsize=10) + + # ---------- Bar Chart 2: FLOPS Efficiency to Reach Greedy ASR (Horizontal Slice) ---------- + ax3 = plt.subplot2grid((1, 4), (0, 3)) + + if baseline_frontier_data is not None and baseline_max_asr > 0: + methods_flops = [] + flops_required = [] + colors_flops = [] + + # Find FLOPS required to reach baseline ASR for each sampling method + target_asr = baseline_max_asr + + for j in sample_levels_to_plot: + if j in frontier_data: + # Find the minimum FLOPS where ASR >= target_asr + y_vals = frontier_data[j]['y'] + x_vals = frontier_data[j]['x'] + + # Find points where ASR >= target_asr + valid_indices = y_vals >= target_asr + if np.any(valid_indices): + min_flops = np.min(x_vals[valid_indices]) + methods_flops.append(f"{j} samples") + flops_required.append(min_flops) + colors_flops.append(frontier_data[j]['color']) + + # Add baseline (find minimum FLOPS where it reaches target ASR) + if baseline_frontier_data['x'].size > 0: + # Find the minimum FLOPS where baseline ASR >= target_asr + baseline_y_vals = baseline_frontier_data['y'] + baseline_x_vals = baseline_frontier_data['x'] + baseline_valid_indices = baseline_y_vals >= target_asr + if np.any(baseline_valid_indices): + baseline_flops = np.min(baseline_x_vals[baseline_valid_indices]) + else: + # Fallback to minimum FLOPS if no point reaches target ASR + baseline_flops = np.min(baseline_x_vals) + methods_flops.insert(0, "Greedy") + flops_required.insert(0, baseline_flops) + colors_flops.insert(0, "red") + + if methods_flops: + bars = plt.bar(methods_flops, flops_required, color=colors_flops, alpha=0.7, edgecolor='black') + plt.xlabel("Method", fontsize=12) + plt.ylabel("FLOPS Required", fontsize=12) + plt.title(r"FLOPS to Reach Greedy $p_{harmful}$" + f" ( = {target_asr:.3f})", fontsize=12) + plt.xticks(rotation=45, ha='right') + plt.yscale('log') + plt.grid(True, alpha=0.3, axis='y') + # Increase ylim by 2% on top and bottom + ymin, ymax = plt.ylim() + plt.ylim(ymin, ymax * 1.1) + + # Add value labels on bars + # for bar, value in zip(bars, flops_required): + # plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.1, + # f'{value:.2e}', ha='center', va='bottom', fontsize=9, rotation=45) + # --- constant 5-point vertical gap --- + for bar, value in zip(bars, flops_required): + ax3.annotate(f'{value:.2e}', + xy=(bar.get_x() + bar.get_width()/2, value), # anchor at top of bar + xytext=(0, 5), # 5 points straight up + textcoords='offset points', + ha='center', va='bottom', rotation=45, fontsize=9) + plt.tight_layout() plt.savefig(f"evaluate/distributional_paper/pareto_plots/{title}.pdf") + plt.close() def flops_ratio_plot( @@ -356,6 +506,7 @@ def flops_ratio_plot( flops_per_step: int | None = None, threshold: float|None = None, color_scale: str = "linear", + color_by: str = "samples", verbose: bool = True, ): """ @@ -381,6 +532,8 @@ def flops_ratio_plot( Threshold for binary classification color_scale : str Color scale type ("linear" or "log") + color_by : str + What to color points by: "samples" (number of samples) or "total_flops" (total FLOP usage) verbose : bool Whether to print debug information """ @@ -427,44 +580,64 @@ def get_ratio_pts(y, opt_flops, sampling_flops): pts = get_ratio_pts(y, flops_optimization, flops_sampling) ratio, step_idx, n_samp, mean_p, opt_flop, sampling_flop = pts.T + # Calculate total FLOPS for coloring option + total_flop = opt_flop + sampling_flop + # Filter out infinite ratios for plotting finite_mask = np.isfinite(ratio) ratio_finite = ratio[finite_mask] mean_p_finite = mean_p[finite_mask] n_samp_finite = n_samp[finite_mask] + total_flop_finite = total_flop[finite_mask] plt.figure(figsize=(10, 6)) + # Choose color values based on color_by parameter + if color_by == "samples": + color_values = n_samp_finite + color_label = "Number of Samples" + elif color_by == "total_flops": + color_values = total_flop_finite + color_label = "Total FLOPS" + else: + raise ValueError(f"color_by must be 'samples' or 'total_flops', got '{color_by}'") + # Scatter plot if color_scale == "log": color_norm = LogNorm() + elif color_scale == "sqrt": + color_norm = PowerNorm(gamma=0.5) # Square root normalization else: color_norm = None - sc = plt.scatter(ratio_finite, mean_p_finite, c=n_samp_finite, cmap="viridis", alpha=0.3, s=15, norm=color_norm) - plt.colorbar(sc, label="Number of Samples") + sc = plt.scatter(ratio_finite, mean_p_finite, c=color_values, cmap="viridis", alpha=0.3, s=15, norm=color_norm) + if color_by != "samples": + plt.colorbar(sc, label=color_label) - # Highlight specific sample levels - cmap = plt.get_cmap("viridis") - if color_scale == "log": - norm = LogNorm(n_samp.min(), n_samp.max()) - else: - norm = plt.Normalize(n_samp.min(), n_samp.max()) + # Highlight specific sample levels (only when coloring by samples) + if color_by == "samples": + cmap = plt.get_cmap("viridis") + if color_scale == "log": + norm = LogNorm(n_samp.min(), n_samp.max()) + elif color_scale == "sqrt": + norm = PowerNorm(gamma=0.5, vmin=n_samp.min(), vmax=n_samp.max()) + else: + norm = plt.Normalize(n_samp.min(), n_samp.max()) - for j in sample_levels_to_plot: - mask = (n_samp == j) & finite_mask - if np.any(mask): - color = cmap(norm(j)) - plt.scatter(ratio[mask], mean_p[mask], - color=color, s=50, alpha=0.8, - edgecolors='black', linewidth=0.5, - label=f"{j} samples") + for j in sample_levels_to_plot: + mask = (n_samp == j) & finite_mask + if np.any(mask): + color = cmap(norm(j)) + plt.scatter(ratio[mask], mean_p[mask], + color=color, s=50, alpha=0.8, + edgecolors='black', linewidth=0.5, + label=f"{j} samples") plt.xlabel("Sampling FLOPS / Total FLOPS", fontsize=14) if threshold is None: plt.ylabel("Mean p_harmful", fontsize=14) else: - plt.ylabel(f"Mean ASR (threshold: {threshold})", fontsize=14) + plt.ylabel(f"Max ASR (threshold: {threshold})", fontsize=12) plt.grid(True, alpha=0.3) plt.title(title, fontsize=16) @@ -563,6 +736,24 @@ def get_ratio_pts(y, opt_flops, sampling_flops): "use_prefix_cache": True, }, )), + ("gcg", dict( + title_suffix="GCG 500", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 500, "temperature": 0.7}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + )), ("gcg", dict( title_suffix="GCG Entropy Loss", cumulative=False, @@ -632,7 +823,7 @@ def get_ratio_pts(y, opt_flops, sampling_flops): ] METRIC = ("scores", "strong_reject", "p_harmful") -GROUP_BY = {("model",)} +GROUP_BY = {"model", "attack_params"} DATASET_IDX = list(range(50)) @@ -643,7 +834,7 @@ def run_attack( atk_name: str, cfg: dict, ): - print("Attack:", atk_name) + print("Attack:", atk_name, cfg["title_suffix"]) # ---------- helper to fetch data ---------- def fetch(attack: str, attack_params: dict): @@ -678,6 +869,7 @@ def fetch(attack: str, attack_params: dict): metric=METRIC, flops_per_step=lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)), threshold=None, + color_scale="sqrt", ) @@ -722,17 +914,19 @@ def fetch(attack: str, attack_params: dict): metric=METRIC, flops_per_step=lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)), threshold=None, + # color_by="total_flops", + color_scale="sqrt", ) # # Main loop ------------------------------------------------------------------------ -# for model_key, model_title in MODELS.items(): -# print("Model:", model_key) -# for atk_name, atk_cfg in ATTACKS: -# try: -# run_attack(model_key, model_title, atk_name, atk_cfg) -# except Exception as e: -# print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") +for model_key, model_title in MODELS.items(): + print("Model:", model_key) + for atk_name, atk_cfg in ATTACKS: + try: + run_attack(model_key, model_title, atk_name, atk_cfg) + except Exception as e: + print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") # FLOPS Ratio plots main loop ------------------------------------------------------ @@ -743,10 +937,10 @@ def fetch(attack: str, attack_params: dict): for model_key, model_title in MODELS.items(): print("Model:", model_key) for atk_name, atk_cfg in ATTACKS: - # try: + try: run_attack_flops_ratio(model_key, model_title, atk_name, atk_cfg) - # except Exception as e: - # print(f"Error running FLOPS ratio attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") + except Exception as e: + print(f"Error running FLOPS ratio attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") # Helper --------------------------------------------------------------------------- @@ -822,7 +1016,7 @@ def fetch(attack: str, attack_params: dict): ) paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) results = collect_results(paths, infer_sampling_flops=True) - assert len(results) == 1, len(results) + assert len(results) == 1, f"Should only have exactly one type of result, got {len(results)}, {list(results.keys())}" return list(results.values())[0] # ---------- sampled run ---------- From 924b771cb45e33199f4f3c581dbd89b43b953e09 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Fri, 30 May 2025 10:55:08 +0200 Subject: [PATCH 06/13] better run_sampling.py no longer overwrites previous items --- run_sampling.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/run_sampling.py b/run_sampling.py index d429201..b44ddee 100644 --- a/run_sampling.py +++ b/run_sampling.py @@ -110,6 +110,10 @@ def main(cfg: DictConfig) -> None: print(filter_by) # get all paths paths = get_filtered_and_grouped_paths(filter_by, None)[("all", )] + if not paths: + logging.info("No paths found, exiting") + return + n = 0 pbar = tqdm(paths, file=sys.stdout) backend = "hf" @@ -159,12 +163,13 @@ def main(cfg: DictConfig) -> None: model, tokenizer, tokens, - initial_batch_size=len(tokens), + initial_batch_size=None, num_return_sequences=n_to_generate, max_new_tokens=gen_config["max_new_tokens"], temperature=gen_config["temperature"], top_p=gen_config["top_p"], - top_k=gen_config["top_k"] + top_k=gen_config["top_k"], + verbose=True ) # (n_steps, n_to_generate) # have to also add classifier scores for new completions @@ -185,7 +190,7 @@ def main(cfg: DictConfig) -> None: else: modified_prompt.append({"role": "assistant", "content": completion}) modified_prompts.append(modified_prompt) - results = judge(modified_prompts) + results = judge(modified_prompts, verbose=True) if all(r is None for r in results): continue i = 0 @@ -202,11 +207,36 @@ def main(cfg: DictConfig) -> None: pbar.set_description(f"{len(subrun['steps']) * n_to_generate} | {n} total") + log_dir = os.path.join(cfg.save_dir, date_time_string) + i = 0 + while os.path.exists(os.path.join(log_dir, str(i), f"run.json")): + i += 1 + log_file = os.path.join(log_dir, str(i), f"run.json") + os.makedirs(os.path.dirname(log_file), exist_ok=True) + json.dump(attack_run, open(log_file, "w"), indent=2, cls=CompactJSONEncoder) - json.dump(attack_run, open(path, "w"), indent=2, cls=CompactJSONEncoder) db = get_mongodb_connection() collection = db.runs - collection.update_many({"log_file": path}, {"$set": {"config.attack_params.generation_config.num_return_sequences": cfg.num_return_sequences}}) + + # Find all entries that match the original log_file path + matching_entries = list(collection.find({"log_file": path})) + + # Create new entries with updated log_file and num_return_sequences + new_entries = [] + for entry in matching_entries: + new_entry = entry.copy() + # Remove the _id field so MongoDB will generate a new one + if "_id" in new_entry: + del new_entry["_id"] + # Update the log_file to the new path + new_entry["log_file"] = log_file + # Update the num_return_sequences in the config + new_entry["config"]["attack_params"]["generation_config"]["num_return_sequences"] = cfg.num_return_sequences + new_entries.append(new_entry) + + # Insert the new entries if any were found + if new_entries: + collection.insert_many(new_entries) except Exception as e: logging.error(f"Error in {path}. Original exception: {e}") raise Exception(f"Error in {path}. Original exception: {e}") from e From f681f93b87d5f723a3e3cd4002cd8885e98cc23c Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Tue, 3 Jun 2025 16:34:02 +0200 Subject: [PATCH 07/13] various improvements to plots --- conf/hydra/launcher/a100h100.yaml | 2 +- conf/sampling.yaml | 7 +- evaluate/distributional_paper/make_plots.py | 1215 ++++++++++++------- purge_orphans.py | 4 + src/attacks/direct.py | 1 + src/io_utils.py | 32 +- src/judges.py | 5 +- src/lm_utils.py | 11 +- 8 files changed, 815 insertions(+), 462 deletions(-) create mode 100644 purge_orphans.py diff --git a/conf/hydra/launcher/a100h100.yaml b/conf/hydra/launcher/a100h100.yaml index 59a3c53..512cf87 100644 --- a/conf/hydra/launcher/a100h100.yaml +++ b/conf/hydra/launcher/a100h100.yaml @@ -13,6 +13,6 @@ max_num_timeout: 0 additional_parameters: {} setup: [] gres: gpu:1 -mem_gb: 32 +mem_gb: 64 cpus_per_task: 4 partition: gpu_a100,gpu_h100 \ No newline at end of file diff --git a/conf/sampling.yaml b/conf/sampling.yaml index 911f803..7e0b936 100644 --- a/conf/sampling.yaml +++ b/conf/sampling.yaml @@ -13,16 +13,19 @@ hydra: dir: ${root_dir}/multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}/ name: testing -num_return_sequences: 50 # we generate completions such that the final result has num_return_sequences completions +save_dir: ${root_dir}/outputs/ +num_return_sequences: 100 # we generate completions such that the final result has num_return_sequences completions filter_by: model: meta-llama/Meta-Llama-3.1-8B-Instruct attack: gcg attack_params: num_steps: 250 + loss: ce + token_selection: default generation_config: num_return_sequences: 50 dataset_params: - idx: "list(range(0,20))" + idx: 0 diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py index d13cd0a..a13148a 100644 --- a/evaluate/distributional_paper/make_plots.py +++ b/evaluate/distributional_paper/make_plots.py @@ -4,10 +4,12 @@ import numpy as np import pandas as pd +import seaborn as sns from scipy.stats import rankdata from scipy.interpolate import interp1d from matplotlib.colors import LogNorm, PowerNorm - +from scipy.interpolate import griddata +from brokenaxes import brokenaxes pd.set_option("display.max_colwidth", None) pd.set_option("display.max_columns", None) @@ -35,6 +37,7 @@ def generate_sample_sizes(total_samples: int) -> tuple[int, ...]: return tuple(result) power += 1 + def _dominance_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Return the non-dominated (Pareto-optimal) points, ordered by cost. @@ -123,6 +126,111 @@ def _pareto_frontier(xs: np.ndarray, raise ValueError(f"Unknown frontier method '{method}'") +# ------------------------------------------------------------------ +# Common helper functions +# ------------------------------------------------------------------ +def fetch_data(model: str, attack: str, attack_params: dict, dataset_idx: list[int], group_by: set[str]): + """Common data fetching logic used across all plotting functions.""" + filter_by = dict( + model=model, + attack=attack, + attack_params=attack_params, + dataset_params={"idx": dataset_idx}, + ) + paths = get_filtered_and_grouped_paths(filter_by, group_by) + + results = collect_results(paths, infer_sampling_flops=True) + assert len(results) == 1, f"Should only have exactly one type of result, got {len(results)}, {list(results.keys())}" + return list(results.values())[0] + + +def preprocess_data(results: dict[str, np.ndarray], metric: tuple[str, ...], threshold: float|None, flops_per_step_fn): + """Common data preprocessing logic.""" + y = np.array(results[metric]) # (B, n_steps, n_samples) + if threshold is not None: + y = y > threshold + + flops_sampling_prefill_cache = np.array(results["flops_sampling_prefill_cache"]) # (B, n_steps) + flops_sampling_generation = np.array(results["flops_sampling_generation"]) # (B, n_steps) + + if "flops" in results: + flops_optimization = np.array(results["flops"]) # (B, n_steps) + else: + flops_optimization = np.zeros_like(flops_sampling_generation) # (B, n_steps) + if flops_per_step_fn is not None: + flops_optimization += flops_per_step_fn(np.arange(flops_optimization.shape[1])) + + return y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation + + +def subsample_and_aggregate(step_idx: int, sample_idx: int, cumulative: bool, y: np.ndarray, + opt_flops: np.ndarray, sampling_prefill_flops: np.ndarray, + sampling_generation_flops: np.ndarray, rng: np.random.Generator, + return_ratio: bool = False, n_smoothing: int = 1): + """ + Unified subsampling and aggregation function. + + Parameters + ---------- + return_ratio : bool + If True, returns sampling ratio instead of total cost + n_smoothing : int + Number of smoothing iterations for variance reduction + """ + n_runs, n_steps, n_total_samples = y.shape + opt_flop = np.mean(opt_flops[:, :step_idx+1].sum(axis=1)) + sampling_flop = np.mean(sampling_generation_flops[:, step_idx]) * sample_idx + np.mean(sampling_prefill_flops[:, step_idx]) + total_flop = opt_flop + sampling_flop + + # Calculate value with smoothing + values = [] + for _ in range(n_smoothing): + # + rng = np.random.default_rng(sample_idx+n_smoothing) + if cumulative and step_idx > 0: + samples_up_to_now = y[:, :step_idx, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0] + samples_at_end = y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1) + values.append(np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1).mean(axis=0)) + else: + values.append(y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1).mean(axis=0)) + + mean_value = np.mean(values) + + if return_ratio: + ratio = sampling_flop / (total_flop + 1e-9) + return (ratio, step_idx, sample_idx, mean_value, opt_flop, sampling_flop) + else: + return (total_flop, step_idx, sample_idx, mean_value) + + +def get_points(y: np.ndarray, opt_flops: np.ndarray, sampling_prefill_flops: np.ndarray, + sampling_generation_flops: np.ndarray, return_ratio: bool = False, + n_smoothing: int = 1, cumulative: bool = False): + """Generate points for plotting with optional ratio calculation.""" + n_runs, n_steps, total_samples = y.shape + rng = np.random.default_rng(42) # Fixed seed for reproducibility + pts = [] + + for j in range(1, total_samples + 1, 1): + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate( + i, j, cumulative, y, opt_flops, sampling_prefill_flops, + sampling_generation_flops, rng, return_ratio, n_smoothing + )) + + return np.asarray(pts) + + +def setup_color_normalization(color_scale: str, values: np.ndarray): + """Setup color normalization based on scale type.""" + if color_scale == "log": + return LogNorm(values.min(), values.max()) + elif color_scale == "sqrt": + return PowerNorm(gamma=0.5, vmin=values.min(), vmax=values.max()) + else: + return plt.Normalize(values.min(), values.max()) + + def pareto_plot( results: dict[str,np.ndarray], baseline: dict[str,np.ndarray] | None = None, @@ -144,63 +252,20 @@ def pareto_plot( """ Scatter the full design-space AND overlay Pareto frontiers for selected sampling counts. - - Parameters - ---------- - data : (x, y) where x is ignored and y has shape (n_steps, total_samples) - Your original score tensor. - sampling_cost_factor : float, optional - Multiplies the sampling cost term j. - frontier_samples : tuple[int, ...], optional - Which n_sample values to draw Pareto lines for. - - Returns - ------- - None """ - y = np.array(results[metric]) # (B, n_steps, n_samples) - if threshold is not None: - y = y > threshold + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, flops_per_step + ) n_runs, n_steps, n_total_samples = y.shape if sample_levels_to_plot is None: sample_levels_to_plot = generate_sample_sizes(n_total_samples) - flops_sampling = np.array(results["flops_sampling"]) # (B, n_steps) - if "flops" in results: - flops_optimization = np.array(results["flops"]) # (B, n_steps) - else: - flops_optimization = np.zeros_like(flops_sampling) # (B, n_steps) - if flops_per_step is not None: - flops_optimization += flops_per_step(np.arange(flops_optimization.shape[1])) - - - def subsample_and_aggregate(step_idx, sample_idx, cumulative, y, opt_flops, sampling_flops, rng): - n_runs, n_steps, n_total_samples = y.shape - opt_flop = np.mean(opt_flops[:, :step_idx+1].sum(axis=1)) - sampling_flop = np.mean(sampling_flops[:, step_idx]) * sample_idx - if cumulative and step_idx > 0: - samples_at_end = y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1) - samples_up_to_now = y[:, :step_idx, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0] - values = np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1) - return (opt_flop + sampling_flop, step_idx, sample_idx, values.mean(0)) - return (opt_flop + sampling_flop, step_idx, sample_idx, y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1).mean(axis=0)) - - - def get_pts(y, opt_flops, sampling_flops): - n_runs, n_steps, total_samples = y.shape - rng = np.random.default_rng() - pts = [] # (cost, step, n_samples, mean_p) - for j in range(1, total_samples + 1, 1): - for i in range(0, n_steps, 1): - pts.append(subsample_and_aggregate(i, j, cumulative, y, opt_flops, sampling_flops, rng)) - pts = np.asarray(pts) - return pts - - pts = get_pts(y, flops_optimization, flops_sampling) + pts = get_points(y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, + return_ratio=False, cumulative=cumulative) cost, step_idx, n_samp, mean_p = pts.T max_cost = max(cost) if x_scale == "log": - x_interp = np.logspace(11, np.log10(max_cost+1), n_x_points) + x_interp = np.logspace(11, np.log10(max_cost)+0.001, n_x_points) else: x_interp = np.linspace(0, max_cost+1, n_x_points) @@ -211,28 +276,17 @@ def get_pts(y, opt_flops, sampling_flops): ax1 = plt.subplot2grid((1, 4), (0, 0), colspan=2) # ---------- scatter all points ---------- + color_norm = setup_color_normalization(color_scale, n_samp) if plot_points: - if color_scale == "log": - color_norm = LogNorm() - elif color_scale == "sqrt": - color_norm = PowerNorm(gamma=0.5) # Square root normalization - else: - color_norm = None sc = plt.scatter(cost, mean_p, c=n_samp, cmap="viridis", alpha=0.15, s=3, norm=color_norm) plt.xlabel("Cost (FLOPS (optimization + sampling))", fontsize=14) if threshold is None: - plt.ylabel("Mean p_harmful", fontsize=14) + plt.ylabel(r"$\overline{p_{harmful}}$", fontsize=14) else: - plt.ylabel(f"Max ASR (threshold: {threshold})", fontsize=12) + plt.ylabel(r"$\overline{{ASR}}\quad (p_{{harmful}} \geq {threshold})$".format(threshold=threshold), fontsize=14) # ---------- overlay Pareto frontiers ---------- cmap = plt.get_cmap("viridis") - if color_scale == "log": - norm = LogNorm(n_samp.min(), n_samp.max()) - elif color_scale == "sqrt": - norm = PowerNorm(gamma=0.5, vmin=n_samp.min(), vmax=n_samp.max()) - else: - norm = plt.Normalize(n_samp.min(), n_samp.max()) rng = np.random.default_rng() n_smoothing = 50 @@ -247,7 +301,7 @@ def get_pts(y, opt_flops, sampling_flops): for _ in range(n_smoothing): pts = [] for i in range(0, n_steps, 1): - pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling, rng)) + pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, rng)) pts = np.asarray(pts) cost, _, _, mean_p = pts.T @@ -257,7 +311,7 @@ def get_pts(y, opt_flops, sampling_flops): ys.append(fy) y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) for x_, y_ in zip(xs, ys)] - color = cmap(norm(j)) + color = cmap(color_norm(j)) y_mean = np.mean(y_interp, axis=0) # Filter out leading zeros nonzero_mask = y_mean > 0 @@ -289,7 +343,7 @@ def get_pts(y, opt_flops, sampling_flops): for n in range(n_smoothing): pts = [] for i in range(0, n_steps, 1): - pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling, rng)) + pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, rng)) pts = np.asarray(pts) cost, step_idx, n_samp, mean_p = pts.T @@ -307,7 +361,7 @@ def get_pts(y, opt_flops, sampling_flops): # Filter out leading zeros nonzero_mask = y_envelope > 0 - color = [cmap(norm(argmax[i])) for i in range(len(argmax)) if nonzero_mask[i]] + color = [cmap(color_norm(argmax[i])) for i in range(len(argmax)) if nonzero_mask[i]] plt.scatter(x_interp[nonzero_mask], y_envelope[nonzero_mask], c=color, s=2) title_suffix = "" @@ -317,17 +371,9 @@ def get_pts(y, opt_flops, sampling_flops): baseline_frontier_data = None if baseline is not None: - y_baseline = np.array(baseline[metric]) # (B, n_steps, n_samples) - if threshold is not None: - y_baseline = y_baseline > threshold - - baseline_flops_sampling = np.array(baseline["flops_sampling"]) - if "flops" in baseline: - baseline_flops_optimization = np.array(baseline["flops"]) # (B, n_steps) - else: - baseline_flops_optimization = np.zeros_like(baseline_flops_sampling) # (B, n_steps) - if flops_per_step is not None: - baseline_flops_optimization += flops_per_step(np.arange(baseline_flops_optimization.shape[1])) + y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data( + baseline, metric, threshold, flops_per_step + ) if y_baseline is not None: title_suffix = f" ({n_runs}, {y_baseline.shape[0]})" @@ -337,13 +383,8 @@ def get_pts(y, opt_flops, sampling_flops): n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape assert n_total_samples_baseline == 1 - rng = np.random.default_rng() - pts = [] # (cost, step, n_samples, mean_p) - for i in range(0, n_steps_baseline, 1): - for j in range(1, n_total_samples_baseline + 1, 1): - pts.append(subsample_and_aggregate(i, j, cumulative, y_baseline, baseline_flops_optimization, baseline_flops_sampling, rng)) - - pts = np.asarray(pts) + pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, + baseline_flops_sampling_generation, return_ratio=False, cumulative=cumulative) cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts.T # ---------- overlay Pareto frontiers ---------- @@ -393,19 +434,19 @@ def get_pts(y, opt_flops, sampling_flops): # Add sampling methods (calculate delta from baseline) for j in sample_levels_to_plot: if j in frontier_data: - methods.append(f"{j} samples") + methods.append(f"{j} samples" if j != 1 else "1 sample") delta_asr = frontier_data[j]['max_asr'] - baseline_max_asr if baseline_frontier_data is not None else 0 max_asrs.append(delta_asr) colors.append(frontier_data[j]['color']) if methods: bars = plt.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black') - plt.xlabel("Method", fontsize=12) if threshold is None: - plt.ylabel(r"$\Delta$ $p_{harmful}$ - greedy", fontsize=12) + plt.ylabel(r"$\Delta$ $p_{harmful}$", fontsize=14) + plt.title(r"$p_{harmful}$ vs. \#samples", fontsize=14) else: - plt.ylabel(r"Max $p_{harmful}$" + f" (threshold: {threshold})", fontsize=12) - plt.title(r"Max $p_{harmful}$ Comparison", fontsize=12) + plt.title(r"$\overline{{ASR}}$ vs. \#samples".format(threshold=threshold), fontsize=14) + plt.ylabel(r"$\Delta$ $\overline{{ASR}}\quad (p_{{harmful}} \geq {threshold})$".format(threshold=threshold), fontsize=14) plt.xticks(rotation=45, ha='right') plt.grid(True, alpha=0.3, axis='y') # Increase ylim by 2% on top and bottom @@ -469,20 +510,17 @@ def get_pts(y, opt_flops, sampling_flops): if methods_flops: bars = plt.bar(methods_flops, flops_required, color=colors_flops, alpha=0.7, edgecolor='black') - plt.xlabel("Method", fontsize=12) - plt.ylabel("FLOPS Required", fontsize=12) - plt.title(r"FLOPS to Reach Greedy $p_{harmful}$" + f" ( = {target_asr:.3f})", fontsize=12) + plt.ylabel("FLOPS", fontsize=12) + plt.title(r"FLOPS to reach same $p_{harmful}$ as greedy" + f" ( = {target_asr:.3f})", fontsize=12) plt.xticks(rotation=45, ha='right') plt.yscale('log') plt.grid(True, alpha=0.3, axis='y') - # Increase ylim by 2% on top and bottom + # Increase ylim by ymin, ymax = plt.ylim() - plt.ylim(ymin, ymax * 1.1) + import math + margin = ((math.log10(ymax) - math.log10(ymin)) * 0.2) + plt.ylim(ymin, ymax * (1+margin)) - # Add value labels on bars - # for bar, value in zip(bars, flops_required): - # plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.1, - # f'{value:.2e}', ha='center', va='bottom', fontsize=9, rotation=45) # --- constant 5-point vertical gap --- for bar, value in zip(bars, flops_required): ax3.annotate(f'{value:.2e}', @@ -511,73 +549,16 @@ def flops_ratio_plot( ): """ Plot p_harmful vs the ratio of optimization FLOPS to sampling FLOPS. - - Parameters - ---------- - results : dict - Results dictionary containing metric data and FLOPS information - baseline : dict, optional - Baseline results for comparison - title : str - Plot title - sample_levels_to_plot : tuple[int, ...], optional - Which sample levels to highlight - metric : tuple[str, ...] - Metric to plot on y-axis - cumulative : bool - Whether to use cumulative aggregation - flops_per_step : callable, optional - Function to compute optimization FLOPS per step - threshold : float, optional - Threshold for binary classification - color_scale : str - Color scale type ("linear" or "log") - color_by : str - What to color points by: "samples" (number of samples) or "total_flops" (total FLOP usage) - verbose : bool - Whether to print debug information """ - y = np.array(results[metric]) # (B, n_steps, n_samples) - if threshold is not None: - y = y > threshold + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, flops_per_step + ) n_runs, n_steps, n_total_samples = y.shape if sample_levels_to_plot is None: sample_levels_to_plot = generate_sample_sizes(n_total_samples) - flops_sampling = np.array(results["flops_sampling"]) # (B, n_steps) - if "flops" in results: - flops_optimization = np.array(results["flops"]) # (B, n_steps) - else: - flops_optimization = np.zeros_like(flops_sampling) # (B, n_steps) - if flops_per_step is not None: - flops_optimization += flops_per_step(np.arange(flops_optimization.shape[1])) - - def subsample_and_aggregate_ratio(step_idx, sample_idx, cumulative, y, opt_flops, sampling_flops, rng): - n_runs, n_steps, n_total_samples = y.shape - opt_flop = np.mean(opt_flops[:, :step_idx+1].sum(axis=1)) - sampling_flop = np.mean(sampling_flops[:, step_idx]) * sample_idx - - # Calculate ratio (sampling / total), handle division by zero - ratio = sampling_flop / (sampling_flop + opt_flop + 1e-9) - - if cumulative and step_idx > 0: - samples_at_end = y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1) - samples_up_to_now = y[:, :step_idx, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0] - values = np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1) - return (ratio, step_idx, sample_idx, values.mean(0), opt_flop, sampling_flop) - return (ratio, step_idx, sample_idx, y[:, step_idx, rng.choice(n_total_samples, size=sample_idx, replace=False)].max(axis=-1).mean(axis=0), opt_flop, sampling_flop) - - def get_ratio_pts(y, opt_flops, sampling_flops): - n_runs, n_steps, n_total_samples = y.shape - rng = np.random.default_rng() - pts = [] # (ratio, step, n_samples, mean_p, opt_flop, sampling_flop) - for j in range(1, n_total_samples + 1, 1): - for i in range(0, n_steps, 1): - pts.append(subsample_and_aggregate_ratio(i, j, cumulative, y, opt_flops, sampling_flops, rng)) - pts = np.asarray(pts) - return pts - - pts = get_ratio_pts(y, flops_optimization, flops_sampling) + pts = get_points(y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, + return_ratio=True, cumulative=cumulative) ratio, step_idx, n_samp, mean_p, opt_flop, sampling_flop = pts.T # Calculate total FLOPS for coloring option @@ -592,46 +573,35 @@ def get_ratio_pts(y, opt_flops, sampling_flops): plt.figure(figsize=(10, 6)) - # Choose color values based on color_by parameter - if color_by == "samples": - color_values = n_samp_finite - color_label = "Number of Samples" - elif color_by == "total_flops": - color_values = total_flop_finite - color_label = "Total FLOPS" - else: - raise ValueError(f"color_by must be 'samples' or 'total_flops', got '{color_by}'") + # Create dual color encoding: hue based on samples, strength based on total FLOPS + # Normalize sample counts for hue + sample_norm = setup_color_normalization("linear", n_samp_finite) + # Normalize total FLOPS for alpha/strength + flops_norm = setup_color_normalization(color_scale, total_flop_finite) - # Scatter plot - if color_scale == "log": - color_norm = LogNorm() - elif color_scale == "sqrt": - color_norm = PowerNorm(gamma=0.5) # Square root normalization - else: - color_norm = None - - sc = plt.scatter(ratio_finite, mean_p_finite, c=color_values, cmap="viridis", alpha=0.3, s=15, norm=color_norm) - if color_by != "samples": - plt.colorbar(sc, label=color_label) - - # Highlight specific sample levels (only when coloring by samples) - if color_by == "samples": - cmap = plt.get_cmap("viridis") - if color_scale == "log": - norm = LogNorm(n_samp.min(), n_samp.max()) - elif color_scale == "sqrt": - norm = PowerNorm(gamma=0.5, vmin=n_samp.min(), vmax=n_samp.max()) - else: - norm = plt.Normalize(n_samp.min(), n_samp.max()) + # Get base colors from viridis colormap based on sample count + cmap = plt.get_cmap("viridis") + base_colors = cmap(sample_norm(n_samp_finite)) - for j in sample_levels_to_plot: - mask = (n_samp == j) & finite_mask - if np.any(mask): - color = cmap(norm(j)) - plt.scatter(ratio[mask], mean_p[mask], - color=color, s=50, alpha=0.8, - edgecolors='black', linewidth=0.5, - label=f"{j} samples") + # Scatter plot with dual color encoding + sc = plt.scatter(ratio_finite, mean_p_finite, c=base_colors, s=15, alpha=0.05) + + # Create custom colorbar for samples (hue) + sm = plt.cm.ScalarMappable(cmap=cmap, norm=sample_norm) + sm.set_array([]) + + # Highlight specific sample levels + for j in sample_levels_to_plot: + mask = (n_samp == j) & finite_mask + if np.any(mask): + # Use the same dual coloring for highlighted points + highlight_base_color = cmap(sample_norm(j)) + # Create colors for this sample level + + plt.scatter(ratio[mask], mean_p[mask], + c=highlight_base_color, s=50, alpha=0.9, + edgecolors='black', linewidth=0.5, + label=f"{j} samples") plt.xlabel("Sampling FLOPS / Total FLOPS", fontsize=14) if threshold is None: @@ -644,26 +614,73 @@ def get_ratio_pts(y, opt_flops, sampling_flops): # Add baseline if provided if baseline is not None: - y_baseline = np.array(baseline[metric]) - if threshold is not None: - y_baseline = y_baseline > threshold - - baseline_flops_sampling = np.array(baseline["flops_sampling"]) - if "flops" in baseline: - baseline_flops_optimization = np.array(baseline["flops"]) - else: - baseline_flops_optimization = np.zeros_like(baseline_flops_sampling) - if flops_per_step is not None: - baseline_flops_optimization += flops_per_step(np.arange(baseline_flops_optimization.shape[1])) + y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data( + baseline, metric, threshold, flops_per_step + ) - baseline_pts = get_ratio_pts(y_baseline, baseline_flops_optimization, baseline_flops_sampling) + baseline_pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, + baseline_flops_sampling_generation, return_ratio=True, cumulative=cumulative) baseline_ratio, _, baseline_n_samp, baseline_mean_p, _, _ = baseline_pts.T - baseline_finite_mask = np.isfinite(baseline_ratio) - plt.scatter(baseline_ratio[baseline_finite_mask], baseline_mean_p[baseline_finite_mask], - color="red", s=50, alpha=0.8, marker="^", - edgecolors='black', linewidth=0.5, label="Greedy") - + baseline_finite_mask = np.isfinite(baseline_ratio) & np.isfinite(baseline_mean_p) + if np.any(baseline_finite_mask): + # For baseline, just plot the raw ratios + plt.scatter(baseline_ratio[baseline_finite_mask], baseline_mean_p[baseline_finite_mask], + color="red", s=60, alpha=0.9, marker="^", + edgecolors='black', linewidth=0.5, label="Greedy", zorder=6) + + # Add subtle iso-FLOP lines (fitted quadratics) + if n_total_samples == 500: # Only if we have enough data points + # Select 5 FLOP levels spanning the range + flop_min, flop_max = np.min(total_flop_finite), np.max(total_flop_finite) + iso_flop_levels = np.logspace(np.log10(flop_min), np.log10(flop_max), 5) + + for i, flop_level in enumerate(iso_flop_levels): + # Find points near this FLOP level (within 20% tolerance) + tolerance = 0.15 + near_flop_mask = np.abs(total_flop_finite - flop_level) / flop_level < tolerance + + if np.sum(near_flop_mask) >= 3: # Need at least 3 points for quadratic fit + x_iso = ratio_finite[near_flop_mask] + y_iso = mean_p_finite[near_flop_mask] + + # Sort by x for smooth curve + sort_idx = np.argsort(x_iso) + x_iso_sorted = x_iso[sort_idx] + y_iso_sorted = y_iso[sort_idx] + + # Fit quadratic in log-space for x + try: + log_x = np.log10(x_iso_sorted) + coeffs = np.polyfit(log_x, y_iso_sorted, 2) + + # Generate smooth curve + x_smooth = np.logspace(np.log10(x_iso_sorted.min()) - 0.25, + np.log10(x_iso_sorted.max()) + 0.25, 50) + log_x_smooth = np.log10(x_smooth) + y_smooth = np.polyval(coeffs, log_x_smooth) + + # Plot the iso-FLOP line with label for first one only + label = "Iso-FLOP lines" if i == 0 else None + plt.plot(x_smooth, y_smooth, '--', color='gray', alpha=0.8, + linewidth=1, zorder=1, label=label) + + # Add text annotation for FLOP level at the end of the curve + if len(x_smooth) > 0 and len(y_smooth) > 0: + # Find a good position for the text (middle of the curve) + mid_idx = 0 + text_x = x_smooth[mid_idx] + text_y = y_smooth[mid_idx] + + # Format FLOP level in scientific notation + flop_text = f"{flop_level:.1e}" + plt.text(text_x, text_y, flop_text, fontsize=8, alpha=0.8, + ha='center', va='top', color='black', + bbox=dict(boxstyle='round,pad=0.2', facecolor='white', + alpha=0.7, edgecolor='none')) + + except (np.linalg.LinAlgError, ValueError) as e: + raise e plt.xscale("log") plt.xlim(1e-5, 1) plt.ylim(bottom=0) @@ -675,12 +692,516 @@ def get_ratio_pts(y, opt_flops, sampling_flops): if verbose: print(f"FLOPS ratio range: {ratio_finite.min():.2e} to {ratio_finite.max():.2e}") print(f"Mean p_harmful range: {mean_p_finite.min():.4f} to {mean_p_finite.max():.4f}") + print(f"Total FLOPS range: {total_flop_finite.min():.2e} to {total_flop_finite.max():.2e}") + + +def ideal_ratio_plot( + results: dict[str,np.ndarray], + baseline: dict[str,np.ndarray] | None = None, + title: str = "Ideal Sampling FLOPS Ratio", + sample_levels_to_plot: tuple[int, ...]|None = None, + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + cumulative: bool = False, + flops_per_step: int | None = None, + threshold: float|None = None, + n_p_harmful_points: int = 100, + verbose: bool = True, +): + """ + Plot the ideal sampling FLOPS ratio for achieving different levels of harmfulness. + For each p_harmful level, finds the point that achieves that level with minimum total FLOPS + and plots the corresponding sampling ratio. + """ + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, flops_per_step + ) + + n_smoothing = 50 + pts = get_points(y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, + return_ratio=True, n_smoothing=n_smoothing, cumulative=cumulative) + ratio, step_idx, n_samp, mean_p, total_flop = pts.T[:5] + + # Filter out infinite ratios and invalid points + finite_mask = np.isfinite(ratio) & np.isfinite(mean_p) & np.isfinite(total_flop) + ratio_finite = ratio[finite_mask] + mean_p_finite = mean_p[finite_mask] + total_flop_finite = total_flop[finite_mask] + + # Create p_harmful levels to evaluate + p_harmful_min = np.min(mean_p_finite) + p_harmful_max = np.max(mean_p_finite) + p_harmful_levels = np.linspace(p_harmful_min, p_harmful_max, n_p_harmful_points) + + # Find ideal ratio for each p_harmful level + ideal_ratios = [] + max_ratios = [] # Track maximum ratios explored at each level + min_ratios = [] # Track minimum ratios explored at each level + achieved_p_levels = [] + + for p_level in p_harmful_levels: + # Find all points that achieve at least this p_harmful level + achieving_mask = mean_p_finite >= p_level + + if np.any(achieving_mask): + # Among achieving points, find the one with minimum total FLOPS + achieving_flops = total_flop_finite[achieving_mask] + achieving_ratios = ratio_finite[achieving_mask] + + min_flops_idx = np.argmin(achieving_flops) + ideal_ratio = achieving_ratios[min_flops_idx] + + # Find the maximum and minimum ratios explored at this level + max_ratio = np.max(achieving_ratios) + min_ratio = np.min(achieving_ratios) + + ideal_ratios.append(ideal_ratio) + max_ratios.append(max_ratio) + min_ratios.append(min_ratio) + achieved_p_levels.append(p_level) + + ideal_ratios = np.array(ideal_ratios) + max_ratios = np.array(max_ratios) + min_ratios = np.array(min_ratios) + achieved_p_levels = np.array(achieved_p_levels) + + plt.figure(figsize=(12, 8)) + + # Create the FLOP landscape: interpolated surface with color indicating total FLOPS + # Use raw ratios instead of normalized ones + landscape_p_harmful = [] + landscape_ratios = [] + landscape_total_flops = [] + + for i, (p_val, ratio_val, flop_val) in enumerate(zip(mean_p_finite, ratio_finite, total_flop_finite)): + landscape_p_harmful.append(p_val) + landscape_ratios.append(ratio_val) + landscape_total_flops.append(flop_val) + + landscape_p_harmful = np.array(landscape_p_harmful) + landscape_ratios = np.array(landscape_ratios) + landscape_total_flops = np.array(landscape_total_flops) + + # Create interpolated surface + # Define grid for interpolation + p_grid = np.linspace(np.min(landscape_p_harmful), np.max(landscape_p_harmful), 100) + ratio_grid = np.logspace(np.log10(1e-5), np.log10(1.0), 100) + P_grid, Ratio_grid = np.meshgrid(p_grid, ratio_grid) + + # Interpolate FLOPS values onto the grid + try: + flops_grid = griddata( + (landscape_p_harmful, landscape_ratios), + landscape_total_flops, + (P_grid, Ratio_grid), + method='linear', + fill_value=np.nan + ) + except Exception as e: + raise ValueError(f"Error interpolating FLOPS values.") + + # Create mask to only show values within the explored bounds + mask = np.ones_like(flops_grid, dtype=bool) + for i, p_val in enumerate(p_grid): + # Find the closest p_harmful level to get min/max bounds + closest_idx = np.argmin(np.abs(achieved_p_levels - p_val)) + if closest_idx < len(min_ratios) and closest_idx < len(max_ratios): + min_bound = min_ratios[closest_idx] + max_bound = max_ratios[closest_idx] + + # Mask out values outside the bounds + ratio_col = Ratio_grid[:, i] + outside_bounds = (ratio_col < min_bound) | (ratio_col > max_bound) + mask[outside_bounds, i] = False + + # Apply mask + flops_grid_masked = np.where(mask, flops_grid, np.nan) + + # Create contour plot of the FLOP landscape + contour = plt.contourf(P_grid, Ratio_grid, flops_grid_masked, levels=50, + cmap='plasma', alpha=0.8, extend='both') + + # Add colorbar for total FLOPS + cbar = plt.colorbar(contour, label='Total FLOPS') + cbar.formatter.set_powerlimits((0, 0)) # Use scientific notation + + # Plot the ideal ratio curve (raw ratios) - this traces the minimum through the landscape + plt.plot(achieved_p_levels, ideal_ratios, 'b-', linewidth=3, label='Ideal Ratio (Min FLOPS)', marker='o', markersize=4, zorder=5) + + # Plot the maximum ratio explored curve (raw ratios) + plt.plot(achieved_p_levels, max_ratios, 'g--', linewidth=2, label='Max Ratio Explored', marker='s', markersize=3, zorder=5) + + # Plot the minimum ratio explored curve (raw ratios) + plt.plot(achieved_p_levels, min_ratios, 'r:', linewidth=2, label='Min Ratio Explored', marker='^', markersize=3, zorder=5) + + # Handle baseline if provided + if baseline is not None: + y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data( + baseline, metric, threshold, flops_per_step + ) + + baseline_pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, + baseline_flops_sampling_generation, return_ratio=True, n_smoothing=n_smoothing, cumulative=cumulative) + baseline_ratio, _, baseline_n_samp, baseline_mean_p, baseline_total_flop = baseline_pts.T[:5] + + baseline_finite_mask = np.isfinite(baseline_ratio) & np.isfinite(baseline_mean_p) + if np.any(baseline_finite_mask): + # For baseline, just plot the raw ratios + plt.scatter(baseline_mean_p[baseline_finite_mask], baseline_ratio[baseline_finite_mask], + color="red", s=80, alpha=0.9, marker="^", + edgecolors='black', linewidth=0.5, label="Greedy", zorder=6) + + plt.xlabel("p_harmful", fontsize=14) + plt.ylabel("Sampling FLOPS / Total FLOPS", fontsize=14) + plt.grid(True, alpha=0.3) + plt.title(title, fontsize=16) + plt.xlim(0, 1) + plt.yscale('log') + plt.ylim(1e-5, 1.0) + plt.xlim(left=0) + plt.legend() + plt.tight_layout() + plt.savefig(f"evaluate/distributional_paper/ideal_ratio_plots/{title}.pdf", bbox_inches='tight') + plt.close() + + if verbose: + print(f"p_harmful range: {p_harmful_min:.4f} to {p_harmful_max:.4f}") + print(f"Ideal ratio range: {ideal_ratios.min():.4f} to {ideal_ratios.max():.4f}") + print(f"Max ratio range: {max_ratios.min():.4f} to {max_ratios.max():.4f}") + print(f"Min ratio range: {min_ratios.min():.4f} to {min_ratios.max():.4f}") + print(f"Total FLOPS landscape range: {landscape_total_flops.min():.2e} to {landscape_total_flops.max():.2e}") + print(f"Number of points in landscape: {len(landscape_total_flops)}") + print(f"Number of p_harmful levels with solutions: {len(achieved_p_levels)}") + + +# Helper --------------------------------------------------------------------------- +def run_analysis( + model: str, + model_title: str, + atk_name: str, + cfg: dict, + analysis_type: str = "pareto", +): + """ + Unified function to run different types of analysis. + + Parameters + ---------- + analysis_type : str + Type of analysis: "pareto", "flops_ratio", "ideal_ratio", "histogram", "histogram_2" + """ + print(f"{analysis_type.title()} Analysis:", atk_name, cfg.get("title_suffix", "")) + + # ---------- sampled run ---------- + sampled_data = fetch_data(model, cfg.get("attack_override", atk_name), cfg["sample_params"](), + DATASET_IDX, GROUP_BY) + + # Attack-specific post-processing + if post := cfg.get("postprocess"): + post(sampled_data, METRIC) + + # ---------- baseline run (not needed for histograms) ---------- + baseline_data = None + if analysis_type not in ["histogram", "histogram_2"]: + baseline_attack = cfg.get("baseline_attack", atk_name) + baseline_data = fetch_data(model, baseline_attack, cfg["baseline_params"](), + DATASET_IDX, GROUP_BY) + + # ---------- generate plot based on analysis type ---------- + flops_per_step_fn = lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)) + + if analysis_type == "pareto": + pareto_plot( + sampled_data, + baseline_data, + title=f"{model_title} {cfg['title_suffix']}", + cumulative=cfg["cumulative"], + metric=METRIC, + flops_per_step=flops_per_step_fn, + threshold=None, + color_scale="sqrt", + ) + elif analysis_type == "flops_ratio": + flops_ratio_plot( + sampled_data, + baseline_data, + title=f"{model_title} {cfg['title_suffix']} FLOPS Ratio", + cumulative=cfg["cumulative"], + metric=METRIC, + flops_per_step=flops_per_step_fn, + threshold=None, + color_scale="sqrt", + ) + elif analysis_type == "ideal_ratio": + ideal_ratio_plot( + sampled_data, + baseline_data, + title=f"{model_title} {cfg['title_suffix']} Ideal Ratio", + cumulative=cfg["cumulative"], + metric=METRIC, + flops_per_step=flops_per_step_fn, + threshold=None, + ) + elif analysis_type == "histogram": + + plt.figure(figsize=(10, 6)) + data_list = [] + positions = [] + data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) + for i in np.arange(data.shape[1]): + data_list.append(data[:, i].flatten()) + positions.append(i) + + # Create 2D heatmap + # Define bins for p_harmful values (y-axis) + p_harmful_bins = np.linspace(0, 1, 101) # 50 bins from 0 to 1 + + # Create 2D histogram matrix + heatmap_data = np.zeros((len(p_harmful_bins)-1, len(positions))) + + for i, (pos, d) in enumerate(zip(positions, data_list)): + # Calculate histogram for this position + counts, _ = np.histogram(d, bins=p_harmful_bins) + heatmap_data[:, i] = counts / len(d) + + # Create the heatmap + im = plt.imshow(heatmap_data, + aspect='auto', + origin='lower', + extent=[positions[0], positions[-1], 0, 1], + cmap='viridis', + norm=LogNorm(vmin=1/len(d), vmax=heatmap_data.max()) + ) + # Add colorbar + cbar = plt.colorbar(im, label='Density') + + # Calculate and plot median and mean lines + medians = [] + means = [] + for data_at_pos in data_list: + medians.append(np.median(data_at_pos)) + means.append(np.mean(data_at_pos)) + + # Plot mean line + plt.plot(positions, means, color='orange', linewidth=2, label='Mean', alpha=0.8) + plt.plot(positions, np.maximum.accumulate(means), color='red', linewidth=2, label='Max', alpha=0.8) + + # Add legend + plt.legend(loc='upper right', framealpha=0.8) + + plt.xscale("log") + plt.xlabel('Step Index', fontsize=14) + plt.ylabel('p_harmful', fontsize=14) + plt.title(f'{model_title} - {atk_name} - p_harmful Distribution', fontsize=16) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Save the plot + filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}.pdf" + plt.savefig(filename, dpi=300, bbox_inches='tight') + plt.close() + elif analysis_type == "ridge": + sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) + + # Create ridge plot for p_harmful distributions across steps + data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) + + # Prepare data for ridge plot + ridge_data = [] + for step_idx in [0, 1, 2, 5, 10, 20, 50, 99, 249]: + step_data = data[:, step_idx, :].flatten() # Get p_harmful values for this step + # Round/bucketize the data into five values: 0, 0.25, 0.5, 0.75, 1.0 + # step_data = np.round(step_data * 4) / 4 + for value in step_data: + ridge_data.append({'step': f'Step {step_idx}', 'p_harmful': value}) + + df = pd.DataFrame(ridge_data) + + # Create the ridge plot + plt.figure(figsize=(12, 8)) + + # Get unique steps and create color palette + unique_steps = sorted(df['step'].unique(), key=lambda x: int(x.split()[1])) + n_steps = len(unique_steps) + pal = sns.cubehelix_palette(n_steps, rot=-.25, light=.7) + + # Initialize the FacetGrid object + g = sns.FacetGrid(df, row="step", hue="step", aspect=15, height=.5, palette=pal, + row_order=unique_steps) + + # Draw the densities + g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, fill=True, alpha=1, linewidth=1.5) + g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, color="w", lw=2) + + # Add vertical lines for mean and median + def add_mean_lines(x, **kwargs): + ax = plt.gca() + mean_val = np.mean(x) + ax.axvline(mean_val, color='red', linestyle='-', alpha=0.7, linewidth=1, ymax=0.8) + + g.map(add_mean_lines, "p_harmful") + + # Add reference line at y=0 + g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False) + + # Set the subplots to overlap + g.figure.subplots_adjust(hspace=-.4) + + # Remove axes details that don't play well with overlap + g.set_titles("") + g.set(yticks=[], ylabel="") + g.despine(bottom=True, left=True) + g.set(xlim=(0, 1)) + + # Add overall title + g.figure.suptitle(f'{model_title} - {cfg["title_suffix"]} - p_harmful Ridge Plot', + fontsize=16, y=0.98) + + # Save the plot + filename = f"evaluate/distributional_paper/ridge_plots/{model_title}_{cfg['title_suffix']}.pdf" + plt.savefig(filename, dpi=300, bbox_inches='tight') + plt.close() + elif analysis_type == "histogram_2": + # Create histogram plot + plt.figure(figsize=(10, 6)) + threshold = 0.0 + show_top = False # Set to False to hide the top subplot + bins = np.linspace(0, 1, 21) + + + data_low = np.array(sampled_data[("scores","strong_reject","p_harmful")])[:, 0].flatten() + data_low = data_low[data_low > threshold] + + data_high = np.array(sampled_data[("scores","strong_reject","p_harmful")])[:, -1].flatten() + data_high = data_high[data_high > threshold] + + # Create CDF plot + fig, ax = plt.subplots(figsize=(10, 6)) + + # Calculate survival functions (1 - CDF) + data_low_sorted = np.sort(data_low) + data_high_sorted = np.sort(data_high) + + # Calculate proportion of items with value <= x, then invert to get survival function + cdf_low = np.arange(1, len(data_low_sorted) + 1) / len(data_low_sorted) + cdf_high = np.arange(1, len(data_high_sorted) + 1) / len(data_high_sorted) + + survival_low = 1 - cdf_low + survival_high = 1 - cdf_high + + # Plot survival functions + ax.plot(data_low_sorted, survival_low, label="First Step", linewidth=2, alpha=0.8) + ax.plot(data_high_sorted, survival_high, label="Last Step", linewidth=2, alpha=0.8) + + ax.set_xlabel("p_harmful", fontsize=14) + ax.set_ylabel("Survival Probability (P(X $>$ x))", fontsize=14) + ax.set_xlim(threshold, 1) + ax.set_ylim(0, 1) + ax.legend(fontsize=12) + ax.grid(True, alpha=0.3) + + ax.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful Survival Function", + fontsize=16) + + # Save the plot + filename = f"evaluate/distributional_paper/cdf_plots/{model_title}_{cfg['title_suffix']}.pdf" + plt.savefig(filename, dpi=300, bbox_inches='tight') + plt.close() + + if show_top: + fig, (ax_top, ax_bottom) = plt.subplots( + 2, 1, sharex=True, + figsize=(10, 6), + gridspec_kw={"height_ratios": [1, 3]} # top is shorter + ) + axes = (ax_top, ax_bottom) + else: + fig, ax_bottom = plt.subplots(figsize=(10, 4)) + axes = (ax_bottom,) + + # draw the same two histograms on both axes + for ax in axes: + ax.hist(data_high, bins=bins, alpha=0.7, label="Last Step") + ax.hist(data_low, bins=bins, alpha=0.7, label="First Step") + + if show_top: + # ------------- tune y limits so the top shows only the tall bar(s) ----- + counts_start = np.sort(np.histogram(data_low.flatten(), bins=bins)[0]) + counts_end = np.sort(np.histogram(data_high.flatten(), bins=bins)[0]) + + ylims = ( + (0, max(counts_start[-2], counts_end[-2]) * 1.1), + (min(counts_start[-1], counts_end[-1]) * 0.9, max(counts_start[-1], counts_end[-1]) * 1.1), + ) + ax_top.set_ylim(*ylims[1]) + ax_bottom.set_ylim(*ylims[0]) + + # ---------------- cosmetic fixes ----------------- + ax_top.spines.bottom.set_visible(False) + ax_bottom.spines.top.set_visible(False) + ax_top.tick_params(axis="x", which="both", bottom=False, labelbottom=False) + + # little diagonal "break" marks + d = .5 # size of diagonal lines + kwargs = dict(marker=[(-1, -1), (1, 1)], markersize=12, + linestyle="none", color="k", clip_on=False) + ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **kwargs) + ax_bottom.plot([0, 1], [1, 1], transform=ax_bottom.transAxes, **kwargs) + + # --------- the usual labels ------------ + ax_bottom.set_xlabel("p_harmful", fontsize=14) + ax_bottom.set_ylabel("Frequency (\\%)", fontsize=14) + ax_bottom.set_xlim(threshold, 1) + + if show_top: + ax_top.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful distribution", + fontsize=16) + else: + ax_bottom.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful distribution", + fontsize=16) + + # Convert y-axis tick labels to percentages + if not show_top: + # Get current y-tick locations + yticks = ax_bottom.get_yticks() + # Convert to percentages by dividing by total count and multiplying by 100 + ytick_labels = [f"{100*tick / len(data_low.flatten()):.0f}%" for tick in yticks] + ax_bottom.set_yticklabels(ytick_labels) + else: + for ax in (ax_bottom,): + # Get current y-tick locations + yticks = ax.get_yticks() + # Convert to percentages by dividing by total count and multiplying by 100 + ytick_labels = [f"{100*tick / len(data_low.flatten()):.0f}%" for tick in yticks] + ax.set_yticklabels(ytick_labels) + # Convert y-axis tick labels to percentages + for ax in (ax_top,): + # Get current y-tick locations + yticks = ax.get_yticks() + # Convert to percentages by dividing by total count and multiplying by 100 + ytick_labels = [f"{100*tick / len(data_low.flatten()):.1f}%" for tick in yticks] + ax.set_yticklabels(ytick_labels) + ax_top.legend() + + if not show_top: + ax_bottom.legend() + + for ax in axes: + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot + filename = f"evaluate/distributional_paper/histograms_2/{model_title}_{cfg['title_suffix']}.pdf" + plt.savefig(filename, dpi=300, bbox_inches='tight') + plt.close() + else: + raise ValueError(f"Unknown analysis type: {analysis_type}") + # ---------------------------------------------------------------------------------- -# Pareto plots – simplified +# Configuration and Constants # ---------------------------------------------------------------------------------- -import numpy as np MODELS = { "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta Llama 3.1 8B", @@ -690,7 +1211,7 @@ def get_ratio_pts(y, opt_flops, sampling_flops): FLOPS_PER_STEP = { "autodan": lambda s, c: 69845248149248 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, - "gcg": lambda s, c: 14958709489152 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, + "gcg": lambda s, c: int(1e14) + 14958709489152 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, "beast": lambda s, c: 10447045889280 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, "pair": lambda s, c: 83795198566400 + 78737584640 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, } # for 0.5B model @@ -826,232 +1347,28 @@ def get_ratio_pts(y, opt_flops, sampling_flops): GROUP_BY = {"model", "attack_params"} DATASET_IDX = list(range(50)) - -# Helper --------------------------------------------------------------------------- -def run_attack( - model: str, - model_title: str, - atk_name: str, - cfg: dict, -): - print("Attack:", atk_name, cfg["title_suffix"]) - - # ---------- helper to fetch data ---------- - def fetch(attack: str, attack_params: dict): - filter_by = dict( - model=model, - attack=attack, - attack_params=attack_params, - dataset_params={"idx": DATASET_IDX}, - ) - paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) - results = collect_results(paths, infer_sampling_flops=True) - assert len(results) == 1, len(results) - return list(results.values())[0] - - # ---------- sampled run ---------- - sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) - - # Attack-specific post-processing - if post := cfg.get("postprocess"): - post(sampled_data, METRIC) - - # ---------- baseline run ---------- - baseline_attack = cfg.get("baseline_attack", atk_name) - baseline_data = fetch(baseline_attack, cfg["baseline_params"]()) - - # ---------- plot ---------- - pareto_plot( - sampled_data, - baseline_data, - title=f"{model_title} {cfg['title_suffix']}", - cumulative=cfg["cumulative"], - metric=METRIC, - flops_per_step=lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)), - threshold=None, - color_scale="sqrt", - ) - - -def run_attack_flops_ratio( - model: str, - model_title: str, - atk_name: str, - cfg: dict, -): - print("FLOPS Ratio Attack:", atk_name) - - # ---------- helper to fetch data ---------- - def fetch(attack: str, attack_params: dict): - filter_by = dict( - model=model, - attack=attack, - attack_params=attack_params, - dataset_params={"idx": DATASET_IDX}, - ) - paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) - results = collect_results(paths, infer_sampling_flops=True) - assert len(results) == 1, len(results) - return list(results.values())[0] - - # ---------- sampled run ---------- - sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) - - # Attack-specific post-processing - if post := cfg.get("postprocess"): - post(sampled_data, METRIC) - - # ---------- baseline run ---------- - baseline_attack = cfg.get("baseline_attack", atk_name) - baseline_data = fetch(baseline_attack, cfg["baseline_params"]()) - - # ---------- plot ---------- - flops_ratio_plot( - sampled_data, - baseline_data, - title=f"{model_title} {cfg['title_suffix']} FLOPS Ratio", - cumulative=cfg["cumulative"], - metric=METRIC, - flops_per_step=lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)), - threshold=None, - # color_by="total_flops", - color_scale="sqrt", - ) - - -# # Main loop ------------------------------------------------------------------------ -for model_key, model_title in MODELS.items(): - print("Model:", model_key) - for atk_name, atk_cfg in ATTACKS: - try: - run_attack(model_key, model_title, atk_name, atk_cfg) - except Exception as e: - print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") - - -# FLOPS Ratio plots main loop ------------------------------------------------------ -print("\n" + "="*80) -print("GENERATING FLOPS RATIO PLOTS") -print("="*80) - -for model_key, model_title in MODELS.items(): - print("Model:", model_key) - for atk_name, atk_cfg in ATTACKS: - try: - run_attack_flops_ratio(model_key, model_title, atk_name, atk_cfg) - except Exception as e: - print(f"Error running FLOPS ratio attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") - - -# Helper --------------------------------------------------------------------------- -def run_attack_2( - model: str, - model_title: str, - atk_name: str, - cfg: dict, -): - print("Attack:", atk_name) - - # ---------- helper to fetch data ---------- - def fetch(attack: str, attack_params: dict): - filter_by = dict( - model=model, - attack=attack, - attack_params=attack_params, - dataset_params={"idx": DATASET_IDX}, - ) - paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) - results = collect_results(paths, infer_sampling_flops=True) - assert len(results) == 1, len(results) - return list(results.values())[0] - - # ---------- sampled run ---------- - sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) - - # Attack-specific post-processing - if post := cfg.get("postprocess"): - post(sampled_data, METRIC) - - data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")])[:, 0] - # Create histogram plot - plt.figure(figsize=(10, 6)) - plt.hist(data.flatten(), bins=100, alpha=0.7, edgecolor='black') - plt.xlabel('p_harmful', fontsize=14) - plt.ylabel('Frequency', fontsize=14) - plt.title(f'{model_title} - {atk_name} - p_harmful Distribution', fontsize=16) - plt.grid(True, alpha=0.3) - plt.tight_layout() - - # Save the plot - filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}.png" - plt.savefig(filename, dpi=300, bbox_inches='tight') - plt.close() - -for model_key, model_title in MODELS.items(): - print("Model:", model_key) - for atk_name, atk_cfg in ATTACKS: - if atk_name != "direct": continue - try: - run_attack_2(model_key, model_title, atk_name, atk_cfg) - except Exception as e: - print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") - - -# Helper --------------------------------------------------------------------------- -def run_attack_2( - model: str, - model_title: str, - atk_name: str, - cfg: dict, -): - print("Attack:", atk_name) - - # ---------- helper to fetch data ---------- - def fetch(attack: str, attack_params: dict): - filter_by = dict( - model=model, - attack=attack, - attack_params=attack_params, - dataset_params={"idx": DATASET_IDX}, - ) - paths = get_filtered_and_grouped_paths(filter_by, GROUP_BY) - results = collect_results(paths, infer_sampling_flops=True) - assert len(results) == 1, f"Should only have exactly one type of result, got {len(results)}, {list(results.keys())}" - return list(results.values())[0] - - # ---------- sampled run ---------- - sampled_data = fetch(cfg.get("attack_override", atk_name), cfg["sample_params"]()) - - # Attack-specific post-processing - if post := cfg.get("postprocess"): - post(sampled_data, METRIC) - - plt.figure(figsize=(10, 6)) - data_list = [] - positions = [] - for i in range(0, 250, 25): - data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")])[:, i] - data_list.append(data.flatten()) - positions.append(i) - - # Create violin plot - plt.violinplot(data_list, positions=positions, widths=20, showmeans=True, showmedians=True) - plt.xlabel('Step', fontsize=14) - plt.ylabel('Frequency', fontsize=14) - plt.title(f'{model_title} - {atk_name} - p_harmful Distribution', fontsize=16) - plt.grid(True, alpha=0.3) - plt.tight_layout() - - # Save the plot - filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}.png" - plt.savefig(filename, dpi=300, bbox_inches='tight') - plt.close() - -for model_key, model_title in MODELS.items(): - print("Model:", model_key) - for atk_name, atk_cfg in ATTACKS: - if atk_name != "gcg": continue - try: - run_attack_2(model_key, model_title, atk_name, atk_cfg) - except Exception as e: - print(f"Error running attack {atk_name}, atk_cfg: {atk_cfg['title_suffix']}: {e}") +def main(fail: bool = False): + # for analysis_type in ["pareto", "flops_ratio", "ideal_ratio", "histogram", "histogram_2", "ridge"]: + for analysis_type in ["flops_ratio", "ideal_ratio", "histogram", "histogram_2", "ridge"]: + print("\n" + "="*80) + print(f"GENERATING {analysis_type.upper().replace('_', ' ')} PLOTS") + print("="*80) + + for model_key, model_title in MODELS.items(): + print("Model:", model_key) + for atk_name, atk_cfg in ATTACKS: + try: + run_analysis(model_key, model_title, atk_name, atk_cfg, analysis_type) + except Exception as e: + if fail: + raise e + print(f"Error running {analysis_type} analysis for {atk_name}, " + f"cfg: {atk_cfg.get('title_suffix', 'unknown')}: {e}") + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Generate plots for distributional paper') + parser.add_argument('--fail', action='store_true', help='Override flag to fail') + args = parser.parse_args() + + main(args.fail) diff --git a/purge_orphans.py b/purge_orphans.py new file mode 100644 index 0000000..c69ca74 --- /dev/null +++ b/purge_orphans.py @@ -0,0 +1,4 @@ +from src.io_utils import delete_orphaned_runs + +if __name__ == "__main__": + delete_orphaned_runs() \ No newline at end of file diff --git a/src/attacks/direct.py b/src/attacks/direct.py index d042f2a..5366c2e 100644 --- a/src/attacks/direct.py +++ b/src/attacks/direct.py @@ -140,6 +140,7 @@ def run( model_completions=model_completions, time_taken=(t1 - t0) / B, loss=loss, + flops=0, model_input=model_input, model_input_tokens=model_input_tokens, ) diff --git a/src/io_utils.py b/src/io_utils.py index ef062a9..3ed68f6 100644 --- a/src/io_utils.py +++ b/src/io_utils.py @@ -396,6 +396,27 @@ def check_match(doc_fragment, filter_fragment): return doc_fragment == filter_fragment +def normalize_value_for_grouping(value): + """ + Normalize a value for consistent grouping. + + Converts numeric values to a canonical form to ensure that 0 and 0.0, + or 1 and 1.0, etc. are treated as identical for grouping purposes. + For dictionaries and lists, recursively normalizes all contained values. + """ + if isinstance(value, dict): + return {k: normalize_value_for_grouping(v) for k, v in value.items()} + elif isinstance(value, (list, tuple)): + normalized_list = [normalize_value_for_grouping(item) for item in value] + return type(value)(normalized_list) # preserve the original type (list or tuple) + elif isinstance(value, (int, float)): + # Convert to int if it's a whole number, otherwise keep as float + if isinstance(value, float) and value.is_integer(): + return int(value) + return value + return value + + def get_nested_value(data: dict, path: list[str], default="unknown"): """ Safely retrieves a value from a nested dictionary using a path list/tuple. @@ -463,11 +484,13 @@ def get_filtered_and_grouped_paths(filter_by, group_by) -> dict[tuple[str], list for key_spec in group_by: if isinstance(key_spec, str): value = get_nested_value(config_data, [key_spec]) - group_key_parts.append(f"{key_spec}={value}") + normalized_value = normalize_value_for_grouping(value) + group_key_parts.append(f"{key_spec}={normalized_value}") elif isinstance(key_spec, (list, tuple)): value = get_nested_value(config_data, key_spec) - key_name = '.'.join(map(str, key_spec)) # Ensure sub-keys are strings for join - group_key_parts.append(f"{key_name}={value}") + normalized_value = normalize_value_for_grouping(value) + key_name = '.'.join(map(str, key_spec)) + group_key_parts.append(f"{key_name}={normalized_value}") else: group_key_parts.append(f"invalid_group_spec={key_spec}") @@ -529,7 +552,8 @@ def collect_results(paths, infer_sampling_flops=False) -> dict[tuple[str], dict[ if infer_sampling_flops: max_new_tokens = results["config"]["attack_params"]["generation_config"]["max_new_tokens"] model_params = num_model_params(results["config"]["model_params"]["id"]) - step["flops_sampling"] = model_params * (max_new_tokens + len(step["model_input_tokens"])) * 2 + step["flops_sampling_prefill_cache"] = model_params * len(step["model_input_tokens"]) * 2 + step["flops_sampling_generation"] = model_params * max_new_tokens * 2 for metric in step.keys(): # this will fill collected_metrics with values from step[metric] # and handles nested containers diff --git a/src/judges.py b/src/judges.py index 3e8e490..30cc09d 100644 --- a/src/judges.py +++ b/src/judges.py @@ -48,9 +48,9 @@ def judge( pass @torch.no_grad() - def __call__(self, chats: List[List[Dict[str, str]]]) -> Dict[str, List[float]]: + def __call__(self, chats: List[List[Dict[str, str]]], verbose: bool = False) -> Dict[str, List[float]]: """Allows calling the judge instance directly.""" - return with_max_batchsize(self.judge, chats) + return with_max_batchsize(self.judge, chats, verbose=verbose) @staticmethod def validate_chats( @@ -402,4 +402,3 @@ def judge( categories.append(answer_category) return {"category": categories} - diff --git a/src/lm_utils.py b/src/lm_utils.py index 55e540d..1cd217c 100644 --- a/src/lm_utils.py +++ b/src/lm_utils.py @@ -22,7 +22,7 @@ def generate_ragged_batched( tokenizer: PreTrainedTokenizerBase, token_list: list[torch.IntTensor] | None = None, embedding_list: list[torch.FloatTensor] | None = None, - initial_batch_size: int = 256, + initial_batch_size: int | None = None, use_cache: bool = True, verbose: bool = False, num_return_sequences: int = 1, @@ -106,9 +106,14 @@ def with_max_batchsize(function: Callable, *inputs, initial_batch_size: int | No outputs = [] i = 0 + + def next_power_of_two(n): + return 1 << (n - 1).bit_length() + if initial_batch_size is None: - initial_batch_size = input_length - batch_size = min(initial_batch_size, input_length) + initial_batch_size = next_power_of_two(input_length) + + batch_size = min(initial_batch_size, next_power_of_two(input_length)) pbar = tqdm(total=input_length, desc=f"Running function b={batch_size}", file=sys.stdout) if verbose else None while i < input_length: From a671a62bc13af9ffd90083fa1fbbf0dcf698f4c1 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Thu, 12 Jun 2025 14:55:35 +0200 Subject: [PATCH 08/13] Simplify main API - use attack,dataset,model instead of attack_name,dataset_name,model_name --- conf/config.yaml | 6 +++--- conf/models/models.yaml | 9 +++++++++ conf/sweep_across_models.yaml | 8 ++++---- run_attacks.py | 22 ++++++++++++---------- run_ensemble.py | 20 ++++++++++---------- run_sampling.py | 35 ++++++++++++++++------------------- src/io_utils.py | 34 +++++++++++++++++++--------------- 7 files changed, 73 insertions(+), 61 deletions(-) diff --git a/conf/config.yaml b/conf/config.yaml index 3a17844..0a5b536 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -21,6 +21,6 @@ overwrite: false classifiers: ["strong_reject"] # Options: "strong_reject", "cais", "llama_guard_3_8b", "overrefusal", etc. (can be list[str]|null) # overrides for parallelism -attack_name: null -dataset_name: adv_behaviors -model_name: null \ No newline at end of file +attack: null +dataset: adv_behaviors +model: null \ No newline at end of file diff --git a/conf/models/models.yaml b/conf/models/models.yaml index ef07515..c6afa7b 100644 --- a/conf/models/models.yaml +++ b/conf/models/models.yaml @@ -387,6 +387,15 @@ mistralai/Ministral-8B-Instruct-2410: dtype: bfloat16 chat_template: llama-3-instruct trust_remote_code: True +/ceph/ssd/shared/hf_models/cat-llama3.2it-fixed: + id: /ceph/ssd/shared/hf_models/cat-llama3.2it-fixed + tokenizer_id: meta-llama/Llama-3.2-3B-Instruct + short_name: Llama + developer_name: Meta + compile: False + dtype: bfloat16 + chat_template: llama-3-instruct + trust_remote_code: True GSAI-ML/LLaDA-8B-Instruct: id: GSAI-ML/LLaDA-8B-Instruct tokenizer_id: GSAI-ML/LLaDA-8B-Instruct diff --git a/conf/sweep_across_models.yaml b/conf/sweep_across_models.yaml index e4e07da..ec14773 100644 --- a/conf/sweep_across_models.yaml +++ b/conf/sweep_across_models.yaml @@ -10,7 +10,7 @@ hydra: mode: MULTIRUN sweeper: params: - model_name: google/gemma-2-2b-it, mistralai/Mistral-7B-Instruct-v0.3, meta-llama/Meta-Llama-3.1-8B-Instruct, qwen/Qwen2-7B-Instruct, HuggingFaceH4/zephyr-7b-beta, meta-llama/Llama-2-7b-chat-hf, ContinuousAT/Llama-2-7B-CAT,ContinuousAT/Zephyr-CAT,ContinuousAT/Phi-CAT, microsoft/Phi-3-mini-4k-instruct, cais/zephyr_7b_r2d2, GraySwanAI/Llama-3-8B-Instruct-RR, GraySwanAI/Mistral-7B-Instruct-RR + model: google/gemma-2-2b-it, mistralai/Mistral-7B-Instruct-v0.3, meta-llama/Meta-Llama-3.1-8B-Instruct, qwen/Qwen2-7B-Instruct, HuggingFaceH4/zephyr-7b-beta, meta-llama/Llama-2-7b-chat-hf, ContinuousAT/Llama-2-7B-CAT,ContinuousAT/Zephyr-CAT,ContinuousAT/Phi-CAT, microsoft/Phi-3-mini-4k-instruct, cais/zephyr_7b_r2d2, GraySwanAI/Llama-3-8B-Instruct-RR, GraySwanAI/Mistral-7B-Instruct-RR datasets.adv_behaviors.idx: range(0, 300) run: dir: ${root_dir}/multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}/ @@ -24,6 +24,6 @@ name: testing save_dir: ${root_dir}/outputs/ # overrides for parallelism -attack_name: null -dataset_name: null -model_name: null \ No newline at end of file +attack: null +dataset: null +model: null \ No newline at end of file diff --git a/run_attacks.py b/run_attacks.py index b4dbffd..cf34b28 100644 --- a/run_attacks.py +++ b/run_attacks.py @@ -28,21 +28,23 @@ def select_configs(cfg: DictConfig, name: str | ListConfig | None) -> list[tuple def collect_configs(cfg: DictConfig) -> list[RunConfig]: - models_to_run = select_configs(cfg.models, cfg.model_name) - datasets_to_run = select_configs(cfg.datasets, cfg.dataset_name) - attacks_to_run = select_configs(cfg.attacks, cfg.attack_name) + if hasattr(cfg, 'model_name') or hasattr(cfg, 'dataset_name') or hasattr(cfg, 'attack_name'): + raise ValueError("model_name, dataset_name, and attack_name are deprecated. Use model, dataset, and attack instead.") + models_to_run = select_configs(cfg.models, cfg.model) + datasets_to_run = select_configs(cfg.datasets, cfg.dataset) + attacks_to_run = select_configs(cfg.attacks, cfg.attack) all_run_configs = [] - for model_name, model_params in models_to_run: - for dataset_name, dataset_params in datasets_to_run: - temp_dataset = PromptDataset.from_name(dataset_name)(dataset_params) + for model, model_params in models_to_run: + for dataset, dataset_params in datasets_to_run: + temp_dataset = PromptDataset.from_name(dataset)(dataset_params) dset_len = len(temp_dataset) dataset_params["idx"] = temp_dataset.config_idx - for attack_name, attack_params in attacks_to_run: + for attack, attack_params in attacks_to_run: run_config = RunConfig( - model_name, - dataset_name, - attack_name, + model, + dataset, + attack, model_params, dataset_params, attack_params, diff --git a/run_ensemble.py b/run_ensemble.py index 3e35e7c..2c59975 100644 --- a/run_ensemble.py +++ b/run_ensemble.py @@ -11,7 +11,7 @@ args.add_argument("--skip_h100", action="store_true") args = args.parse_args() -model_name = args.model +model = args.model attacks_a100 = [ "gcg", "autodan", @@ -27,8 +27,8 @@ # Base command template template_a100 = ( - "python run_attacks.py -m ++model_name={model_name} ++attack_name={attack_name} " - "++datasets.adv_behaviors.idx={indices} ++dataset_name=adv_behaviors " + "python run_attacks.py -m ++model={model} ++attack={attack} " + "++datasets.adv_behaviors.idx={indices} ++dataset=adv_behaviors " "++hydra.launcher.timeout_min=300 " ) @@ -38,8 +38,8 @@ ) # Base command template template_pgd = ( - "python run_attacks.py ++model_name={model_name} ++attack_name=pgd " - '++datasets.adv_behaviors.idx="{indices}" ++dataset_name=adv_behaviors ' + "python run_attacks.py ++model={model} ++attack=pgd " + '++datasets.adv_behaviors.idx="{indices}" ++dataset=adv_behaviors ' "++hydra.launcher.timeout_min=300 ++hydra.launcher.qos=deadline " "++attacks.pgd.epsilon=0.5 ++attacks.pgd.normalize_alpha=true ++attacks.pgd.normalize_gradient=true ++attacks.pgd.alpha=0.01 " "-m " @@ -51,23 +51,23 @@ if not args.skip_a100: commands.append( template_a100.format( - model_name=model_name, - attack_name=",".join(attacks_a100), + model=model, + attack=",".join(attacks_a100), indices=indices, ) ) if not args.skip_h100: commands.append( template_h100.format( - model_name=model_name, - attack_name=",".join(attacks_h100), + model=model, + attack=",".join(attacks_h100), indices=indices, ) ) if args.pgd: commands.append( template_pgd.format( - model_name=model_name, + model=model, indices=list(range(args.min_idx, args.max_idx)), ) ) diff --git a/run_sampling.py b/run_sampling.py index b44ddee..bbaff62 100644 --- a/run_sampling.py +++ b/run_sampling.py @@ -6,9 +6,9 @@ Example: filter_by: - model_name: + model: - google/gemma-3-1b-it - attack_name: + attack: - gcg - pgd @@ -16,7 +16,6 @@ completions the runs should have in the end. """ import os -from dataclasses import dataclass from datetime import datetime os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # determinism @@ -32,8 +31,8 @@ from src.dataset import json from src.errors import print_exceptions -from src.io_utils import (CompactJSONEncoder, free_vram, get_filtered_and_grouped_paths, - get_mongodb_connection, load_model_and_tokenizer) +from src.io_utils import (CompactJSONEncoder, RunConfig, free_vram, get_filtered_and_grouped_paths, + get_mongodb_connection, load_model_and_tokenizer, filter_config) from src.judges import Judge from src.lm_utils import generate_ragged_batched @@ -41,17 +40,6 @@ torch.backends.cuda.matmul.allow_tf32 = True -@dataclass -class RunConfig: - model: str - dataset: str - attack: str - model_params: dict - dataset_params: dict - attack_params: dict - config: dict - - def eval_list_expressions_in_dict(d): if isinstance(d, dict): return {k: eval_list_expressions_in_dict(v) for k, v in d.items()} @@ -135,7 +123,18 @@ def main(cfg: DictConfig) -> None: logging.info(f"Skipping {path}, it already has {n_already_generated} completions") continue attack_run["config"]["attack_params"]["generation_config"]["num_return_sequences"] = cfg.num_return_sequences - model_params = OmegaConf.create(attack_run["config"]["model_params"]) + run_config = RunConfig( + model=attack_run["config"]["model"], + dataset=attack_run["config"]["dataset"], + attack=attack_run["config"]["attack"], + model_params=OmegaConf.structured(attack_run["config"]["model_params"]), + dataset_params=OmegaConf.structured(attack_run["config"]["dataset_params"]), + attack_params=OmegaConf.structured(attack_run["config"]["attack_params"]), + ) + run_config = filter_config(run_config, -1) + if run_config is None: + continue + model_params = run_config.model_params if model_params != last_model_params: pbar.set_description(f"Loading new model and tokenizer {model_params}") last_model_params = model_params @@ -242,7 +241,5 @@ def main(cfg: DictConfig) -> None: raise Exception(f"Error in {path}. Original exception: {e}") from e - - if __name__ == "__main__": main() diff --git a/src/io_utils.py b/src/io_utils.py index 3ed68f6..d602fdc 100644 --- a/src/io_utils.py +++ b/src/io_utils.py @@ -545,21 +545,25 @@ def collect_results(paths, infer_sampling_flops=False) -> dict[tuple[str], dict[ for k, v in paths.items(): aggregated_results = defaultdict(list) for path in v: - results = cached_json_load(path) - for run in results["runs"]: - collected_metrics = defaultdict(list) - for step in run["steps"]: - if infer_sampling_flops: - max_new_tokens = results["config"]["attack_params"]["generation_config"]["max_new_tokens"] - model_params = num_model_params(results["config"]["model_params"]["id"]) - step["flops_sampling_prefill_cache"] = model_params * len(step["model_input_tokens"]) * 2 - step["flops_sampling_generation"] = model_params * max_new_tokens * 2 - for metric in step.keys(): - # this will fill collected_metrics with values from step[metric] - # and handles nested containers - _gather(step[metric], (metric,), collected_metrics) - for metric, v in collected_metrics.items(): - aggregated_results[metric].append(v) + try: + results = cached_json_load(path) + for run in results["runs"]: + collected_metrics = defaultdict(list) + for step in run["steps"]: + if infer_sampling_flops: + max_new_tokens = results["config"]["attack_params"]["generation_config"]["max_new_tokens"] + model_params = num_model_params(results["config"]["model_params"]["id"]) + step["flops_sampling_prefill_cache"] = model_params * len(step["model_input_tokens"]) * 2 + step["flops_sampling_generation"] = model_params * max_new_tokens * 2 + for metric in step.keys(): + # this will fill collected_metrics with values from step[metric] + # and handles nested containers + _gather(step[metric], (metric,), collected_metrics) + for metric, v in collected_metrics.items(): + aggregated_results[metric].append(v) + except Exception as e: + print(f"Error loading {path}") + raise e all_results[k] = aggregated_results return all_results From 0b0bb1cf7c0a1396e83e789957283f4318e377c3 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Fri, 20 Jun 2025 15:31:44 +0200 Subject: [PATCH 09/13] Add staged changes --- src/attacks/beast.py | 13 +++++++------ src/io_utils.py | 3 +++ src/lm_utils.py | 8 ++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/attacks/beast.py b/src/attacks/beast.py index 7b4bf34..0398c82 100644 --- a/src/attacks/beast.py +++ b/src/attacks/beast.py @@ -10,18 +10,19 @@ Implementation adapted from https://github.com/dreadnode/research/blob/main/notebooks/Mistral%20-%20BEAST%20Beam%20Attack.ipynb """ +import copy +import sys import time from dataclasses import dataclass, field from functools import partial -import copy + import torch from tqdm import trange from transformers import AutoModelForCausalLM, AutoTokenizer -from src.lm_utils import (generate_ragged_batched, get_disallowed_ids, - with_max_batchsize, prepare_conversation, get_flops) - -from src.attacks.attack import Attack, AttackResult, GenerationConfig, SingleAttackRunResult, AttackStepResult +from .attack import (Attack, AttackResult, AttackStepResult, GenerationConfig, SingleAttackRunResult) +from src.lm_utils import (generate_ragged_batched, get_disallowed_ids, get_flops, + prepare_conversation, with_max_batchsize) from src.types import Conversation @@ -120,7 +121,7 @@ def run(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataset) -> per_sample_losses = [torch.log(torch.tensor(initial_ppl)).item()] per_sample_flops = [flops] beams: list[torch.LongTensor] = [torch.LongTensor([]) for b in beams] - for i in (pbar := trange(1, self.config.num_steps)): + for i in (pbar := trange(1, self.config.num_steps, file=sys.stdout)): flops = 0 t1 = time.time() # Get next K1 x K2 candidates diff --git a/src/io_utils.py b/src/io_utils.py index d602fdc..ae2f8d6 100644 --- a/src/io_utils.py +++ b/src/io_utils.py @@ -82,6 +82,9 @@ def load_model_and_tokenizer(model_params): case path if "llama-2" in path: tokenizer.pad_token = tokenizer.unk_token tokenizer.model_max_length = 4096 + case path if "llama2" in path: + tokenizer.pad_token = tokenizer.unk_token + tokenizer.model_max_length = 4096 case path if "meta-llama/meta-llama-3-8b-instruct" in path: tokenizer.model_max_length = 8192 tokenizer.eos_token_id = 128009 # want to use <|eot_id|> instead of <|eos_id|> (https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/4) diff --git a/src/lm_utils.py b/src/lm_utils.py index 1cd217c..e3caa33 100644 --- a/src/lm_utils.py +++ b/src/lm_utils.py @@ -386,12 +386,8 @@ def sample_next_token(logits: torch.Tensor) -> torch.Tensor: still_active = (~finished & ~finished_at_this_step)[~finished].cpu() for j in range(len(past_key_values.key_cache)): - if not is_gemma: - past_key_values.key_cache[j] = past_key_values.key_cache[j][still_active] - past_key_values.value_cache[j] = past_key_values.value_cache[j][still_active] - else: - past_key_values.key_cache[j] = past_key_values.key_cache[j][still_active] - past_key_values.value_cache[j] = past_key_values.value_cache[j][still_active] + past_key_values.key_cache[j] = past_key_values.key_cache[j][still_active] + past_key_values.value_cache[j] = past_key_values.value_cache[j][still_active] finished |= finished_at_this_step if finished.all(): From 4168c6ce72ba1148eb4559bdbd7823787f90ce7a Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Fri, 20 Jun 2025 15:32:19 +0200 Subject: [PATCH 10/13] plots --- evaluate/distributional_paper/make_plots.py | 765 ++++++++++++++------ 1 file changed, 546 insertions(+), 219 deletions(-) diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py index a13148a..022d433 100644 --- a/evaluate/distributional_paper/make_plots.py +++ b/evaluate/distributional_paper/make_plots.py @@ -9,7 +9,8 @@ from scipy.interpolate import interp1d from matplotlib.colors import LogNorm, PowerNorm from scipy.interpolate import griddata -from brokenaxes import brokenaxes +import logging +logging.basicConfig(level=logging.INFO) pd.set_option("display.max_colwidth", None) pd.set_option("display.max_columns", None) @@ -140,6 +141,7 @@ def fetch_data(model: str, attack: str, attack_params: dict, dataset_idx: list[i paths = get_filtered_and_grouped_paths(filter_by, group_by) results = collect_results(paths, infer_sampling_flops=True) + print(group_by, filter_by, len(paths), len(results)) assert len(results) == 1, f"Should only have exactly one type of result, got {len(results)}, {list(results.keys())}" return list(results.values())[0] @@ -269,11 +271,11 @@ def pareto_plot( else: x_interp = np.linspace(0, max_cost+1, n_x_points) - # Create figure with subplots: main plot + 2 bar charts - fig = plt.figure(figsize=(18, 6)) + # Create figure with subplots: main plot + 2x2 grid on the right + fig = plt.figure(figsize=(18, 8)) # Main Pareto plot (left half, spanning both rows) - ax1 = plt.subplot2grid((1, 4), (0, 0), colspan=2) + ax1 = plt.subplot2grid((2, 4), (0, 0), colspan=2, rowspan=2) # ---------- scatter all points ---------- color_norm = setup_color_normalization(color_scale, n_samp) @@ -378,8 +380,8 @@ def pareto_plot( if y_baseline is not None: title_suffix = f" ({n_runs}, {y_baseline.shape[0]})" if verbose: - print(n_runs, "for main") - print(y_baseline.shape[0], "for baseline") + logging.info(f"{n_runs} for main") + logging.info(f"{y_baseline.shape[0]} for baseline") n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape assert n_total_samples_baseline == 1 @@ -408,7 +410,7 @@ def pareto_plot( marker="o", linewidth=1.8, markersize=2, - label=f"greedy", + label=f"Baseline", color="r", ) @@ -419,7 +421,7 @@ def pareto_plot( plt.legend(title="Frontiers", loc="upper left" if x_scale == "log" else "lower right") # ---------- Bar Chart 1: Max ASR Comparison (Vertical Slice) ---------- - ax2 = plt.subplot2grid((1, 4), (0, 2)) + ax2 = plt.subplot2grid((2, 4), (0, 2)) methods = [] max_asrs = [] @@ -427,7 +429,7 @@ def pareto_plot( # Add baseline (delta = 0 for baseline) if baseline_frontier_data is not None: - methods.append("Greedy") + methods.append("Baseline") max_asrs.append(0.0) # Delta from itself is 0 colors.append("red") @@ -443,9 +445,9 @@ def pareto_plot( bars = plt.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black') if threshold is None: plt.ylabel(r"$\Delta$ $p_{harmful}$", fontsize=14) - plt.title(r"$p_{harmful}$ vs. \#samples", fontsize=14) + # plt.title(r"$p_{harmful}$ vs. \#samples", fontsize=14) else: - plt.title(r"$\overline{{ASR}}$ vs. \#samples".format(threshold=threshold), fontsize=14) + # plt.title(r"$\overline{{ASR}}$ vs. \#samples".format(threshold=threshold), fontsize=14) plt.ylabel(r"$\Delta$ $\overline{{ASR}}\quad (p_{{harmful}} \geq {threshold})$".format(threshold=threshold), fontsize=14) plt.xticks(rotation=45, ha='right') plt.grid(True, alpha=0.3, axis='y') @@ -458,18 +460,17 @@ def pareto_plot( for bar, value in zip(bars, max_asrs): # choose label position: above for positive, below for negative offset_pt = 4 # visual gap in points - y = bar.get_height() - va = 'bottom' if y >= 0 else 'top' - offset = (0, offset_pt if y >= 0 else -offset_pt) + va = 'bottom' if value >= 0 else 'top' + offset = (0, offset_pt if value >= 0 else -offset_pt) ax2.annotate(f'{value:.3f}', - xy=(bar.get_x() + bar.get_width()/2, y), + xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), xytext=offset, textcoords='offset points', ha='center', va=va, fontsize=10) - # ---------- Bar Chart 2: FLOPS Efficiency to Reach Greedy ASR (Horizontal Slice) ---------- - ax3 = plt.subplot2grid((1, 4), (0, 3)) + # ---------- Bar Chart 2: FLOPS Efficiency to Reach Baseline ASR (Horizontal Slice) ---------- + ax3 = plt.subplot2grid((2, 4), (0, 3)) if baseline_frontier_data is not None and baseline_max_asr > 0: methods_flops = [] @@ -504,14 +505,13 @@ def pareto_plot( else: # Fallback to minimum FLOPS if no point reaches target ASR baseline_flops = np.min(baseline_x_vals) - methods_flops.insert(0, "Greedy") + methods_flops.insert(0, "Baseline") flops_required.insert(0, baseline_flops) colors_flops.insert(0, "red") if methods_flops: bars = plt.bar(methods_flops, flops_required, color=colors_flops, alpha=0.7, edgecolor='black') - plt.ylabel("FLOPS", fontsize=12) - plt.title(r"FLOPS to reach same $p_{harmful}$ as greedy" + f" ( = {target_asr:.3f})", fontsize=12) + plt.ylabel(r"FLOPS for Baseline $p_{harmful}$" + f" ( = {target_asr:.3f})", fontsize=12) plt.xticks(rotation=45, ha='right') plt.yscale('log') plt.grid(True, alpha=0.3, axis='y') @@ -529,6 +529,132 @@ def pareto_plot( textcoords='offset points', ha='center', va='bottom', rotation=45, fontsize=9) + # ---------- Bar Chart 3: Speedup vs Baseline (Bottom Right) ---------- + ax4 = plt.subplot2grid((2, 4), (1, 3)) + + # Create speedup plot + speedup_methods = [] + speedups = [] + speedup_colors = [] + + # Calculate speedup for each method (baseline_flops / method_flops) + baseline_flops = flops_required[0] if methods_flops[0] == "Baseline" else None + + if baseline_flops is not None: + for i, (method, flops, color) in enumerate(zip(methods_flops, flops_required, colors_flops)): + if method != "Baseline": # Skip baseline itself + speedup = baseline_flops / flops if flops > 0 else 0 + speedup_methods.append(method) + speedups.append(speedup) + speedup_colors.append(color) + + if speedup_methods: + bars = plt.bar(speedup_methods, speedups, color=speedup_colors, alpha=0.7, edgecolor='black') + plt.ylabel("Speedup (FLOPS) vs Baseline", fontsize=12) + plt.xticks(rotation=45, ha='right') + plt.grid(True, alpha=0.3, axis='y') + + # Add horizontal line at y=1 for reference + plt.axhline(y=1, color='red', linestyle='--', alpha=0.7, linewidth=1) + + # Increase ylim by small margin + ymin, ymax = plt.ylim() + margin = (ymax - ymin) * 0.05 + plt.ylim(max(0, ymin - margin), ymax + margin) + + # Add value labels on bars + for bar, value in zip(bars, speedups): + ax4.annotate(f'{value:.2f}x', + xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), + xytext=(0, 5), + textcoords='offset points', + ha='center', va='bottom', fontsize=10) + + # ---------- Line Plot 4: Continuous FLOPS to Reach Baseline ASR (Bottom Left) ---------- + ax5 = plt.subplot2grid((2, 4), (1, 2)) + + if baseline_frontier_data is not None and baseline_max_asr > 0: + target_asr = baseline_max_asr + + # Generate continuous range of sample counts + sample_range = range(1, n_total_samples + 1) + continuous_flops = [] + continuous_samples = [] + + # Calculate frontier data for all sample counts (not just sample_levels_to_plot) + rng_continuous = np.random.default_rng() + n_smoothing_continuous = 10 # Reduced for performance + + for j in sample_range: + xs = [] + ys = [] + for _ in range(n_smoothing_continuous): + pts = [] + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, + flops_sampling_prefill_cache, flops_sampling_generation, rng_continuous)) + + pts = np.asarray(pts) + cost, _, _, mean_p = pts.T + + fx, fy = _pareto_frontier(cost, mean_p, method=frontier_method) + xs.append(fx) + ys.append(fy) + + # Interpolate and average + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) + for x_, y_ in zip(xs, ys)] + y_mean = np.mean(y_interp, axis=0) + + # Find minimum FLOPS where ASR >= target_asr + nonzero_mask = y_mean > 0 + if np.any(nonzero_mask): + y_vals = y_mean[nonzero_mask] + x_vals = x_interp[nonzero_mask] + + valid_indices = y_vals >= target_asr + if np.any(valid_indices): + min_flops = np.min(x_vals[valid_indices]) + continuous_flops.append(min_flops) + continuous_samples.append(j) + + if continuous_flops: + # Plot the continuous line + plt.plot(continuous_samples, continuous_flops, 'b-', linewidth=2, alpha=0.8, label='All Samples') + + # Highlight the baseline point + if baseline_frontier_data['x'].size > 0: + baseline_y_vals = baseline_frontier_data['y'] + baseline_x_vals = baseline_frontier_data['x'] + baseline_valid_indices = baseline_y_vals >= target_asr + if np.any(baseline_valid_indices): + baseline_flops = np.min(baseline_x_vals[baseline_valid_indices]) + plt.axhline(y=baseline_flops, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Baseline') + + # Highlight the discrete sample levels from the bar chart + for j in sample_levels_to_plot: + if j in [s for s in continuous_samples]: + idx = continuous_samples.index(j) + color = cmap(color_norm(j)) + plt.scatter(j, continuous_flops[idx], color=color, s=60, alpha=0.9, + edgecolors='black', linewidth=0.5, zorder=5) + + plt.xlabel("Number of Samples", fontsize=12) + plt.ylabel("FLOPS to Reach Baseline ASR", fontsize=12) + plt.xscale('log') + plt.yscale('log') + plt.grid(True, alpha=0.3) + plt.legend(fontsize=10) + + # Set reasonable x-axis limits + plt.xlim(1, n_total_samples) + + # Increase ylim by small margin + ymin, ymax = plt.ylim() + import math + margin = ((math.log10(ymax) - math.log10(ymin)) * 0.1) + plt.ylim(ymin / (1+margin), ymax * (1+margin)) + plt.tight_layout() plt.savefig(f"evaluate/distributional_paper/pareto_plots/{title}.pdf") plt.close() @@ -627,7 +753,7 @@ def flops_ratio_plot( # For baseline, just plot the raw ratios plt.scatter(baseline_ratio[baseline_finite_mask], baseline_mean_p[baseline_finite_mask], color="red", s=60, alpha=0.9, marker="^", - edgecolors='black', linewidth=0.5, label="Greedy", zorder=6) + edgecolors='black', linewidth=0.5, label="Baseline", zorder=6) # Add subtle iso-FLOP lines (fitted quadratics) if n_total_samples == 500: # Only if we have enough data points @@ -690,9 +816,9 @@ def flops_ratio_plot( plt.close() if verbose: - print(f"FLOPS ratio range: {ratio_finite.min():.2e} to {ratio_finite.max():.2e}") - print(f"Mean p_harmful range: {mean_p_finite.min():.4f} to {mean_p_finite.max():.4f}") - print(f"Total FLOPS range: {total_flop_finite.min():.2e} to {total_flop_finite.max():.2e}") + logging.info(f"FLOPS ratio range: {ratio_finite.min():.2e} to {ratio_finite.max():.2e}") + logging.info(f"Mean p_harmful range: {mean_p_finite.min():.4f} to {mean_p_finite.max():.4f}") + logging.info(f"Total FLOPS range: {total_flop_finite.min():.2e} to {total_flop_finite.max():.2e}") def ideal_ratio_plot( @@ -828,10 +954,10 @@ def ideal_ratio_plot( plt.plot(achieved_p_levels, ideal_ratios, 'b-', linewidth=3, label='Ideal Ratio (Min FLOPS)', marker='o', markersize=4, zorder=5) # Plot the maximum ratio explored curve (raw ratios) - plt.plot(achieved_p_levels, max_ratios, 'g--', linewidth=2, label='Max Ratio Explored', marker='s', markersize=3, zorder=5) + # plt.plot(achieved_p_levels, max_ratios, 'g--', linewidth=2, label='Max Ratio Explored', marker='s', markersize=3, zorder=5) # Plot the minimum ratio explored curve (raw ratios) - plt.plot(achieved_p_levels, min_ratios, 'r:', linewidth=2, label='Min Ratio Explored', marker='^', markersize=3, zorder=5) + # plt.plot(achieved_p_levels, min_ratios, 'r:', linewidth=2, label='Min Ratio Explored', marker='^', markersize=3, zorder=5) # Handle baseline if provided if baseline is not None: @@ -848,7 +974,7 @@ def ideal_ratio_plot( # For baseline, just plot the raw ratios plt.scatter(baseline_mean_p[baseline_finite_mask], baseline_ratio[baseline_finite_mask], color="red", s=80, alpha=0.9, marker="^", - edgecolors='black', linewidth=0.5, label="Greedy", zorder=6) + edgecolors='black', linewidth=0.5, label="Baseline", zorder=6) plt.xlabel("p_harmful", fontsize=14) plt.ylabel("Sampling FLOPS / Total FLOPS", fontsize=14) @@ -864,13 +990,247 @@ def ideal_ratio_plot( plt.close() if verbose: - print(f"p_harmful range: {p_harmful_min:.4f} to {p_harmful_max:.4f}") - print(f"Ideal ratio range: {ideal_ratios.min():.4f} to {ideal_ratios.max():.4f}") - print(f"Max ratio range: {max_ratios.min():.4f} to {max_ratios.max():.4f}") - print(f"Min ratio range: {min_ratios.min():.4f} to {min_ratios.max():.4f}") - print(f"Total FLOPS landscape range: {landscape_total_flops.min():.2e} to {landscape_total_flops.max():.2e}") - print(f"Number of points in landscape: {len(landscape_total_flops)}") - print(f"Number of p_harmful levels with solutions: {len(achieved_p_levels)}") + logging.info(f"p_harmful range: {p_harmful_min:.4f} to {p_harmful_max:.4f}") + logging.info(f"Ideal ratio range: {ideal_ratios.min():.4f} to {ideal_ratios.max():.4f}") + logging.info(f"Max ratio range: {max_ratios.min():.4f} to {max_ratios.max():.4f}") + logging.info(f"Min ratio range: {min_ratios.min():.4f} to {min_ratios.max():.4f}") + logging.info(f"Total FLOPS landscape range: {landscape_total_flops.min():.2e} to {landscape_total_flops.max():.2e}") + logging.info(f"Number of points in landscape: {len(landscape_total_flops)}") + logging.info(f"Number of p_harmful levels with solutions: {len(achieved_p_levels)}") + + +def flops_breakdown_plot( + results: dict[str,np.ndarray], + baseline: dict[str,np.ndarray] | None = None, + title: str = "FLOPS Breakdown Analysis", + sample_levels_to_plot: tuple[int, ...]|None = None, + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + cumulative: bool = False, + flops_per_step: int | None = None, + threshold: float|None = None, + color_scale: str = "linear", + verbose: bool = True, +): + """ + Plot optimization FLOPS vs sampling FLOPS with p_harmful as a 2D surface. + """ + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, flops_per_step + ) + n_runs, n_steps, n_total_samples = y.shape + if sample_levels_to_plot is None: + sample_levels_to_plot = generate_sample_sizes(n_total_samples) + + pts = get_points(y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, + return_ratio=False, cumulative=cumulative) + cost, step_idx, n_samp, mean_p = pts.T + + # Calculate individual FLOP components + opt_flops = [] + sampling_flops = [] + p_harmful_vals = [] + n_samples_vals = [] + + rng = np.random.default_rng(42) # Fixed seed for reproducibility + + for j in range(1, n_total_samples + 1, 1): + for i in range(0, n_steps, 1): + opt_flop = np.mean(flops_optimization[:, :i+1].sum(axis=1)) + sampling_flop = np.mean(flops_sampling_generation[:, i]) * j + np.mean(flops_sampling_prefill_cache[:, i]) + p_vals = [] + for n in range(10): + # Calculate p_harmful value with same logic as other functions + if cumulative and i > 0: + samples_up_to_now = y[:, :i, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0] + samples_at_end = y[:, i, rng.choice(n_total_samples, size=j, replace=False)].max(axis=-1) + p_val = np.stack([samples_up_to_now, samples_at_end], axis=1).max(axis=1).mean(axis=0) + else: + p_val = y[:, i, rng.choice(n_total_samples, size=j, replace=False)].max(axis=-1).mean(axis=0) + p_vals.append(p_val) + + opt_flops.append(opt_flop+sampling_flop) + sampling_flops.append(sampling_flop) + p_harmful_vals.append(np.mean(p_vals)) + n_samples_vals.append(j) + + opt_flops = np.array(opt_flops) + sampling_flops = np.array(sampling_flops) + p_harmful_vals = np.array(p_harmful_vals) + n_samples_vals = np.array(n_samples_vals) + + plt.figure(figsize=(12, 8)) + + # Create 2D surface plot using griddata interpolation + # Define grid for interpolation + sampling_min, sampling_max = sampling_flops.min(), sampling_flops.max() + opt_min, opt_max = opt_flops.min(), opt_flops.max() + + # Use log space for sampling FLOPS if range is large + if sampling_max / sampling_min > 100: + sampling_grid = np.logspace(np.log10(sampling_min), np.log10(sampling_max), 100) + else: + sampling_grid = np.linspace(sampling_min, sampling_max, 100) + + # Use log space for optimization FLOPS if range is large + if opt_max / opt_min > 100: + opt_grid = np.logspace(np.log10(opt_min), np.log10(opt_max), 100) + else: + opt_grid = np.linspace(opt_min, opt_max, 100) + + Sampling_grid, Opt_grid = np.meshgrid(sampling_grid, opt_grid) + + # Interpolate p_harmful values onto the grid + try: + p_harmful_grid = griddata( + (sampling_flops, opt_flops), + p_harmful_vals, + (Sampling_grid, Opt_grid), + method='linear', + fill_value=np.nan + ) + except Exception as e: + if verbose: + logging.info(f"Linear interpolation failed: {e}, trying nearest neighbor") + p_harmful_grid = griddata( + (sampling_flops, opt_flops), + p_harmful_vals, + (Sampling_grid, Opt_grid), + method='nearest', + fill_value=0 + ) + + # Create contour plot + levels = np.linspace(np.nanmin(p_harmful_vals), np.nanmax(p_harmful_vals), 50) + contour = plt.contourf(Sampling_grid, Opt_grid, p_harmful_grid, levels=levels, + cmap='plasma', extend='both') + + + # Add colorbar + cbar = plt.colorbar(contour) + if threshold is None: + cbar.set_label(r"$p_{harmful}$", fontsize=14) + else: + cbar.set_label(f"ASR (threshold: {threshold})", fontsize=14) + + + # Add baseline if provided + if baseline is not None: + y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data( + baseline, metric, threshold, flops_per_step + ) + + baseline_pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, + baseline_flops_sampling_generation, return_ratio=False, cumulative=cumulative) + baseline_cost, baseline_step_idx, baseline_n_samp, baseline_mean_p = baseline_pts.T + + # Calculate baseline FLOP components + baseline_opt_flops = [] + baseline_sampling_flops = [] + + for i in range(0, y_baseline.shape[1], 1): + opt_flop = np.mean(baseline_flops_optimization[:, :i+1].sum(axis=1)) + sampling_flop = np.mean(baseline_flops_sampling_generation[:, i]) * 1 + np.mean(baseline_flops_sampling_prefill_cache[:, i]) + + baseline_opt_flops.append(opt_flop+sampling_flop) + baseline_sampling_flops.append(sampling_flop) + + baseline_opt_flops = np.array(baseline_opt_flops) + baseline_sampling_flops = np.array(baseline_sampling_flops) + + plt.scatter(baseline_sampling_flops, baseline_opt_flops, + s=60, alpha=0.9, marker="^", + edgecolors='red', linewidth=2, + color='white', label="Baseline") + + plt.xlabel("Sampling FLOPS", fontsize=14) + plt.ylabel("Total FLOPS", fontsize=14) + plt.grid(True, alpha=0.3) + plt.title(title, fontsize=16) + + # Use log scale for both axes if the range is large + if sampling_max / sampling_min > 100: + plt.xscale('log') + if opt_max / opt_min > 100: + plt.yscale('log') + + plt.legend(loc='upper left') + plt.tight_layout() + plt.savefig(f"evaluate/distributional_paper/flops_breakdown/{title}.pdf", bbox_inches='tight') + plt.close() + + if verbose: + logging.info(f"Sampling FLOPS range: {sampling_flops.min():.2e} to {sampling_flops.max():.2e}") + logging.info(f"Optimization FLOPS range: {opt_flops.min():.2e} to {opt_flops.max():.2e}") + logging.info(f"p_harmful range: {p_harmful_vals.min():.4f} to {p_harmful_vals.max():.4f}") + logging.info(f"Surface grid shape: {p_harmful_grid.shape}") + logging.info(f"Valid surface points: {np.sum(~np.isnan(p_harmful_grid))}/{p_harmful_grid.size}") + +def ridge_plot( + sampled_data: dict[str,np.ndarray], + model_title: str, + cfg: dict, +): + sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) + + # Create ridge plot for p_harmful distributions across steps + data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) + + # Prepare data for ridge plot + ridge_data = [] + step_idxs = [0] + list(generate_sample_sizes(data.shape[1]-1)) + if data.shape[1]-1 not in step_idxs: + step_idxs.append(data.shape[1]-1) + for step_idx in step_idxs: + step_data = data[:, step_idx, :].flatten() # Get p_harmful values for this step + # Round/bucketize the data into five values: 0, 0.25, 0.5, 0.75, 1.0 + # step_data = np.round(step_data * 4) / 4 + for value in step_data: + ridge_data.append({'step': f'Step {step_idx}', 'p_harmful': value}) + df = pd.DataFrame(ridge_data) + print(df) + + # Create ridge plot for p_harmful distributions across steps + unique_steps = sorted(df['step'].unique(), key=lambda x: int(x.split()[1])) + n_steps = len(unique_steps) + pal = sns.cubehelix_palette(n_steps, rot=-.25, light=.7) + + # Initialize the FacetGrid object + g = sns.FacetGrid(df, row="step", hue="step", aspect=15, height=.5, palette=pal, + row_order=unique_steps) + + # Draw the densities + g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, fill=True, alpha=1, linewidth=1.5) + g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, color="w", lw=2) + + # Add vertical lines for mean and median + def add_mean_lines(x, **kwargs): + ax = plt.gca() + mean_val = np.mean(x) + ax.axvline(mean_val, color='red', linestyle='-', alpha=0.7, linewidth=1, ymax=0.8) + + g.map(add_mean_lines, "p_harmful") + + # Add reference line at y=0 + g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False) + + # Set the subplots to overlap + g.figure.subplots_adjust(hspace=-.4) + + # Remove axes details that don't play well with overlap + g.set_titles("") + g.set(yticks=[], ylabel="") + g.despine(bottom=True, left=True) + g.set(xlim=(0, 1)) + + # Add title to ridge plot + g.figure.suptitle(f'{model_title} - {cfg["title_suffix"]} - p_harmful Ridge Plot', + fontsize=14, y=0.95) + + # Save the ridge plot + filename = f"evaluate/distributional_paper/ridge_plots/{model_title}_{cfg['title_suffix']}.pdf" + g.figure.savefig(filename, dpi=300, bbox_inches='tight') + plt.close(g.figure) + # Helper --------------------------------------------------------------------------- @@ -889,7 +1249,7 @@ def run_analysis( analysis_type : str Type of analysis: "pareto", "flops_ratio", "ideal_ratio", "histogram", "histogram_2" """ - print(f"{analysis_type.title()} Analysis:", atk_name, cfg.get("title_suffix", "")) + logging.info(f"{analysis_type.title()} Analysis: {atk_name} {cfg.get('title_suffix', '')}") # ---------- sampled run ---------- sampled_data = fetch_data(model, cfg.get("attack_override", atk_name), cfg["sample_params"](), @@ -1000,66 +1360,11 @@ def run_analysis( plt.savefig(filename, dpi=300, bbox_inches='tight') plt.close() elif analysis_type == "ridge": - sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) - - # Create ridge plot for p_harmful distributions across steps - data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) - - # Prepare data for ridge plot - ridge_data = [] - for step_idx in [0, 1, 2, 5, 10, 20, 50, 99, 249]: - step_data = data[:, step_idx, :].flatten() # Get p_harmful values for this step - # Round/bucketize the data into five values: 0, 0.25, 0.5, 0.75, 1.0 - # step_data = np.round(step_data * 4) / 4 - for value in step_data: - ridge_data.append({'step': f'Step {step_idx}', 'p_harmful': value}) - - df = pd.DataFrame(ridge_data) - - # Create the ridge plot - plt.figure(figsize=(12, 8)) - - # Get unique steps and create color palette - unique_steps = sorted(df['step'].unique(), key=lambda x: int(x.split()[1])) - n_steps = len(unique_steps) - pal = sns.cubehelix_palette(n_steps, rot=-.25, light=.7) - - # Initialize the FacetGrid object - g = sns.FacetGrid(df, row="step", hue="step", aspect=15, height=.5, palette=pal, - row_order=unique_steps) - - # Draw the densities - g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, fill=True, alpha=1, linewidth=1.5) - g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, color="w", lw=2) - - # Add vertical lines for mean and median - def add_mean_lines(x, **kwargs): - ax = plt.gca() - mean_val = np.mean(x) - ax.axvline(mean_val, color='red', linestyle='-', alpha=0.7, linewidth=1, ymax=0.8) - - g.map(add_mean_lines, "p_harmful") - - # Add reference line at y=0 - g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False) - - # Set the subplots to overlap - g.figure.subplots_adjust(hspace=-.4) - - # Remove axes details that don't play well with overlap - g.set_titles("") - g.set(yticks=[], ylabel="") - g.despine(bottom=True, left=True) - g.set(xlim=(0, 1)) - - # Add overall title - g.figure.suptitle(f'{model_title} - {cfg["title_suffix"]} - p_harmful Ridge Plot', - fontsize=16, y=0.98) - - # Save the plot - filename = f"evaluate/distributional_paper/ridge_plots/{model_title}_{cfg['title_suffix']}.pdf" - plt.savefig(filename, dpi=300, bbox_inches='tight') - plt.close() + ridge_plot( + sampled_data, + model_title, + cfg, + ) elif analysis_type == "histogram_2": # Create histogram plot plt.figure(figsize=(10, 6)) @@ -1194,6 +1499,17 @@ def add_mean_lines(x, **kwargs): filename = f"evaluate/distributional_paper/histograms_2/{model_title}_{cfg['title_suffix']}.pdf" plt.savefig(filename, dpi=300, bbox_inches='tight') plt.close() + elif analysis_type == "flops_breakdown": + flops_breakdown_plot( + sampled_data, + baseline_data, + title=f"{model_title} {cfg['title_suffix']} FLOPS Breakdown", + cumulative=cfg["cumulative"], + metric=METRIC, + flops_per_step=flops_per_step_fn, + threshold=None, + color_scale="sqrt", + ) else: raise ValueError(f"Unknown analysis type: {analysis_type}") @@ -1205,8 +1521,9 @@ def add_mean_lines(x, **kwargs): MODELS = { "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta Llama 3.1 8B", - "google/gemma-3-1b-it": "Gemma 3.1 1B", + "google/gemma-3-1b-it": "Gemma 3 1B", "GraySwanAI/Llama-3-8B-Instruct-RR": "Llama 3 CB", + "Unispac/Llama2-7B-Chat-Augmented": "Llama 2 DeepAlign", } FLOPS_PER_STEP = { @@ -1218,151 +1535,161 @@ def add_mean_lines(x, **kwargs): # Attack-specific configuration ----------------------------------------------------- ATTACKS = [ - ("pair", dict( - title_suffix="PAIR", - cumulative=True, - sample_params=lambda: { - "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - }, - baseline_params=lambda: { - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - }, - )), - ("autodan", dict( - title_suffix="AutoDAN", + # ("pair", dict( + # title_suffix="PAIR", + # cumulative=True, + # sample_params=lambda: { + # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + # }, + # baseline_params=lambda: { + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # }, + # )), + ("beast", dict( + title_suffix="BEAST", cumulative=False, sample_params=lambda: { "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - "early_stopping_threshold": 0, - }, - baseline_params=lambda: { - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - }, - )), - ("gcg", dict( - title_suffix="GCG", - cumulative=False, - sample_params=lambda: { - "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - "num_steps": 250, - "loss": "ce", - "token_selection": "default", - "use_prefix_cache": True, - }, - baseline_params=lambda: { - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - "num_steps": 250, - "loss": "ce", - "token_selection": "default", - "use_prefix_cache": True, - }, - )), - ("gcg", dict( - title_suffix="GCG 500", - cumulative=False, - sample_params=lambda: { - "generation_config": {"num_return_sequences": 500, "temperature": 0.7}, - "num_steps": 250, - "loss": "ce", - "token_selection": "default", - "use_prefix_cache": True, - }, - baseline_params=lambda: { - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - "num_steps": 250, - "loss": "ce", - "token_selection": "default", - "use_prefix_cache": True, - }, - )), - ("gcg", dict( - title_suffix="GCG Entropy Loss", - cumulative=False, - sample_params=lambda: { - "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - "num_steps": 250, - "loss": "entropy_adaptive", - "token_selection": "default", - "use_prefix_cache": True, - }, - baseline_params=lambda: { - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - "num_steps": 250, - "loss": "ce", - "token_selection": "default", - "use_prefix_cache": True, - }, - )), - ("bon", dict( - title_suffix="BoN", - cumulative=False, - sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 0.7}}, - baseline_params=lambda: { - # BoN's baseline is *Direct* with one deterministic sample - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - }, - baseline_attack="direct", - postprocess=lambda data, metric: data.__setitem__( - metric, np.array(data[metric]).transpose(0, 2, 1) - ), - )), - ("bon", dict( - title_suffix="BoN Repro", - cumulative=False, - sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 1.0}}, - baseline_params=lambda: { - # BoN's baseline is *Direct* with one deterministic sample - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - }, - baseline_attack="direct", - postprocess=lambda data, metric: data.__setitem__( - metric, np.array(data[metric]).transpose(0, 2, 1) - ), - )), - ("direct", dict( - title_suffix="Direct", - cumulative=True, - sample_params=lambda: { - "generation_config": {"num_return_sequences": 1000, "temperature": 0.7}, - }, - baseline_params=lambda: { - "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - }, - skip_if_empty=True, # gracefully continue if no paths were found - )), - ("direct", dict( - title_suffix="Direct temp 1.0", - cumulative=True, - sample_params=lambda: { - "generation_config": {"num_return_sequences": 1000, "temperature": 1.0}, }, baseline_params=lambda: { "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, }, - skip_if_empty=True, # gracefully continue if no paths were found )), + # ("autodan", dict( + # title_suffix="AutoDAN", + # cumulative=False, + # sample_params=lambda: { + # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + # "early_stopping_threshold": 0, + # }, + # baseline_params=lambda: { + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # }, + # )), + # ("gcg", dict( + # title_suffix="GCG", + # cumulative=False, + # sample_params=lambda: { + # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + # "num_steps": 250, + # "loss": "ce", + # "token_selection": "default", + # "use_prefix_cache": True, + # }, + # baseline_params=lambda: { + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # "num_steps": 250, + # "loss": "ce", + # "token_selection": "default", + # "use_prefix_cache": True, + # }, + # )), + # ("gcg", dict( + # title_suffix="GCG 500", + # cumulative=False, + # sample_params=lambda: { + # "generation_config": {"num_return_sequences": 500, "temperature": 0.7}, + # "num_steps": 250, + # "loss": "ce", + # "token_selection": "default", + # "use_prefix_cache": True, + # }, + # baseline_params=lambda: { + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # "num_steps": 250, + # "loss": "ce", + # "token_selection": "default", + # "use_prefix_cache": True, + # }, + # )), + # ("gcg", dict( + # title_suffix="GCG Entropy Loss", + # cumulative=False, + # sample_params=lambda: { + # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + # "num_steps": 250, + # "loss": "entropy_adaptive", + # "token_selection": "default", + # "use_prefix_cache": True, + # }, + # baseline_params=lambda: { + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # "num_steps": 250, + # "loss": "ce", + # "token_selection": "default", + # "use_prefix_cache": True, + # }, + # )), + # ("bon", dict( + # title_suffix="BoN", + # cumulative=False, + # sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 0.7}}, + # baseline_params=lambda: { + # # BoN's baseline is *Direct* with one deterministic sample + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # }, + # baseline_attack="direct", + # postprocess=lambda data, metric: data.__setitem__( + # metric, np.array(data[metric]).transpose(0, 2, 1) + # ), + # )), + # ("bon", dict( + # title_suffix="BoN Repro", + # cumulative=False, + # sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 1.0}}, + # baseline_params=lambda: { + # # BoN's baseline is *Direct* with one deterministic sample + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # }, + # baseline_attack="direct", + # postprocess=lambda data, metric: data.__setitem__( + # metric, np.array(data[metric]).transpose(0, 2, 1) + # ), + # )), + # ("direct", dict( + # title_suffix="Direct", + # cumulative=True, + # sample_params=lambda: { + # "generation_config": {"num_return_sequences": 1000, "temperature": 0.7}, + # }, + # baseline_params=lambda: { + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # }, + # skip_if_empty=True, # gracefully continue if no paths were found + # )), + # ("direct", dict( + # title_suffix="Direct temp 1.0", + # cumulative=True, + # sample_params=lambda: { + # "generation_config": {"num_return_sequences": 1000, "temperature": 1.0}, + # }, + # baseline_params=lambda: { + # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + # }, + # skip_if_empty=True, # gracefully continue if no paths were found + # )), ] METRIC = ("scores", "strong_reject", "p_harmful") GROUP_BY = {"model", "attack_params"} -DATASET_IDX = list(range(50)) +DATASET_IDX = list(range(75)) def main(fail: bool = False): - # for analysis_type in ["pareto", "flops_ratio", "ideal_ratio", "histogram", "histogram_2", "ridge"]: - for analysis_type in ["flops_ratio", "ideal_ratio", "histogram", "histogram_2", "ridge"]: - print("\n" + "="*80) - print(f"GENERATING {analysis_type.upper().replace('_', ' ')} PLOTS") - print("="*80) + for analysis_type in ["pareto", "flops_ratio", "ideal_ratio", "histogram", "histogram_2", "ridge", "flops_breakdown"]: + # for analysis_type in [ "ridge"]: + logging.info("\n" + "="*80) + logging.info(f"GENERATING {analysis_type.upper().replace('_', ' ')} PLOTS") + logging.info("="*80) for model_key, model_title in MODELS.items(): - print("Model:", model_key) + logging.info(f"Model: {model_key}") for atk_name, atk_cfg in ATTACKS: try: run_analysis(model_key, model_title, atk_name, atk_cfg, analysis_type) except Exception as e: if fail: raise e - print(f"Error running {analysis_type} analysis for {atk_name}, " + logging.info(f"Error running {analysis_type} analysis for {atk_name}, " f"cfg: {atk_cfg.get('title_suffix', 'unknown')}: {e}") if __name__ == "__main__": From 0ec861295c2ad240070e83d12fd3464bd89a7384 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Mon, 23 Jun 2025 15:52:07 +0200 Subject: [PATCH 11/13] various fixes - allow loading of text attributes in `collect_results` - fix beast for gemma - small qol improvements - add deeply aligned llama2 to pool --- conf/attacks/attacks.yaml | 1 + conf/models/models.yaml | 9 ++ run_attacks.py | 2 - slurm_status.py | 204 ++++++++++++++++++++++++++++++++++++++ src/attacks/beast.py | 29 ++++-- src/io_utils.py | 4 +- 6 files changed, 235 insertions(+), 14 deletions(-) create mode 100755 slurm_status.py diff --git a/conf/attacks/attacks.yaml b/conf/attacks/attacks.yaml index 5c48913..a218893 100644 --- a/conf/attacks/attacks.yaml +++ b/conf/attacks/attacks.yaml @@ -101,6 +101,7 @@ beast: allow_non_ascii: True allow_special: False use_prefix_cache: True + mask_undecided_tokens: False bon: name: bon type: discrete diff --git a/conf/models/models.yaml b/conf/models/models.yaml index c6afa7b..d2d41cb 100644 --- a/conf/models/models.yaml +++ b/conf/models/models.yaml @@ -169,6 +169,15 @@ meta-llama/Llama-2-7b-chat-hf: dtype: bfloat16 chat_template: llama-2-chat trust_remote_code: True +Unispac/Llama2-7B-Chat-Augmented: + id: Unispac/Llama2-7B-Chat-Augmented + tokenizer_id: Unispac/Llama2-7B-Chat-Augmented + short_name: Llama + developer_name: Meta + compile: False + dtype: bfloat16 + chat_template: llama-2-chat + trust_remote_code: True ContinuousAT/Llama-2-7B-CAT: id: ContinuousAT/Llama-2-7B-CAT tokenizer_id: meta-llama/Llama-2-7b-chat-hf diff --git a/run_attacks.py b/run_attacks.py index cf34b28..da4f7c3 100644 --- a/run_attacks.py +++ b/run_attacks.py @@ -28,8 +28,6 @@ def select_configs(cfg: DictConfig, name: str | ListConfig | None) -> list[tuple def collect_configs(cfg: DictConfig) -> list[RunConfig]: - if hasattr(cfg, 'model_name') or hasattr(cfg, 'dataset_name') or hasattr(cfg, 'attack_name'): - raise ValueError("model_name, dataset_name, and attack_name are deprecated. Use model, dataset, and attack instead.") models_to_run = select_configs(cfg.models, cfg.model) datasets_to_run = select_configs(cfg.datasets, cfg.dataset) attacks_to_run = select_configs(cfg.attacks, cfg.attack) diff --git a/slurm_status.py b/slurm_status.py new file mode 100755 index 0000000..d14ed68 --- /dev/null +++ b/slurm_status.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Simple SLURM Job Status Visualizer for Multirun Directories +Shows a nice overview of job completion status with color coding. + +Usage: + python3 slurm_status.py # Show overview of last 20 runs + python3 slurm_status.py -n 50 # Show last 50 runs + python3 slurm_status.py -d # Show detailed view with individual job info +""" + +import os +import time +import argparse +from pathlib import Path +from datetime import datetime +import re + +# Color codes for terminal output +class Colors: + GREEN = '\033[92m' # Success + RED = '\033[91m' # Failed + YELLOW = '\033[93m' # Running/Unknown + BLUE = '\033[94m' # Info + PURPLE = '\033[95m' # Header + CYAN = '\033[96m' # Timestamp + WHITE = '\033[97m' # Default + BOLD = '\033[1m' # Bold + END = '\033[0m' # Reset + +def check_job_status(job_dir): + """Check the status of a single SLURM job""" + # Look for log files with the correct naming pattern + log_files = list(job_dir.glob("*_log.out")) + + if not log_files: + return "NO_LOG", "No log file found" + + log_out = log_files[0] + log_err = log_out.with_suffix('.err') + + try: + # Check the last few lines of the output log + with open(log_out, 'r') as f: + lines = f.readlines() + + with open(log_err, 'r') as f: + err_lines = "\n".join(f.readlines()).lower() + + # Look for completion indicators in the last 10 lines + last_lines = ''.join(lines[-10:]).lower() + if any(word in err_lines for word in ["error", "failed", "exception", "traceback"]): + return "FAILED", "Error detected in logs" + elif "job completed successfully" in last_lines: + return "SUCCESS", "Completed successfully" + elif "exiting after successful completion" in last_lines: + return "SUCCESS", "Completed successfully" + elif any(word in last_lines for word in ["error", "failed", "exception", "traceback"]): + return "FAILED", "Error detected in logs" + else: + # Check the last few lines of the output log + last_lines = ''.join(err_lines[-10:]).lower() + if any(word in last_lines for word in ["error", "failed", "exception", "traceback"]): + if "out of memory" in last_lines: + return "FAILED", "Out of memory" + if "cancelled at" in last_lines: + return "FAILED", "Cancelled" + return "FAILED", "Error detected in logs" + + # Check if log file was modified in the last 2 minutes (job is likely still running) + if time.time() - os.path.getmtime(log_out) < 120: + return "RUNNING", "Job is currently running" + + return "UNKNOWN", "Status unclear" + + except Exception as e: + return "ERROR", f"Could not read logs: {str(e)}" + +def get_run_timestamp(run_path): + """Extract timestamp from run path for sorting""" + try: + date_str = run_path.parent.name # YYYY-MM-DD + time_str = run_path.name # HH-MM-SS + return datetime.strptime(f"{date_str} {time_str}", "%Y-%m-%d %H-%M-%S") + except: + return datetime.min + +def get_multirun_status(max_runs=20, detailed=False): + """Get status of all multirun jobs""" + multirun_path = Path("multirun") + + if not multirun_path.exists(): + print(f"{Colors.RED}Error: multirun directory not found{Colors.END}") + return + + # Find all .submitit directories + submitit_dirs = list(multirun_path.glob("*/*/.submitit")) + + if not submitit_dirs: + print(f"{Colors.YELLOW}No SLURM jobs found in multirun directory{Colors.END}") + return + + # Sort by timestamp (most recent first) + submitit_dirs.sort(key=lambda x: get_run_timestamp(x.parent), reverse=True) + + print(f"{Colors.PURPLE}{Colors.BOLD}🚀 SLURM Job Status Overview{Colors.END}") + print(f"{Colors.PURPLE}{'='*80}{Colors.END}") + + total_runs = 0 + successful_runs = 0 + failed_runs = 0 + running_runs = 0 + + for submitit_dir in submitit_dirs[:max_runs]: + run_path = submitit_dir.parent + run_name = f"{run_path.parent.name}/{run_path.name}" + timestamp = get_run_timestamp(run_path) + + # Find all job directories + job_dirs = [d for d in submitit_dir.iterdir() if d.is_dir() and re.match(r'\d+_\d+', d.name)] + + if not job_dirs: + continue + + total_runs += 1 + + # Check status of all jobs in this run + job_statuses = [] + for job_dir in sorted(job_dirs): + status, message = check_job_status(job_dir) + job_statuses.append((job_dir.name, status, message)) + + # Determine overall run status + success_count = sum(1 for _, status, _ in job_statuses if status == "SUCCESS") + failed_count = sum(1 for _, status, _ in job_statuses if status == "FAILED") + running_count = sum(1 for _, status, _ in job_statuses if status == "RUNNING") + unknown_count = sum(1 for _, status, _ in job_statuses if status not in ["SUCCESS", "FAILED", "RUNNING"]) + total_jobs = len(job_statuses) + + if success_count == total_jobs: + run_status = "SUCCESS" + status_icon = "✅" + status_color = Colors.GREEN + successful_runs += 1 + elif failed_count > 0: + run_status = "FAILED" + status_icon = "❌" + status_color = Colors.RED + failed_runs += 1 + elif running_count > 0: + run_status = "RUNNING" + status_icon = "🏃" + status_color = Colors.YELLOW + running_runs += 1 + else: + run_status = "PARTIAL" + status_icon = "⚠️" + status_color = Colors.YELLOW + + # Print run summary + time_str = timestamp.strftime("%m-%d %H:%M") if timestamp != datetime.min else "unknown" + print(f"{status_color}{status_icon} {run_name}{Colors.END} " + f"{Colors.CYAN}[{time_str}]{Colors.END} " + f"({total_jobs} jobs) " + f"{Colors.GREEN}{success_count}✓{Colors.END} " + f"{Colors.RED}{failed_count}✗{Colors.END} " + f"{Colors.YELLOW}{running_count}🏃{Colors.END} " + f"{Colors.YELLOW}{unknown_count}?{Colors.END}") + + # Show detailed job info if requested + if detailed and (failed_count > 0 or running_count > 0 or unknown_count > 0): + for job_name, status, message in job_statuses: + if status != "SUCCESS": + color = Colors.RED if status == "FAILED" else Colors.YELLOW + print(f" {color}└─ {job_name}: {message}{Colors.END}") + + # Print summary + print(f"\n{Colors.PURPLE}{'='*80}{Colors.END}") + print(f"{Colors.BOLD}📊 Summary:{Colors.END}") + print(f" Total runs: {total_runs}") + print(f" {Colors.GREEN}✅ Successful: {successful_runs}{Colors.END}") + print(f" {Colors.RED}❌ Failed: {failed_runs}{Colors.END}") + print(f" {Colors.YELLOW}🏃 Running: {running_runs}{Colors.END}") + print(f" {Colors.YELLOW}⚠️ Partial/Other: {total_runs - successful_runs - failed_runs - running_runs}{Colors.END}") + +def main(): + """Main function""" + parser = argparse.ArgumentParser(description="SLURM Job Status Visualizer") + parser.add_argument("-n", "--num-runs", type=int, default=20, + help="Number of recent runs to show (default: 20)") + parser.add_argument("-d", "--detailed", action="store_true", + help="Show detailed information for failed/unknown jobs") + + args = parser.parse_args() + + try: + get_multirun_status(max_runs=args.num_runs, detailed=args.detailed) + except KeyboardInterrupt: + print(f"\n{Colors.YELLOW}Interrupted by user{Colors.END}") + except Exception as e: + print(f"{Colors.RED}Error: {str(e)}{Colors.END}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/attacks/beast.py b/src/attacks/beast.py index 0398c82..bdb031c 100644 --- a/src/attacks/beast.py +++ b/src/attacks/beast.py @@ -41,6 +41,7 @@ class BEASTConfig: allow_non_ascii: bool = False allow_special: bool = False use_prefix_cache: bool = True + mask_undecided_tokens: bool = False class BEASTAttack(Attack): @@ -73,6 +74,7 @@ def run(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataset) -> self.config.allow_non_ascii, self.config.allow_special ).to(model.device) + self.disallowed_ids = self.disallowed_ids[self.disallowed_ids < model.get_input_embeddings().weight.size(0)] for conversation in dataset: t0 = time.time() @@ -89,7 +91,7 @@ def run(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataset) -> prompts.append(conversation) # Compute KV cache for prefix tokens - if self.config.use_prefix_cache: + if self.config.use_prefix_cache and "gemma" not in model.name_or_path: self.populate_prefix_cache(model, pre_tokens, prompt_tokens) flops_prefix = get_flops(model, pre_tokens.numel() + prompt_tokens.numel(), 0, "forward") else: @@ -256,8 +258,11 @@ def get_perplexity( padded_attack_tokens_list.append(attack_tokens) # Replace original list with padded version attack_tokens_list = padded_attack_tokens_list - attention_mask = torch.zeros(T, dtype=torch.long, device=attack_tokens_list[0].device) - attention_mask[:T_cur] = 1 + if self.config.mask_undecided_tokens: + attention_mask = torch.zeros(T, dtype=torch.long, device=attack_tokens_list[0].device) + attention_mask[:T_cur] = 1 + else: + attention_mask = None if self.prefix_cache is not None: # With prefix cache, we don't need to include prefix tokens @@ -271,20 +276,24 @@ def get_perplexity( torch.cat([pre_tokens, prompt_tokens, attack_tokens, post_tokens, target_tokens]) for attack_tokens in attack_tokens_list ] - attention_mask = torch.cat([torch.ones(pre_tokens.size(0) + prompt_tokens.size(0)), attention_mask, torch.ones(post_tokens.size(0) + target_tokens.size(0))]) - attention_mask = attention_mask.to(model.device) + if attention_mask is not None: + attention_mask = torch.cat([torch.ones(pre_tokens.size(0) + prompt_tokens.size(0)), attention_mask, torch.ones(post_tokens.size(0) + target_tokens.size(0))]) + attention_mask = attention_mask[:-1].to(model.device) tensor = torch.stack(tokens_to_concat) def get_log_probs(target_tokens, attention_mask, x): + B = x.size(0) # Expand prefix cache to match batch size if available - cache = None if self.prefix_cache is not None: cache = copy.deepcopy(self.prefix_cache) - for i in range(len(cache)): - cache.key_cache[i] = cache.key_cache[i].expand(x.size(0), -1, -1, -1) - cache.value_cache[i] = cache.value_cache[i].expand(x.size(0), -1, -1, -1) - attention_mask = attention_mask.unsqueeze(0).repeat(x.size(0), 1).to(model.device) + for i in range(len(cache.key_cache)): + cache.key_cache[i] = cache.key_cache[i].expand(B, -1, -1, -1) + cache.value_cache[i] = cache.value_cache[i].expand(B, -1, -1, -1) + else: + cache = None + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0).repeat(B, 1).to(model.device) # Get logits and compute log probabilities logits = model(input_ids=x.to(model.device), past_key_values=cache, attention_mask=attention_mask).logits diff --git a/src/io_utils.py b/src/io_utils.py index ae2f8d6..0c5aa4d 100644 --- a/src/io_utils.py +++ b/src/io_utils.py @@ -513,14 +513,14 @@ def _gather(value, prefix: tuple[str], out): and store it under its full path. """ # leaf node: number or list of numbers - if isinstance(value, (int, float)) or isinstance(value, list) and isinstance(value[0], (int, float)): + if isinstance(value, (int, float, str)) or isinstance(value, list) and isinstance(value[0], (int, float, str)): if len(prefix) == 1: prefix = prefix[0] out[prefix].append(value) elif isinstance(value, dict): # keep descending for k, v in value.items(): _gather(v, prefix + (k,), out) - elif isinstance(value, list): # either a list of dicts or a list of numbers + elif isinstance(value, list): # list of containers if value and isinstance(value[0], (dict, list)): for v in value: _gather(v, prefix, out) # sub-lists of dicts From 5d5571b44af14e33158a1345798cadf0d467f4d3 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Mon, 23 Jun 2025 15:52:43 +0200 Subject: [PATCH 12/13] fix plots --- evaluate/distributional_paper/make_plots.py | 248 ++++++++++---------- 1 file changed, 123 insertions(+), 125 deletions(-) diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py index 022d433..70730de 100644 --- a/evaluate/distributional_paper/make_plots.py +++ b/evaluate/distributional_paper/make_plots.py @@ -141,7 +141,6 @@ def fetch_data(model: str, attack: str, attack_params: dict, dataset_idx: list[i paths = get_filtered_and_grouped_paths(filter_by, group_by) results = collect_results(paths, infer_sampling_flops=True) - print(group_by, filter_by, len(paths), len(results)) assert len(results) == 1, f"Should only have exactly one type of result, got {len(results)}, {list(results.keys())}" return list(results.values())[0] @@ -1187,7 +1186,6 @@ def ridge_plot( for value in step_data: ridge_data.append({'step': f'Step {step_idx}', 'p_harmful': value}) df = pd.DataFrame(ridge_data) - print(df) # Create ridge plot for p_harmful distributions across steps unique_steps = sorted(df['step'].unique(), key=lambda x: int(x.split()[1])) @@ -1535,16 +1533,16 @@ def run_analysis( # Attack-specific configuration ----------------------------------------------------- ATTACKS = [ - # ("pair", dict( - # title_suffix="PAIR", - # cumulative=True, - # sample_params=lambda: { - # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - # }, - # baseline_params=lambda: { - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # }, - # )), + ("pair", dict( + title_suffix="PAIR", + cumulative=True, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + )), ("beast", dict( title_suffix="BEAST", cumulative=False, @@ -1555,119 +1553,119 @@ def run_analysis( "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, }, )), - # ("autodan", dict( - # title_suffix="AutoDAN", - # cumulative=False, - # sample_params=lambda: { - # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - # "early_stopping_threshold": 0, - # }, - # baseline_params=lambda: { - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # }, - # )), - # ("gcg", dict( - # title_suffix="GCG", - # cumulative=False, - # sample_params=lambda: { - # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - # "num_steps": 250, - # "loss": "ce", - # "token_selection": "default", - # "use_prefix_cache": True, - # }, - # baseline_params=lambda: { - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # "num_steps": 250, - # "loss": "ce", - # "token_selection": "default", - # "use_prefix_cache": True, - # }, - # )), - # ("gcg", dict( - # title_suffix="GCG 500", - # cumulative=False, - # sample_params=lambda: { - # "generation_config": {"num_return_sequences": 500, "temperature": 0.7}, - # "num_steps": 250, - # "loss": "ce", - # "token_selection": "default", - # "use_prefix_cache": True, - # }, - # baseline_params=lambda: { - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # "num_steps": 250, - # "loss": "ce", - # "token_selection": "default", - # "use_prefix_cache": True, - # }, - # )), - # ("gcg", dict( - # title_suffix="GCG Entropy Loss", - # cumulative=False, - # sample_params=lambda: { - # "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, - # "num_steps": 250, - # "loss": "entropy_adaptive", - # "token_selection": "default", - # "use_prefix_cache": True, - # }, - # baseline_params=lambda: { - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # "num_steps": 250, - # "loss": "ce", - # "token_selection": "default", - # "use_prefix_cache": True, - # }, - # )), - # ("bon", dict( - # title_suffix="BoN", - # cumulative=False, - # sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 0.7}}, - # baseline_params=lambda: { - # # BoN's baseline is *Direct* with one deterministic sample - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # }, - # baseline_attack="direct", - # postprocess=lambda data, metric: data.__setitem__( - # metric, np.array(data[metric]).transpose(0, 2, 1) - # ), - # )), - # ("bon", dict( - # title_suffix="BoN Repro", - # cumulative=False, - # sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 1.0}}, - # baseline_params=lambda: { - # # BoN's baseline is *Direct* with one deterministic sample - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # }, - # baseline_attack="direct", - # postprocess=lambda data, metric: data.__setitem__( - # metric, np.array(data[metric]).transpose(0, 2, 1) - # ), - # )), - # ("direct", dict( - # title_suffix="Direct", - # cumulative=True, - # sample_params=lambda: { - # "generation_config": {"num_return_sequences": 1000, "temperature": 0.7}, - # }, - # baseline_params=lambda: { - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # }, - # skip_if_empty=True, # gracefully continue if no paths were found - # )), - # ("direct", dict( - # title_suffix="Direct temp 1.0", - # cumulative=True, - # sample_params=lambda: { - # "generation_config": {"num_return_sequences": 1000, "temperature": 1.0}, - # }, - # baseline_params=lambda: { - # "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, - # }, - # skip_if_empty=True, # gracefully continue if no paths were found - # )), + ("autodan", dict( + title_suffix="AutoDAN", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + "early_stopping_threshold": 0, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + )), + ("gcg", dict( + title_suffix="GCG", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + )), + ("gcg", dict( + title_suffix="GCG 500", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 500, "temperature": 0.7}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + )), + ("gcg", dict( + title_suffix="GCG Entropy Loss", + cumulative=False, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, + "num_steps": 250, + "loss": "entropy_adaptive", + "token_selection": "default", + "use_prefix_cache": True, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + "num_steps": 250, + "loss": "ce", + "token_selection": "default", + "use_prefix_cache": True, + }, + )), + ("bon", dict( + title_suffix="BoN", + cumulative=False, + sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 0.7}}, + baseline_params=lambda: { + # BoN's baseline is *Direct* with one deterministic sample + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + baseline_attack="direct", + postprocess=lambda data, metric: data.__setitem__( + metric, np.array(data[metric]).transpose(0, 2, 1) + ), + )), + ("bon", dict( + title_suffix="BoN Repro", + cumulative=False, + sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 1.0}}, + baseline_params=lambda: { + # BoN's baseline is *Direct* with one deterministic sample + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + baseline_attack="direct", + postprocess=lambda data, metric: data.__setitem__( + metric, np.array(data[metric]).transpose(0, 2, 1) + ), + )), + ("direct", dict( + title_suffix="Direct", + cumulative=True, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 1000, "temperature": 0.7}, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + skip_if_empty=True, # gracefully continue if no paths were found + )), + ("direct", dict( + title_suffix="Direct temp 1.0", + cumulative=True, + sample_params=lambda: { + "generation_config": {"num_return_sequences": 1000, "temperature": 1.0}, + }, + baseline_params=lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, + }, + skip_if_empty=True, # gracefully continue if no paths were found + )), ] METRIC = ("scores", "strong_reject", "p_harmful") From c3332e8ee640f20955e7e5ab6e2200236ae07dc6 Mon Sep 17 00:00:00 2001 From: Tim Beyer Date: Tue, 15 Jul 2025 15:38:05 +0200 Subject: [PATCH 13/13] update plots --- evaluate/distributional_paper/make_plots.py | 2491 +++++++++++++++---- 1 file changed, 1983 insertions(+), 508 deletions(-) diff --git a/evaluate/distributional_paper/make_plots.py b/evaluate/distributional_paper/make_plots.py index 70730de..9775d1c 100644 --- a/evaluate/distributional_paper/make_plots.py +++ b/evaluate/distributional_paper/make_plots.py @@ -10,6 +10,7 @@ from matplotlib.colors import LogNorm, PowerNorm from scipy.interpolate import griddata import logging +from matplotlib.ticker import MaxNLocator logging.basicConfig(level=logging.INFO) pd.set_option("display.max_colwidth", None) @@ -18,10 +19,10 @@ from src.io_utils import get_filtered_and_grouped_paths, collect_results, num_model_params - +s_harm_tex = r"$s_{harm}$" def generate_sample_sizes(total_samples: int) -> tuple[int, ...]: if total_samples < 1: - raise ValueError("total_samples must be ≥ 1") + return tuple() bases = (1, 2, 5) # 1-2-5 pattern for each power of ten result = [] power = 0 @@ -70,6 +71,29 @@ def _dominance_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np. return np.asarray(frontier_x), np.asarray(frontier_y) +def _non_cumulative_dominance_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Return all points ordered by cost, without dominance filtering. + This creates a non-cumulative frontier that includes all points. + + Parameters + ---------- + xs, ys : 1-D arrays of equal length + Coordinates of the candidate points. + + Returns + ------- + frontier_xs, frontier_ys : 1-D arrays + All points, sorted by xs ascending. + """ + order = np.argsort(xs) # sort by cost + xs_sorted, ys_sorted = xs[order], ys[order] + + frontier_x, frontier_y = [0, *xs_sorted], [0, *ys_sorted] + + return np.asarray(frontier_x), np.asarray(frontier_y) + + # ------------------------------------------------------------------ # 1. Empirical‑copula Pareto frontier (no Archimedean fit required) # ------------------------------------------------------------------ @@ -123,6 +147,9 @@ def _pareto_frontier(xs: np.ndarray, elif method == "basic": # your original dominance‑based frontier return _dominance_frontier(xs, ys) # rename old function + elif method == "non_cumulative": + # non-cumulative frontier that includes all points + return _non_cumulative_dominance_frontier(xs, ys) else: raise ValueError(f"Unknown frontier method '{method}'") @@ -270,21 +297,34 @@ def pareto_plot( else: x_interp = np.linspace(0, max_cost+1, n_x_points) - # Create figure with subplots: main plot + 2x2 grid on the right - fig = plt.figure(figsize=(18, 8)) + # Create figure with subplots: legend + main plot + 2x2 grid on the right + # fig = plt.figure(figsize=(5.4, 2.4)) # hero when slicing at bottom=0.1 + fig = plt.figure(figsize=(5.4, 2.8)) # hero when slicing at bottom=0.1 # Main Pareto plot (left half, spanning both rows) - ax1 = plt.subplot2grid((2, 4), (0, 0), colspan=2, rowspan=2) + ax1 = plt.subplot2grid((2, 3), (0, 1), colspan=2, rowspan=2) # ---------- scatter all points ---------- color_norm = setup_color_normalization(color_scale, n_samp) if plot_points: - sc = plt.scatter(cost, mean_p, c=n_samp, cmap="viridis", alpha=0.15, s=3, norm=color_norm) - plt.xlabel("Cost (FLOPS (optimization + sampling))", fontsize=14) - if threshold is None: - plt.ylabel(r"$\overline{p_{harmful}}$", fontsize=14) - else: - plt.ylabel(r"$\overline{{ASR}}\quad (p_{{harmful}} \geq {threshold})$".format(threshold=threshold), fontsize=14) + # Subsample points for plotting, considering logarithmic cost spacing + if len(cost) > 1000: + # Sample uniformly in log space + log_cost = np.log10(cost + 1e-10) + log_indices = np.argsort(log_cost) + step = len(log_indices) // 1000 + subsample_indices = log_indices[::step][:1000] + + cost_sub = cost[subsample_indices] + mean_p_sub = mean_p[subsample_indices] + n_samp_sub = n_samp[subsample_indices] + else: + cost_sub = cost + mean_p_sub = mean_p + n_samp_sub = n_samp + + sc = plt.scatter(cost_sub, mean_p_sub, c=n_samp_sub, cmap="viridis", alpha=0.15, s=3, norm=color_norm) + # ---------- overlay Pareto frontiers ---------- cmap = plt.get_cmap("viridis") @@ -310,24 +350,26 @@ def pareto_plot( fx, fy = _pareto_frontier(cost, mean_p, method=frontier_method) xs.append(fx) ys.append(fy) - y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) for x_, y_ in zip(xs, ys)] + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)] color = cmap(color_norm(j)) - y_mean = np.mean(y_interp, axis=0) - # Filter out leading zeros - nonzero_mask = y_mean > 0 + y_mean = np.nanmean(y_interp, axis=0) + # Filter out NaN values and zeros + valid_mask = ~np.isnan(y_mean) & (y_mean > 0) + x_pts = x_interp[valid_mask] + y_pts = y_mean[valid_mask] # Store data for bar charts frontier_data[j] = { - 'x': x_interp[nonzero_mask], - 'y': y_mean[nonzero_mask], + 'x': x_pts, + 'y': y_pts, 'color': color, - 'max_asr': np.max(y_mean[nonzero_mask]) if np.any(nonzero_mask) else 0 + 'max_asr': np.max(y_pts) if np.any(valid_mask) else 0 } plt.plot( - x_interp[nonzero_mask], - y_mean[nonzero_mask], + x_pts, + y_pts, marker="o", linewidth=1.8, markersize=2, @@ -353,17 +395,17 @@ def pareto_plot( xs.append(fx) ys.append(fy) - y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) for x_, y_ in zip(xs, ys)] - y_interps.append(np.mean(y_interp, axis=0)) + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)] + y_interps.append(np.nanmean(y_interp, axis=0)) y_interps = np.array(y_interps) - argmax = np.argmax(y_interps, axis=0) + argmax = np.nanargmax(y_interps, axis=0) argmax = np.maximum.accumulate(argmax) - y_envelope = np.max(y_interps, axis=0) + y_envelope = np.nanmax(y_interps, axis=0) - # Filter out leading zeros - nonzero_mask = y_envelope > 0 - color = [cmap(color_norm(argmax[i])) for i in range(len(argmax)) if nonzero_mask[i]] - plt.scatter(x_interp[nonzero_mask], y_envelope[nonzero_mask], c=color, s=2) + # Filter out NaN values and zeros + valid_mask = ~np.isnan(y_envelope) & (y_envelope > 0) + color = [cmap(color_norm(argmax[i])) for i in range(len(argmax)) if valid_mask[i]] + plt.scatter(x_interp[valid_mask], y_envelope[valid_mask], c=color, s=2) title_suffix = "" @@ -387,101 +429,81 @@ def pareto_plot( pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation, return_ratio=False, cumulative=cumulative) cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts.T + max_cost_baseline = max(cost_baseline) # ---------- overlay Pareto frontiers ---------- if plot_frontiers or plot_envelope: mask = n_samp_baseline == 1 fx, fy = _pareto_frontier(cost_baseline[mask], mean_p_baseline[mask], method=frontier_method) - y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=(0, max(fy)))(x_interp) - nonzero_mask_baseline = y_interp_baseline > 0 - + # if len(fx) > 1: + # fx[0] = min(f["x"][0] for f in frontier_data.values()) + # print(fx[0], fx[1]) + y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) + if max_cost_baseline / max_cost < 0.7: + max_cost_baseline = max_cost + valid_mask_baseline = ~np.isnan(y_interp_baseline) & (y_interp_baseline > 0) & (x_interp < max_cost_baseline) # Store baseline data for bar charts - baseline_max_asr = np.max(y_interp_baseline[nonzero_mask_baseline]) if np.any(nonzero_mask_baseline) else 0 + baseline_max_asr = np.max(y_interp_baseline[valid_mask_baseline]) if np.any(valid_mask_baseline) else 0 baseline_frontier_data = { - 'x': x_interp[nonzero_mask_baseline], - 'y': y_interp_baseline[nonzero_mask_baseline], + 'x': x_interp[valid_mask_baseline], + 'y': y_interp_baseline[valid_mask_baseline], 'max_asr': baseline_max_asr } - plt.plot( - x_interp[nonzero_mask_baseline], - y_interp_baseline[nonzero_mask_baseline], + x_interp[valid_mask_baseline], + y_interp_baseline[valid_mask_baseline], marker="o", linewidth=1.8, markersize=2, - label=f"Baseline", + label=f"Baseline (greedy)", color="r", ) - plt.title(title + title_suffix) + plt.xlabel("Total FLOPs", fontsize=13) + if threshold is None: + plt.ylabel(r"$s_{harm@n}$", fontsize=18) + else: + plt.ylabel(r"${ASR}@n$", fontsize=18) plt.grid(True, alpha=0.3) - plt.ylim(bottom=0) + # plt.ylim(bottom=0.1) plt.xscale(x_scale) - plt.legend(title="Frontiers", loc="upper left" if x_scale == "log" else "lower right") - - # ---------- Bar Chart 1: Max ASR Comparison (Vertical Slice) ---------- - ax2 = plt.subplot2grid((2, 4), (0, 2)) - - methods = [] - max_asrs = [] - colors = [] + if "autodan" in title.lower(): + loc = "lower right" + elif x_scale == "log": + loc = "upper left" + else: + loc = "lower right" + + handles, labels = plt.gca().get_legend_handles_labels() + # Create legend subplot and move legend there + ax0 = plt.subplot2grid((2, 3), (0, 0), colspan=1, rowspan=2) + ax0.axis('off') # Remove all axes + # Get legend from current plot and move to ax0 + handles = [*handles[:-1][::-1], handles[-1]] + labels = [*labels[:-1][::-1], labels[-1]] + ax0.legend(handles, labels, loc='center', fontsize=12) + plt.tight_layout() + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/pareto_plots/{title.replace(' ', '_')}.pdf") + else: + plt.savefig(f"evaluate/distributional_paper/pareto_plots/{title.replace(' ', '_')}_t={threshold}.pdf") + plt.close() - # Add baseline (delta = 0 for baseline) - if baseline_frontier_data is not None: - methods.append("Baseline") - max_asrs.append(0.0) # Delta from itself is 0 - colors.append("red") + fig = plt.figure(figsize=(6, 5)) - # Add sampling methods (calculate delta from baseline) - for j in sample_levels_to_plot: - if j in frontier_data: - methods.append(f"{j} samples" if j != 1 else "1 sample") - delta_asr = frontier_data[j]['max_asr'] - baseline_max_asr if baseline_frontier_data is not None else 0 - max_asrs.append(delta_asr) - colors.append(frontier_data[j]['color']) - - if methods: - bars = plt.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black') - if threshold is None: - plt.ylabel(r"$\Delta$ $p_{harmful}$", fontsize=14) - # plt.title(r"$p_{harmful}$ vs. \#samples", fontsize=14) - else: - # plt.title(r"$\overline{{ASR}}$ vs. \#samples".format(threshold=threshold), fontsize=14) - plt.ylabel(r"$\Delta$ $\overline{{ASR}}\quad (p_{{harmful}} \geq {threshold})$".format(threshold=threshold), fontsize=14) - plt.xticks(rotation=45, ha='right') - plt.grid(True, alpha=0.3, axis='y') - # Increase ylim by 2% on top and bottom - ymin, ymax = plt.ylim() - margin = (ymax - ymin) * 0.03 - plt.ylim(ymin - margin, ymax + margin) - - # ----- add labels with a 4-point gap ----- - for bar, value in zip(bars, max_asrs): - # choose label position: above for positive, below for negative - offset_pt = 4 # visual gap in points - va = 'bottom' if value >= 0 else 'top' - offset = (0, offset_pt if value >= 0 else -offset_pt) - - ax2.annotate(f'{value:.3f}', - xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), - xytext=offset, - textcoords='offset points', - ha='center', va=va, fontsize=10) - - # ---------- Bar Chart 2: FLOPS Efficiency to Reach Baseline ASR (Horizontal Slice) ---------- - ax3 = plt.subplot2grid((2, 4), (0, 3)) + bar_chart_margin_multiplier = 5 if baseline_frontier_data is not None and baseline_max_asr > 0: methods_flops = [] flops_required = [] colors_flops = [] - # Find FLOPS required to reach baseline ASR for each sampling method + # Find FLOPs required to reach baseline ASR for each sampling method target_asr = baseline_max_asr for j in sample_levels_to_plot: if j in frontier_data: - # Find the minimum FLOPS where ASR >= target_asr + # Find the minimum FLOPs where ASR >= target_asr y_vals = frontier_data[j]['y'] x_vals = frontier_data[j]['x'] @@ -493,24 +515,82 @@ def pareto_plot( flops_required.append(min_flops) colors_flops.append(frontier_data[j]['color']) - # Add baseline (find minimum FLOPS where it reaches target ASR) + # Add baseline (find minimum FLOPs where it reaches target ASR) if baseline_frontier_data['x'].size > 0: - # Find the minimum FLOPS where baseline ASR >= target_asr + # Find the minimum FLOPs where baseline ASR >= target_asr baseline_y_vals = baseline_frontier_data['y'] baseline_x_vals = baseline_frontier_data['x'] baseline_valid_indices = baseline_y_vals >= target_asr if np.any(baseline_valid_indices): baseline_flops = np.min(baseline_x_vals[baseline_valid_indices]) else: - # Fallback to minimum FLOPS if no point reaches target ASR + # Fallback to minimum FLOPs if no point reaches target ASR baseline_flops = np.min(baseline_x_vals) methods_flops.insert(0, "Baseline") flops_required.insert(0, baseline_flops) colors_flops.insert(0, "red") + else: + methods_flops = [] + flops_required = [] + colors_flops = [] + + # ---------- Bar Chart 1: Max ASR Comparison (Vertical Slice) ---------- + def add_asr_bar_chart(): + ax2 = plt.subplot2grid((2, 2), (0, 0)) + + methods = [] + max_asrs = [] + colors = [] + + # Add baseline (delta = 0 for baseline) + if baseline_frontier_data is not None: + methods.append("Baseline") + max_asrs.append(0.0) # Delta from itself is 0 + colors.append("red") + + # Add sampling methods (calculate delta from baseline) + for j in sample_levels_to_plot: + if j in frontier_data: + methods.append(f"{j} samples" if j != 1 else "1 sample") + delta_asr = frontier_data[j]['max_asr'] - baseline_max_asr if baseline_frontier_data is not None else 0 + max_asrs.append(delta_asr) + colors.append(frontier_data[j]['color']) + + if methods: + bars = plt.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black') + if threshold is None: + plt.ylabel(r"$\Delta$ $s_{harm@n}$" , fontsize=17) + else: + plt.ylabel(r"$\Delta$ ${ASR}@n$", fontsize=17) + plt.xticks(rotation=45, ha='right') + plt.grid(True, alpha=0.3, axis='y') + # Increase ylim by 2% on top and bottom + ymin, ymax = plt.ylim() + margin = (ymax - ymin) * 0.03 * bar_chart_margin_multiplier + plt.ylim(ymin - margin, ymax + margin) + + # ----- add labels with a 4-point gap ----- + for bar, value in zip(bars, max_asrs): + # choose label position: above for positive, below for negative + offset_pt = 4 # visual gap in points + va = 'bottom' if value >= 0 else 'top' + offset = (0, offset_pt if value >= 0 else -offset_pt) + + ax2.annotate(f'{value:.2f}', + xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), + xytext=offset, + textcoords='offset points', + ha='center', va=va, fontsize=10) + add_asr_bar_chart() + + + # ---------- Bar Chart 2: FLOPs Efficiency to Reach Baseline ASR (Horizontal Slice) ---------- + def add_flops_bar_chart(): + ax3 = plt.subplot2grid((2, 2), (0, 1)) if methods_flops: bars = plt.bar(methods_flops, flops_required, color=colors_flops, alpha=0.7, edgecolor='black') - plt.ylabel(r"FLOPS for Baseline $p_{harmful}$" + f" ( = {target_asr:.3f})", fontsize=12) + plt.ylabel("FLOPs to match baseline", fontsize=12) plt.xticks(rotation=45, ha='right') plt.yscale('log') plt.grid(True, alpha=0.3, axis='y') @@ -521,77 +601,756 @@ def pareto_plot( plt.ylim(ymin, ymax * (1+margin)) # --- constant 5-point vertical gap --- - for bar, value in zip(bars, flops_required): - ax3.annotate(f'{value:.2e}', - xy=(bar.get_x() + bar.get_width()/2, value), # anchor at top of bar - xytext=(0, 5), # 5 points straight up - textcoords='offset points', - ha='center', va='bottom', rotation=45, fontsize=9) - - # ---------- Bar Chart 3: Speedup vs Baseline (Bottom Right) ---------- - ax4 = plt.subplot2grid((2, 4), (1, 3)) - - # Create speedup plot - speedup_methods = [] - speedups = [] - speedup_colors = [] - - # Calculate speedup for each method (baseline_flops / method_flops) - baseline_flops = flops_required[0] if methods_flops[0] == "Baseline" else None - - if baseline_flops is not None: - for i, (method, flops, color) in enumerate(zip(methods_flops, flops_required, colors_flops)): - if method != "Baseline": # Skip baseline itself - speedup = baseline_flops / flops if flops > 0 else 0 - speedup_methods.append(method) - speedups.append(speedup) - speedup_colors.append(color) - - if speedup_methods: - bars = plt.bar(speedup_methods, speedups, color=speedup_colors, alpha=0.7, edgecolor='black') - plt.ylabel("Speedup (FLOPS) vs Baseline", fontsize=12) + # for bar, value in zip(bars, flops_required): + # ax3.annotate(f'{value:.2e}', + # xy=(bar.get_x() + bar.get_width()/2, value), # anchor at top of bar + # xytext=(0, 5), # 5 points straight up + # textcoords='offset points', + # ha='center', va='bottom', rotation=45, fontsize=9) + + add_flops_bar_chart() + + # ---------- Bar Chart 3: Speedup vs Baseline (Bottom Left) ---------- + def add_speedup_bar_chart(): + ax4 = plt.subplot2grid((2, 2), (1, 0)) + + # Create speedup plot + speedup_methods = [] + speedups = [] + speedup_colors = [] + + # Calculate speedup for each method (baseline_flops / method_flops) + baseline_flops = flops_required[0] if methods_flops and methods_flops[0] == "Baseline" else None + + if baseline_flops is not None: + for i, (method, flops, color) in enumerate(zip(methods_flops, flops_required, colors_flops)): + if method != "Baseline": # Skip baseline itself + speedup = baseline_flops / flops if flops > 0 else 0 + speedup_methods.append(method) + speedups.append(speedup) + speedup_colors.append(color) + + if speedup_methods: + bars = plt.bar(speedup_methods, speedups, color=speedup_colors, alpha=0.7, edgecolor='black') + plt.ylabel("Speedup (FLOPs)", fontsize=12) + plt.xticks(rotation=45, ha='right') + plt.grid(True, alpha=0.3, axis='y') + + # Add horizontal line at y=1 for reference + plt.axhline(y=1, color='red', linestyle='--', alpha=0.7, linewidth=1) + + # Increase ylim by small margin + ymin, ymax = plt.ylim() + margin = (ymax - ymin) * 0.05 * bar_chart_margin_multiplier + plt.ylim(max(0, ymin - margin), ymax + margin) + + # Add value labels on bars + for bar, value in zip(bars, speedups): + ax4.annotate(f'{value:.1f}x', + xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), + xytext=(0, 5), + textcoords='offset points', + ha='center', va='bottom', fontsize=10) + + add_speedup_bar_chart() + + # --------- Base Plot 4: ASR @ max greedy FLOPs (vertical slice), bottom right ---------- + def add_asr_at_max_greedy_flops_bar_chart(): + ax5 = plt.subplot2grid((2, 2), (1, 1)) + + methods = [] + max_asrs = [] + colors = [] + + # Add baseline (delta = 0 for baseline) + if baseline_frontier_data is not None: + methods.append("Baseline") + max_asrs.append(0.0) # Delta from itself is 0 + colors.append("red") + + # Add sampling methods (calculate delta from baseline) + for j in sample_levels_to_plot: + if j in frontier_data: + methods.append(f"{j} samples" if j != 1 else "1 sample") + if baseline_frontier_data["x"].size == 0: + continue + baseline_max_flops = baseline_frontier_data['x'][-1] + x_idx_of_same_flops_as_baseline = np.argmax(frontier_data[j]['x'] > baseline_max_flops) - 1 + delta_asr = frontier_data[j]['y'][x_idx_of_same_flops_as_baseline] - baseline_max_asr if baseline_frontier_data is not None else 0 + max_asrs.append(delta_asr) + colors.append(frontier_data[j]['color']) + + if methods: + bars = plt.bar(methods, max_asrs, color=colors, alpha=0.7, edgecolor='black') + if threshold is None: + plt.ylabel(r"$\Delta$ $s_{harm@n}$" , fontsize=14) + else: + plt.ylabel(r"$\Delta$ ${ASR}@n$", fontsize=14) plt.xticks(rotation=45, ha='right') plt.grid(True, alpha=0.3, axis='y') - - # Add horizontal line at y=1 for reference - plt.axhline(y=1, color='red', linestyle='--', alpha=0.7, linewidth=1) - - # Increase ylim by small margin + # Increase ylim by 2% on top and bottom ymin, ymax = plt.ylim() - margin = (ymax - ymin) * 0.05 - plt.ylim(max(0, ymin - margin), ymax + margin) + margin = (ymax - ymin) * 0.03 * bar_chart_margin_multiplier + plt.ylim(ymin - margin, ymax + margin) + + # ----- add labels with a 4-point gap ----- + for bar, value in zip(bars, max_asrs): + # choose label position: above for positive, below for negative + offset_pt = 4 # visual gap in points + va = 'bottom' if value >= 0 else 'top' + offset = (0, offset_pt if value >= 0 else -offset_pt) - # Add value labels on bars - for bar, value in zip(bars, speedups): - ax4.annotate(f'{value:.2f}x', + ax5.annotate(f'{value:.2f}', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), - xytext=(0, 5), + xytext=offset, textcoords='offset points', - ha='center', va='bottom', fontsize=10) + ha='center', va=va, fontsize=10) + add_asr_at_max_greedy_flops_bar_chart() + # # ---------- Line Plot 4: Continuous FLOPs to Reach Baseline ASR (Bottom Left) ---------- + # ax5 = plt.subplot2grid((2, 5), (1, 3)) + + # if baseline_frontier_data is not None and baseline_max_asr > 0: + # target_asr = baseline_max_asr + + # # Generate continuous range of sample counts + # sample_range = range(1, n_total_samples + 1) + # continuous_flops = [] + # continuous_samples = [] + + # # Calculate frontier data for all sample counts (not just sample_levels_to_plot) + # rng_continuous = np.random.default_rng() + # n_smoothing_continuous = 10 # Reduced for performance + + # for j in sample_range: + # xs = [] + # ys = [] + # for _ in range(n_smoothing_continuous): + # pts = [] + # for i in range(0, n_steps, 1): + # pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, + # flops_sampling_prefill_cache, flops_sampling_generation, rng_continuous)) + + # pts = np.asarray(pts) + # cost, _, _, mean_p = pts.T + + # fx, fy = _pareto_frontier(cost, mean_p, method=frontier_method) + # xs.append(fx) + # ys.append(fy) + + # # Interpolate and average + # y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) + # for x_, y_ in zip(xs, ys)] + # y_mean = np.mean(y_interp, axis=0) + + # # Find minimum FLOPs where ASR >= target_asr + # nonzero_mask = y_mean > 0 + # if np.any(nonzero_mask): + # y_vals = y_mean[nonzero_mask] + # x_vals = x_interp[nonzero_mask] + + # valid_indices = y_vals >= target_asr + # if np.any(valid_indices): + # min_flops = np.min(x_vals[valid_indices]) + # continuous_flops.append(min_flops) + # continuous_samples.append(j) + + # if continuous_flops: + # # Plot the continuous line + # plt.plot(continuous_samples, continuous_flops, 'b-', linewidth=2, alpha=0.8, label='All Samples') + + # # Highlight the baseline point + # if baseline_frontier_data['x'].size > 0: + # baseline_y_vals = baseline_frontier_data['y'] + # baseline_x_vals = baseline_frontier_data['x'] + # baseline_valid_indices = baseline_y_vals >= target_asr + # if np.any(baseline_valid_indices): + # baseline_flops = np.min(baseline_x_vals[baseline_valid_indices]) + # plt.axhline(y=baseline_flops, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Baseline') + + # # Highlight the discrete sample levels from the bar chart + # for j in sample_levels_to_plot: + # if j in [s for s in continuous_samples]: + # idx = continuous_samples.index(j) + # color = cmap(color_norm(j)) + # plt.scatter(j, continuous_flops[idx], color=color, s=60, alpha=0.9, + # edgecolors='black', linewidth=0.5, zorder=5) + + # plt.xlabel("Number of Samples", fontsize=12) + # plt.ylabel("FLOPs to Reach Baseline ASR", fontsize=12) + # plt.xscale('log') + # plt.yscale('log') + # plt.grid(True, alpha=0.3) + # plt.legend(fontsize=10) + + # # Set reasonable x-axis limits + # plt.xlim(1, n_total_samples) + + # # Increase ylim by small margin + # ymin, ymax = plt.ylim() + # import math + # margin = ((math.log10(ymax) - math.log10(ymin)) * 0.1) + # plt.ylim(ymin / (1+margin), ymax * (1+margin)) - # ---------- Line Plot 4: Continuous FLOPS to Reach Baseline ASR (Bottom Left) ---------- - ax5 = plt.subplot2grid((2, 4), (1, 2)) + plt.tight_layout() + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/bar_charts/{title.replace(' ', '_')}.pdf") + else: + plt.savefig(f"evaluate/distributional_paper/bar_charts/{title.replace(' ', '_')}_t={threshold}.pdf") + plt.close() + # Create a separate figure for just the legend + fig_legend = plt.figure(figsize=(4, 1)) + ax_legend = fig_legend.add_subplot(111) + ax_legend.axis('off') + + # Create legend elements for sample levels + legend_elements = [] + # Add baseline if it exists + if baseline is not None: + legend_elements.append(plt.Line2D([0], [0], color="red", linewidth=2, + label="Baseline (Greedy)")) + cmap = plt.get_cmap("viridis") + color_norm = setup_color_normalization("linear", np.array(sample_levels_to_plot)) - if baseline_frontier_data is not None and baseline_max_asr > 0: - target_asr = baseline_max_asr + for j in sample_levels_to_plot: + if j in frontier_data: + color = cmap(color_norm(j)) + legend_elements.append(plt.Line2D([0], [0], color=color, linewidth=2, + label=f"{j} samples")) - # Generate continuous range of sample counts - sample_range = range(1, n_total_samples + 1) - continuous_flops = [] - continuous_samples = [] - # Calculate frontier data for all sample counts (not just sample_levels_to_plot) - rng_continuous = np.random.default_rng() - n_smoothing_continuous = 10 # Reduced for performance + # Create horizontal legend + ax_legend.legend(handles=legend_elements, loc='center', ncol=len(legend_elements), + fontsize=10, frameon=False, columnspacing=1.0, handletextpad=0.5) + + + plt.tight_layout() + plt.savefig(f"evaluate/distributional_paper/pareto_plots/legend_{n_total_samples}.pdf", bbox_inches='tight') + plt.close() + - for j in sample_range: +def non_cumulative_pareto_plot( + results: dict[str,np.ndarray], + baseline: dict[str,np.ndarray] | None = None, + title: str = "Non-Cumulative Pareto Frontier", + sample_levels_to_plot: tuple[int, ...]|None = None, + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + plot_points: bool = False, + plot_frontiers: bool = True, + plot_envelope: bool = False, + plot_baseline: bool = False, + verbose: bool = True, + flops_per_step: int | None = None, + n_x_points: int = 10000, + x_scale="linear", + threshold: float|None = None, + color_scale: str = "linear", +): + """ + Scatter the full design-space AND overlay non-cumulative Pareto frontiers + for selected sampling counts. Uses the non_cumulative frontier method + which includes all points without dominance filtering. + """ + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, flops_per_step + ) + n_runs, n_steps, n_total_samples = y.shape + if sample_levels_to_plot is None: + sample_levels_to_plot = generate_sample_sizes(n_total_samples) + + pts = get_points(y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, + return_ratio=False, cumulative=False) + cost, step_idx, n_samp, mean_p = pts.T + max_step = max(step_idx) + if x_scale == "log": + x_interp = np.logspace(0, np.log10(max_step+1), n_x_points) + else: + x_interp = np.linspace(0, max_step+1, n_x_points) + + # Create figure with subplots: legend + main plot + 2x2 grid on the right + # fig = plt.figure(figsize=(5.4, 2.4)) # hero when slicing at bottom=0.1 + fig = plt.figure(figsize=(5.4, 2.8)) # hero when slicing at bottom=0.1 + + # Main Pareto plot (left half, spanning both rows) + ax1 = plt.subplot2grid((2, 3), (0, 1), colspan=2, rowspan=2) + + # ---------- scatter all points ---------- + color_norm = setup_color_normalization(color_scale, n_samp) + if plot_points: + # Subsample points for plotting, considering step spacing + if len(step_idx) > 1000: + # Sample uniformly in step space + step_indices = np.argsort(step_idx) + step = len(step_indices) // 1000 + subsample_indices = step_indices[::step][:1000] + + step_idx_sub = step_idx[subsample_indices] + mean_p_sub = mean_p[subsample_indices] + n_samp_sub = n_samp[subsample_indices] + else: + step_idx_sub = step_idx + mean_p_sub = mean_p + n_samp_sub = n_samp + + sc = plt.scatter(step_idx_sub, mean_p_sub, c=n_samp_sub, cmap="viridis", alpha=0.15, s=3, norm=color_norm) + + + # ---------- overlay non-cumulative Pareto frontiers ---------- + cmap = plt.get_cmap("viridis") + rng = np.random.default_rng() + + n_smoothing = 50 + frontier_data = {} # Store frontier data for bar charts + + if plot_frontiers: + # Only plot the maximum number of samples frontier + j = n_total_samples + xs = [] + ys = [] + n_smoothing = 1 # Use single smoothing for max samples + for _ in range(n_smoothing): + pts = [] + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, j, False, y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, rng)) + + pts = np.asarray(pts) + cost, step_idx_pts, _, mean_p = pts.T + + fx, fy = _pareto_frontier(step_idx_pts, mean_p, method="non_cumulative") + xs.append(fx) + ys.append(fy) + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)] + + color = cmap(color_norm(j)) + y_mean = np.nanmean(y_interp, axis=0) + # Filter out NaN values and zeros + valid_mask = ~np.isnan(y_mean) & (y_mean > 0) + + x_pts = x_interp[valid_mask] + y_pts = y_mean[valid_mask] + # Store data for bar charts + frontier_data[j] = { + 'x': x_pts, + 'y': y_pts, + 'color': color, + 'max_asr': np.max(y_pts) if np.any(valid_mask) else 0 + } + + plt.plot( + x_pts, + y_pts, + linewidth=1.2, + label="Steps", + color=color, + ) + + if plot_envelope: + n_smoothing = n_total_samples + y_interps = [] + for j in range(1, n_total_samples+1): xs = [] ys = [] - for _ in range(n_smoothing_continuous): + for n in range(n_smoothing): pts = [] for i in range(0, n_steps, 1): - pts.append(subsample_and_aggregate(i, j, cumulative, y, flops_optimization, - flops_sampling_prefill_cache, flops_sampling_generation, rng_continuous)) + pts.append(subsample_and_aggregate(i, j, False, y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, rng)) + + pts = np.asarray(pts) + cost, step_idx_pts, n_samp, mean_p = pts.T + + fx, fy = _pareto_frontier(step_idx_pts, mean_p, method="non_cumulative") + xs.append(fx) + ys.append(fy) + + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)] + y_interps.append(np.nanmean(y_interp, axis=0)) + y_interps = np.array(y_interps) + argmax = np.nanargmax(y_interps, axis=0) + argmax = np.maximum.accumulate(argmax) + y_envelope = np.nanmax(y_interps, axis=0) + + # Filter out NaN values and zeros + valid_mask = ~np.isnan(y_envelope) & (y_envelope > 0) + color = [cmap(color_norm(argmax[i])) for i in range(len(argmax)) if valid_mask[i]] + plt.scatter(x_interp[valid_mask], y_envelope[valid_mask], c=color, s=2) + + title_suffix = "" + + # Handle baseline data + baseline_max_asr = 0 + baseline_frontier_data = None + + if baseline is not None and plot_baseline: + y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data( + baseline, metric, threshold, flops_per_step + ) + + if y_baseline is not None: + title_suffix = f" ({n_runs}, {y_baseline.shape[0]})" + if verbose: + logging.info(f"{n_runs} for main") + logging.info(f"{y_baseline.shape[0]} for baseline") + n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape + assert n_total_samples_baseline == 1 + + pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, + baseline_flops_sampling_generation, return_ratio=False, cumulative=False) + cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts.T + max_step_baseline = max(step_idx_baseline) + + # ---------- overlay Pareto frontiers ---------- + if plot_frontiers or plot_envelope: + mask = n_samp_baseline == 1 + fx, fy = _pareto_frontier(step_idx_baseline[mask], mean_p_baseline[mask], method="non_cumulative") + y_interp_baseline = interp1d(fx, fy, kind="previous", bounds_error=False, fill_value=np.nan)(x_interp) + if max_step_baseline / max_step < 0.7: + max_step_baseline = max_step + valid_mask_baseline = ~np.isnan(y_interp_baseline) & (y_interp_baseline > 0) & (x_interp < max_step_baseline) + # Store baseline data for bar charts + baseline_max_asr = np.max(y_interp_baseline[valid_mask_baseline]) if np.any(valid_mask_baseline) else 0 + baseline_frontier_data = { + 'x': x_interp[valid_mask_baseline], + 'y': y_interp_baseline[valid_mask_baseline], + 'max_asr': baseline_max_asr + } + plt.plot( + x_interp[valid_mask_baseline], + y_interp_baseline[valid_mask_baseline], + linewidth=1.2, + label=f"Baseline", + color="r", + ) + + plt.xlabel("Optimization Steps", fontsize=13) + if threshold is None: + plt.ylabel(r"$s_{harm@n}$", fontsize=18) + else: + plt.ylabel(r"${ASR}@n$", fontsize=18) + plt.grid(True, alpha=0.3) + # plt.ylim(bottom=0.1) + plt.xscale(x_scale) + if "autodan" in title.lower(): + loc = "lower right" + elif x_scale == "log": + loc = "upper left" + else: + loc = "lower right" + + handles, labels = plt.gca().get_legend_handles_labels() + # Create legend subplot and move legend there + ax0 = plt.subplot2grid((2, 3), (0, 0), colspan=1, rowspan=2) + ax0.axis('off') # Remove all axes + # Get legend from current plot and move to ax0 + if len(handles) > 1: + # If we have baseline, put it last + if plot_baseline and len(handles) > 1: + handles = [*handles[:-1][::-1], handles[-1]] + labels = [*labels[:-1][::-1], labels[-1]] + else: + handles = handles[::-1] + labels = labels[::-1] + ax0.legend(handles, labels, loc='center', fontsize=12) + plt.tight_layout() + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/non_cumulative_pareto_plots/{title.replace(' ', '_')}.pdf") + else: + plt.savefig(f"evaluate/distributional_paper/non_cumulative_pareto_plots/{title.replace(' ', '_')}_t={threshold}.pdf") + plt.close() + + # Note: Skipping the bar charts for non-cumulative version to keep it simple + # The non-cumulative version is primarily for visualization of all points + + if verbose: + logging.info(f"Non-cumulative Pareto plot saved for {title}") + + +def multi_attack_non_cumulative_pareto_plot( + attacks_data: dict, # {attack_name: (results_dict, config)} + model_title: str, + title: str = "Multi-Attack Non-Cumulative Pareto", + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + threshold: float|None = None, + n_x_points: int = 10000, + x_scale: str = "linear", + verbose: bool = True, +): + """ + Create a non-cumulative Pareto plot showing multiple attacks on the same axes. + Each attack shows its frontier with 50 samples. + """ + + # Color scheme for attacks + attack_colors = { + "gcg": "#1f77b4", # blue + "autodan": "#ff7f0e", # orange + "beast": "#2ca02c", # green + "pair": "#d62728", # red + "bon": "#9467bd", # purple + "direct": "#8c564b", # brown + } + + plt.figure(figsize=(4, 3)) + + # Filter to only show specific attacks + desired_attacks = {"PAIR", "BEAST", "AutoDAN", "GCG"} + filtered_attacks_data = {} + + for config_key, (results, config) in attacks_data.items(): + if config_key in desired_attacks: + filtered_attacks_data[config_key] = (results, config) + + attacks_data = filtered_attacks_data + + if not attacks_data: + logging.warning("No desired attacks found in data") + return + + # Use percentage of optimization steps (0-100%) as x-axis + x_interp = np.linspace(0, 100, n_x_points) + + # Process each attack + rng = np.random.default_rng(42) + + for config_key, (results, config) in attacks_data.items(): + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, None + ) + n_runs, n_steps, n_total_samples = y.shape + + # Extract original attack name from config for color mapping + original_attack_name = None + for atk_name, atk_cfg in ATTACKS: + if atk_cfg.get('title_suffix') == config_key: + original_attack_name = atk_name + break + + color = attack_colors.get(original_attack_name, "black") + + # Use 50 samples for each attack + target_samples = 1#min(50, n_total_samples) + n_smoothing = 1 # Single smoothing for cleaner lines + xs = [] + ys = [] + + for _ in range(n_smoothing): + pts = [] + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, target_samples, False, y, + flops_optimization, flops_sampling_prefill_cache, + flops_sampling_generation, rng)) + + pts = np.asarray(pts) + cost, step_idx_pts, _, mean_p = pts.T + + # Convert step indices to percentages (0-100%) + step_percentages = (step_idx_pts / (n_steps - 1)) * 100 + + fx, fy = _pareto_frontier(step_percentages, mean_p, method="non_cumulative") + xs.append(fx) + ys.append(fy) + + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, + fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)] + y_mean = np.nanmean(y_interp, axis=0) + + # Calculate delta from first step's value + # Find the first valid (non-NaN) value as baseline + valid_indices = ~np.isnan(y_mean) + if np.any(valid_indices): + first_valid_value = y_mean[valid_indices][0] + y_delta = y_mean - first_valid_value + + # Filter out NaN values + valid_mask = ~np.isnan(y_delta) + if np.any(valid_mask): + label = f"{config_key}" + + plt.plot(x_interp[valid_mask], y_delta[valid_mask], + linewidth=1.2, + label=label, color=color) + + plt.xlabel(r"Optimization Progress (\%)", fontsize=15) + if threshold is None: + plt.ylabel(r"$\Delta$ $s_{harm@1}$", fontsize=16) + else: + plt.ylabel(r"$\Delta$ ${ASR}@1$", fontsize=16) + + plt.grid(True, alpha=0.3) + plt.xlim(0, 100) # Set x-axis limits to 0-100% + plt.axhline(y=0, color='black', linestyle='--', alpha=0.5, linewidth=1) # Reference line at delta=0 + plt.title(f"{model_title}", fontsize=15) + + # Get legend handles and labels before saving main plot + handles, labels = plt.gca().get_legend_handles_labels() + # Sort handles and labels alphabetically by labels + sorted_pairs = sorted(zip(handles, labels), key=lambda x: x[1]) + handles, labels = zip(*sorted_pairs) if sorted_pairs else ([], []) + + plt.tight_layout() + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/multi_attack_non_cumulative_pareto_plots/{title.replace(' ', '_')}.pdf") + else: + plt.savefig(f"evaluate/distributional_paper/multi_attack_non_cumulative_pareto_plots/{title.replace(' ', '_')}_t={threshold}.pdf") + plt.close() + + # Create a separate figure for just the legend + fig_legend = plt.figure(figsize=(4, 1)) + ax_legend = fig_legend.add_subplot(111) + ax_legend.axis('off') + + # Create horizontal legend + ax_legend.legend(handles=handles, loc='center', ncol=2, + fontsize=12, frameon=False, columnspacing=1.0, handletextpad=0.5) + + plt.tight_layout() + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/multi_attack_non_cumulative_pareto_plots/legend_{title.replace(' ', '_')}.pdf", bbox_inches='tight') + else: + plt.savefig(f"evaluate/distributional_paper/multi_attack_non_cumulative_pareto_plots/legend_{title.replace(' ', '_')}_t={threshold}.pdf", bbox_inches='tight') + plt.close() + + if verbose: + logging.info(f"Multi-attack non-cumulative Pareto plot saved for {model_title}") + logging.info(f"Legend saved separately") + + +def comparative_pareto_plot( + model: str, + model_title: str, + attacks_data: dict, # {attack_name: (results_dict, config)} + title: str = "Comparative Pareto Analysis", + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + threshold: float|None = None, + n_x_points: int = 10000, + x_scale: str = "log", + flops_per_step_fns: dict = None, # {attack_name: flops_per_step_fn} + frontier_method: str = "basic", + verbose: bool = True, + baseline_attacks: set = {"gcg", "beast", "pair", "autodan"}, # Attacks to show baseline points for +): + """ + Create a comparative Pareto plot showing multiple attacks on the same axes. + - For gcg, autodan, beast, pair: use 50-sample frontier + - For bon, direct: use envelope curve + """ + + # Define which attacks use which approach + envelope_attacks = {"bon", "direct"} + + # Color scheme for attacks + attack_colors = { + "gcg": "#1f77b4", # blue + "autodan": "#ff7f0e", # orange + "beast": "#2ca02c", # green + "pair": "#d62728", # red + "bon": "#9467bd", # purple + "direct": "#8c564b", # brown + } + + plt.figure(figsize=(8, 5)) + + # Calculate overall x-axis range from all attacks + all_costs = [] + for config_key, (results, config) in attacks_data.items(): + flops_per_step_fn = flops_per_step_fns.get(config_key) if flops_per_step_fns else None + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, flops_per_step_fn + ) + pts = get_points(y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation, + return_ratio=False, cumulative=config.get("cumulative", False)) + cost, _, _, _ = pts.T + all_costs.extend(cost) + + max_cost = max(all_costs) + if x_scale == "log": + x_interp = np.logspace(11, np.log10(max_cost)+0.001, n_x_points) + else: + x_interp = np.linspace(0, max_cost+1, n_x_points) + + # Process each attack + rng = np.random.default_rng(42) + + for config_key, (results, config) in attacks_data.items(): + flops_per_step_fn = flops_per_step_fns.get(config_key) if flops_per_step_fns else None + y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( + results, metric, threshold, flops_per_step_fn + ) + n_runs, n_steps, n_total_samples = y.shape + + # Extract original attack name from config for color mapping + # We need to find the original attack name by looking at the ATTACKS list + original_attack_name = None + for atk_name, atk_cfg in ATTACKS: + if atk_cfg.get('title_suffix') == config_key: + original_attack_name = atk_name + break + + color = attack_colors.get(original_attack_name, "black") + + if original_attack_name in envelope_attacks: + # Use envelope approach with more permissive extrapolation + n_smoothing = 10 # Limit smoothing for performance + y_interps = [] + all_xs = [] # Collect all x values to determine meaningful range + + for j in range(1, n_total_samples+1): + xs = [] + ys = [] + for n in range(n_smoothing): + pts = [] + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, j, config.get("cumulative", False), y, + flops_optimization, flops_sampling_prefill_cache, + flops_sampling_generation, rng)) + + pts = np.asarray(pts) + cost, _, _, mean_p = pts.T + + fx, fy = _pareto_frontier(cost, mean_p, method=frontier_method) + xs.append(fx) + ys.append(fy) + + # Collect all x values for range determination + for x_ in xs: + if len(x_) > 0: + all_xs.extend(x_) + + # For envelope attacks, use 0 fill for left side and last value for right side to avoid gaps + y_interp = [] + for x_, y_ in zip(xs, ys): + if len(x_) > 0 and len(y_) > 0: + interp_func = interp1d(x_, y_, kind="previous", bounds_error=False, + fill_value=(0, y_[-1])) + y_interp.append(interp_func(x_interp)) + else: + y_interp.append(np.zeros_like(x_interp)) + + y_interps.append(np.mean(y_interp, axis=0) if y_interp else np.zeros_like(x_interp)) + + y_interps = np.array(y_interps) + y_envelope = np.max(y_interps, axis=0) + + # For envelope, only filter out leading zeros, but cap at reasonable x-range + max_meaningful_x = np.max(all_xs) if all_xs else x_interp[-1] + valid_mask = (y_envelope > 0) & (x_interp <= max_meaningful_x * 1.1) # Allow 10% extension + + if np.any(valid_mask): + plt.plot(x_interp[valid_mask], y_envelope[valid_mask], + marker="o", linewidth=2.5, markersize=3, + label=config_key, color=color) + + else: + # Use 50-sample frontier approach + target_samples = n_total_samples + n_smoothing = 50 + xs = [] + ys = [] + + for _ in range(n_smoothing): + pts = [] + for i in range(0, n_steps, 1): + pts.append(subsample_and_aggregate(i, target_samples, config.get("cumulative", False), y, + flops_optimization, flops_sampling_prefill_cache, + flops_sampling_generation, rng)) pts = np.asarray(pts) cost, _, _, mean_p = pts.T @@ -600,69 +1359,105 @@ def pareto_plot( xs.append(fx) ys.append(fy) - # Interpolate and average - y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, fill_value=(0, max(y_)))(x_interp) - for x_, y_ in zip(xs, ys)] - y_mean = np.mean(y_interp, axis=0) + y_interp = [interp1d(x_, y_, kind="previous", bounds_error=False, + fill_value=np.nan)(x_interp) for x_, y_ in zip(xs, ys)] + y_mean = np.nanmean(y_interp, axis=0) + + # Filter out NaN values and zeros + valid_mask = ~np.isnan(y_mean) & (y_mean > 0) + if np.any(valid_mask): + label = f"{config_key}" + if target_samples < n_total_samples: + label += f" ({target_samples} samples)" + + plt.plot(x_interp[valid_mask], y_mean[valid_mask], + marker="o", linewidth=2.5, markersize=3, + label=label, color=color) + + # Add baseline points for specified attacks + for baseline_attack_name in baseline_attacks: + # Find the config key for this attack (could be multiple configs for same attack) + matching_configs = [] + for config_key, (results, config) in attacks_data.items(): + # Find the original attack name for this config + for atk_name, atk_cfg in ATTACKS: + if atk_cfg.get('title_suffix') == config_key and atk_name == baseline_attack_name: + matching_configs.append((config_key, results, config)) + break + + # Use the first matching config (or could choose a specific one) + if matching_configs: + config_key, results, config = matching_configs[0] + try: + # Fetch baseline data for this attack + baseline_params = config.get("baseline_params", lambda: { + "generation_config": {"num_return_sequences": 1, "temperature": 0.0} + })() + baseline_attack = config.get("baseline_attack", baseline_attack_name) + + baseline_data = fetch_data(model, baseline_attack, baseline_params, + list(range(100)), {"model", "attack_params"}) + + # Process baseline data + flops_per_step_fn = flops_per_step_fns.get(config_key) if flops_per_step_fns else None + y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data( + baseline_data, metric, threshold, flops_per_step_fn + ) - # Find minimum FLOPS where ASR >= target_asr - nonzero_mask = y_mean > 0 - if np.any(nonzero_mask): - y_vals = y_mean[nonzero_mask] - x_vals = x_interp[nonzero_mask] + if y_baseline is not None: + n_runs_baseline, n_steps_baseline, n_total_samples_baseline = y_baseline.shape - valid_indices = y_vals >= target_asr - if np.any(valid_indices): - min_flops = np.min(x_vals[valid_indices]) - continuous_flops.append(min_flops) - continuous_samples.append(j) - - if continuous_flops: - # Plot the continuous line - plt.plot(continuous_samples, continuous_flops, 'b-', linewidth=2, alpha=0.8, label='All Samples') - - # Highlight the baseline point - if baseline_frontier_data['x'].size > 0: - baseline_y_vals = baseline_frontier_data['y'] - baseline_x_vals = baseline_frontier_data['x'] - baseline_valid_indices = baseline_y_vals >= target_asr - if np.any(baseline_valid_indices): - baseline_flops = np.min(baseline_x_vals[baseline_valid_indices]) - plt.axhline(y=baseline_flops, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Baseline') - - # Highlight the discrete sample levels from the bar chart - for j in sample_levels_to_plot: - if j in [s for s in continuous_samples]: - idx = continuous_samples.index(j) - color = cmap(color_norm(j)) - plt.scatter(j, continuous_flops[idx], color=color, s=60, alpha=0.9, - edgecolors='black', linewidth=0.5, zorder=5) - - plt.xlabel("Number of Samples", fontsize=12) - plt.ylabel("FLOPS to Reach Baseline ASR", fontsize=12) - plt.xscale('log') - plt.yscale('log') - plt.grid(True, alpha=0.3) - plt.legend(fontsize=10) + # Get the point at max step count (last step) + pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, + baseline_flops_sampling_generation, return_ratio=False, cumulative=config.get("cumulative", False)) + cost_baseline, step_idx_baseline, n_samp_baseline, mean_p_baseline = pts.T + + # Find the point at max step count + max_step_mask = step_idx_baseline == (n_steps_baseline - 1) + if np.any(max_step_mask): + baseline_cost = cost_baseline[max_step_mask][0] + baseline_mean_p = mean_p_baseline[max_step_mask][0] - # Set reasonable x-axis limits - plt.xlim(1, n_total_samples) + # Plot baseline point + color = attack_colors.get(baseline_attack_name, "black") + plt.scatter(baseline_cost, baseline_mean_p, + s=100, marker="^", color=color, + edgecolors='black', linewidth=1.5, alpha=0.9, + label=f"{config_key} Baseline", zorder=10) - # Increase ylim by small margin - ymin, ymax = plt.ylim() - import math - margin = ((math.log10(ymax) - math.log10(ymin)) * 0.1) - plt.ylim(ymin / (1+margin), ymax * (1+margin)) + if verbose: + logging.info(f"Added baseline point for {baseline_attack_name}: cost={baseline_cost:.2e}, p_harmful={baseline_mean_p:.3f}") + + except Exception as e: + if verbose: + logging.warning(f"Could not add baseline point for {baseline_attack_name}: {e}") + + plt.xlabel("Total FLOPs", fontsize=14) + if threshold is None: + plt.ylabel(r"$s_{harm@n}$", fontsize=18) + else: + plt.ylabel(r"${ASR}@n$", fontsize=18) + + plt.grid(True, alpha=0.3) + plt.xscale(x_scale) + plt.legend(loc='upper left', fontsize=12) + plt.title(f"{model_title}", fontsize=16) plt.tight_layout() - plt.savefig(f"evaluate/distributional_paper/pareto_plots/{title}.pdf") + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/comparative_pareto_plots/{title.replace(' ', '_')}.pdf") + else: + plt.savefig(f"evaluate/distributional_paper/comparative_pareto_plots/{title.replace(' ', '_')}_t={threshold}.pdf") plt.close() + if verbose: + logging.info(f"Comparative Pareto plot saved for {model_title}") + def flops_ratio_plot( results: dict[str,np.ndarray], baseline: dict[str,np.ndarray] | None = None, - title: str = "FLOPS Ratio Analysis", + title: str = "FLOPs Ratio Analysis", sample_levels_to_plot: tuple[int, ...]|None = None, metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), cumulative: bool = False, @@ -673,7 +1468,7 @@ def flops_ratio_plot( verbose: bool = True, ): """ - Plot p_harmful vs the ratio of optimization FLOPS to sampling FLOPS. + Plot p_harmful vs the ratio of optimization FLOPs to sampling FLOPs. """ y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( results, metric, threshold, flops_per_step @@ -686,7 +1481,7 @@ def flops_ratio_plot( return_ratio=True, cumulative=cumulative) ratio, step_idx, n_samp, mean_p, opt_flop, sampling_flop = pts.T - # Calculate total FLOPS for coloring option + # Calculate total FLOPs for coloring option total_flop = opt_flop + sampling_flop # Filter out infinite ratios for plotting @@ -698,10 +1493,10 @@ def flops_ratio_plot( plt.figure(figsize=(10, 6)) - # Create dual color encoding: hue based on samples, strength based on total FLOPS + # Create dual color encoding: hue based on samples, strength based on total FLOPs # Normalize sample counts for hue sample_norm = setup_color_normalization("linear", n_samp_finite) - # Normalize total FLOPS for alpha/strength + # Normalize total FLOPs for alpha/strength flops_norm = setup_color_normalization(color_scale, total_flop_finite) # Get base colors from viridis colormap based on sample count @@ -728,14 +1523,14 @@ def flops_ratio_plot( edgecolors='black', linewidth=0.5, label=f"{j} samples") - plt.xlabel("Sampling FLOPS / Total FLOPS", fontsize=14) + plt.xlabel("Sampling FLOPs / Total FLOPs", fontsize=14) if threshold is None: - plt.ylabel("Mean p_harmful", fontsize=14) + plt.ylabel(r"$s_{harm@n}$", fontsize=14) else: - plt.ylabel(f"Max ASR (threshold: {threshold})", fontsize=12) + plt.ylabel(r"${ASR}@n$".format(threshold=threshold), fontsize=14) plt.grid(True, alpha=0.3) - plt.title(title, fontsize=16) + # plt.title(title, fontsize=16) # Add baseline if provided if baseline is not None: @@ -809,21 +1604,24 @@ def flops_ratio_plot( plt.xscale("log") plt.xlim(1e-5, 1) plt.ylim(bottom=0) - plt.legend(loc='upper left') + # plt.legend(loc='upper left') plt.tight_layout() - plt.savefig(f"evaluate/distributional_paper/flops_ratio_plots/{title}.pdf", bbox_inches='tight') + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/flops_ratio_plots/{title.replace(' ', '_')}.pdf", bbox_inches='tight') + else: + plt.savefig(f"evaluate/distributional_paper/flops_ratio_plots/{title.replace(' ', '_')}_t={threshold}.pdf", bbox_inches='tight') plt.close() if verbose: - logging.info(f"FLOPS ratio range: {ratio_finite.min():.2e} to {ratio_finite.max():.2e}") - logging.info(f"Mean p_harmful range: {mean_p_finite.min():.4f} to {mean_p_finite.max():.4f}") - logging.info(f"Total FLOPS range: {total_flop_finite.min():.2e} to {total_flop_finite.max():.2e}") + logging.info(f"FLOPs ratio range: {ratio_finite.min():.2e} to {ratio_finite.max():.2e}") + logging.info(f"Mean s_harm range: {mean_p_finite.min():.4f} to {mean_p_finite.max():.4f}") + logging.info(f"Total FLOPs range: {total_flop_finite.min():.2e} to {total_flop_finite.max():.2e}") def ideal_ratio_plot( results: dict[str,np.ndarray], baseline: dict[str,np.ndarray] | None = None, - title: str = "Ideal Sampling FLOPS Ratio", + title: str = "Ideal Sampling FLOPs Ratio", sample_levels_to_plot: tuple[int, ...]|None = None, metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), cumulative: bool = False, @@ -833,8 +1631,8 @@ def ideal_ratio_plot( verbose: bool = True, ): """ - Plot the ideal sampling FLOPS ratio for achieving different levels of harmfulness. - For each p_harmful level, finds the point that achieves that level with minimum total FLOPS + Plot the ideal sampling FLOPs ratio for achieving different levels of harmfulness. + For each p_harmful level, finds the point that achieves that level with minimum total FLOPs and plots the corresponding sampling ratio. """ y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( @@ -868,7 +1666,7 @@ def ideal_ratio_plot( achieving_mask = mean_p_finite >= p_level if np.any(achieving_mask): - # Among achieving points, find the one with minimum total FLOPS + # Among achieving points, find the one with minimum total FLOPs achieving_flops = total_flop_finite[achieving_mask] achieving_ratios = ratio_finite[achieving_mask] @@ -889,9 +1687,9 @@ def ideal_ratio_plot( min_ratios = np.array(min_ratios) achieved_p_levels = np.array(achieved_p_levels) - plt.figure(figsize=(12, 8)) + plt.figure(figsize=(6, 4)) - # Create the FLOP landscape: interpolated surface with color indicating total FLOPS + # Create the FLOP landscape: interpolated surface with color indicating total FLOPs # Use raw ratios instead of normalized ones landscape_p_harmful = [] landscape_ratios = [] @@ -912,7 +1710,7 @@ def ideal_ratio_plot( ratio_grid = np.logspace(np.log10(1e-5), np.log10(1.0), 100) P_grid, Ratio_grid = np.meshgrid(p_grid, ratio_grid) - # Interpolate FLOPS values onto the grid + # Interpolate FLOPs values onto the grid try: flops_grid = griddata( (landscape_p_harmful, landscape_ratios), @@ -922,7 +1720,7 @@ def ideal_ratio_plot( fill_value=np.nan ) except Exception as e: - raise ValueError(f"Error interpolating FLOPS values.") + raise ValueError(f"Error interpolating FLOPs values.") # Create mask to only show values within the explored bounds mask = np.ones_like(flops_grid, dtype=bool) @@ -943,14 +1741,14 @@ def ideal_ratio_plot( # Create contour plot of the FLOP landscape contour = plt.contourf(P_grid, Ratio_grid, flops_grid_masked, levels=50, - cmap='plasma', alpha=0.8, extend='both') + cmap='viridis', alpha=0.8, extend='both') - # Add colorbar for total FLOPS - cbar = plt.colorbar(contour, label='Total FLOPS') + # Add colorbar for total FLOPs + cbar = plt.colorbar(contour, label='Total FLOPs') cbar.formatter.set_powerlimits((0, 0)) # Use scientific notation # Plot the ideal ratio curve (raw ratios) - this traces the minimum through the landscape - plt.plot(achieved_p_levels, ideal_ratios, 'b-', linewidth=3, label='Ideal Ratio (Min FLOPS)', marker='o', markersize=4, zorder=5) + plt.plot(achieved_p_levels, ideal_ratios, 'k', linewidth=3, label='Ideal Ratio (Min FLOPs)', zorder=5) # Plot the maximum ratio explored curve (raw ratios) # plt.plot(achieved_p_levels, max_ratios, 'g--', linewidth=2, label='Max Ratio Explored', marker='s', markersize=3, zorder=5) @@ -972,20 +1770,22 @@ def ideal_ratio_plot( if np.any(baseline_finite_mask): # For baseline, just plot the raw ratios plt.scatter(baseline_mean_p[baseline_finite_mask], baseline_ratio[baseline_finite_mask], - color="red", s=80, alpha=0.9, marker="^", + color="red", s=60, alpha=0.9, marker="^", edgecolors='black', linewidth=0.5, label="Baseline", zorder=6) - plt.xlabel("p_harmful", fontsize=14) - plt.ylabel("Sampling FLOPS / Total FLOPS", fontsize=14) + plt.xlabel(r"$s_{harm}$", fontsize=16) + plt.ylabel("Sampling FLOPs / Total FLOPs", fontsize=16) plt.grid(True, alpha=0.3) - plt.title(title, fontsize=16) + plt.tick_params(axis='both', which='major', labelsize=14) plt.xlim(0, 1) plt.yscale('log') plt.ylim(1e-5, 1.0) - plt.xlim(left=0) - plt.legend() + plt.legend(loc='lower right') plt.tight_layout() - plt.savefig(f"evaluate/distributional_paper/ideal_ratio_plots/{title}.pdf", bbox_inches='tight') + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/ideal_ratio_plots/{title.replace(' ', '_')}.pdf", bbox_inches='tight') + else: + plt.savefig(f"evaluate/distributional_paper/ideal_ratio_plots/{title.replace(' ', '_')}_t={threshold}.pdf", bbox_inches='tight') plt.close() if verbose: @@ -993,7 +1793,7 @@ def ideal_ratio_plot( logging.info(f"Ideal ratio range: {ideal_ratios.min():.4f} to {ideal_ratios.max():.4f}") logging.info(f"Max ratio range: {max_ratios.min():.4f} to {max_ratios.max():.4f}") logging.info(f"Min ratio range: {min_ratios.min():.4f} to {min_ratios.max():.4f}") - logging.info(f"Total FLOPS landscape range: {landscape_total_flops.min():.2e} to {landscape_total_flops.max():.2e}") + logging.info(f"Total FLOPs landscape range: {landscape_total_flops.min():.2e} to {landscape_total_flops.max():.2e}") logging.info(f"Number of points in landscape: {len(landscape_total_flops)}") logging.info(f"Number of p_harmful levels with solutions: {len(achieved_p_levels)}") @@ -1001,7 +1801,7 @@ def ideal_ratio_plot( def flops_breakdown_plot( results: dict[str,np.ndarray], baseline: dict[str,np.ndarray] | None = None, - title: str = "FLOPS Breakdown Analysis", + title: str = "FLOPs Breakdown Analysis", sample_levels_to_plot: tuple[int, ...]|None = None, metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), cumulative: bool = False, @@ -1011,7 +1811,7 @@ def flops_breakdown_plot( verbose: bool = True, ): """ - Plot optimization FLOPS vs sampling FLOPS with p_harmful as a 2D surface. + Plot optimization FLOPs vs sampling FLOPs with p_harmful as a 2D surface. """ y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation = preprocess_data( results, metric, threshold, flops_per_step @@ -1047,7 +1847,7 @@ def flops_breakdown_plot( p_val = y[:, i, rng.choice(n_total_samples, size=j, replace=False)].max(axis=-1).mean(axis=0) p_vals.append(p_val) - opt_flops.append(opt_flop+sampling_flop) + opt_flops.append(opt_flop + sampling_flop) sampling_flops.append(sampling_flop) p_harmful_vals.append(np.mean(p_vals)) n_samples_vals.append(j) @@ -1057,20 +1857,20 @@ def flops_breakdown_plot( p_harmful_vals = np.array(p_harmful_vals) n_samples_vals = np.array(n_samples_vals) - plt.figure(figsize=(12, 8)) + plt.figure(figsize=(4, 2.8)) # Create 2D surface plot using griddata interpolation # Define grid for interpolation sampling_min, sampling_max = sampling_flops.min(), sampling_flops.max() opt_min, opt_max = opt_flops.min(), opt_flops.max() - # Use log space for sampling FLOPS if range is large + # Use log space for sampling FLOPs if range is large if sampling_max / sampling_min > 100: sampling_grid = np.logspace(np.log10(sampling_min), np.log10(sampling_max), 100) else: sampling_grid = np.linspace(sampling_min, sampling_max, 100) - # Use log space for optimization FLOPS if range is large + # Use log space for optimization FLOPs if range is large if opt_max / opt_min > 100: opt_grid = np.logspace(np.log10(opt_min), np.log10(opt_max), 100) else: @@ -1079,14 +1879,28 @@ def flops_breakdown_plot( Sampling_grid, Opt_grid = np.meshgrid(sampling_grid, opt_grid) # Interpolate p_harmful values onto the grid + # use anisotropic interpolation try: p_harmful_grid = griddata( (sampling_flops, opt_flops), p_harmful_vals, (Sampling_grid, Opt_grid), method='linear', - fill_value=np.nan + rescale=True ) + if np.isnan(p_harmful_grid).sum() > 0: + p_harmful_grid_nearest = griddata( + (sampling_flops, opt_flops), + p_harmful_vals, + (Sampling_grid, Opt_grid), + method='nearest', + fill_value=0, + rescale=True + ) + fill_mask = np.isnan(p_harmful_grid) + impossible_mask = ((Sampling_grid + opt_min) > Opt_grid) | ((Opt_grid-Sampling_grid) > opt_max-Sampling_grid) + p_harmful_grid[fill_mask] = p_harmful_grid_nearest[fill_mask] + p_harmful_grid[impossible_mask] = np.nan except Exception as e: if verbose: logging.info(f"Linear interpolation failed: {e}, trying nearest neighbor") @@ -1095,56 +1909,56 @@ def flops_breakdown_plot( p_harmful_vals, (Sampling_grid, Opt_grid), method='nearest', - fill_value=0 + fill_value=0, + rescale=True ) - # Create contour plot + # Create contour plot (transpose the grids) levels = np.linspace(np.nanmin(p_harmful_vals), np.nanmax(p_harmful_vals), 50) - contour = plt.contourf(Sampling_grid, Opt_grid, p_harmful_grid, levels=levels, - cmap='plasma', extend='both') + contour = plt.contourf(Opt_grid, Sampling_grid, p_harmful_grid, levels=levels, + cmap='viridis', extend='both') # Add colorbar - cbar = plt.colorbar(contour) + cbar = plt.colorbar(contour, ticks=np.linspace(np.nanmin(p_harmful_vals), np.nanmax(p_harmful_vals), 5)) + cbar.ax.set_yticklabels([f'{tick:.2f}' for tick in cbar.get_ticks()]) if threshold is None: - cbar.set_label(r"$p_{harmful}$", fontsize=14) + cbar.set_label(r"$s_{harm@n}$", fontsize=17) else: - cbar.set_label(f"ASR (threshold: {threshold})", fontsize=14) + cbar.set_label(r"ASR@$n$", fontsize=17) + # Find maximum ASR at each total FLOP level, ignoring higher FLOP levels with lower ASR + total_flops = sampling_flops + opt_flops - # Add baseline if provided - if baseline is not None: - y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, baseline_flops_sampling_generation = preprocess_data( - baseline, metric, threshold, flops_per_step - ) - - baseline_pts = get_points(y_baseline, baseline_flops_optimization, baseline_flops_sampling_prefill_cache, - baseline_flops_sampling_generation, return_ratio=False, cumulative=cumulative) - baseline_cost, baseline_step_idx, baseline_n_samp, baseline_mean_p = baseline_pts.T + # Sort by total FLOPs to process in order + sort_idx = np.argsort(total_flops) + sorted_total_flops = total_flops[sort_idx] + sorted_sampling_flops = sampling_flops[sort_idx] + sorted_opt_flops = opt_flops[sort_idx] + sorted_p_harmful = p_harmful_vals[sort_idx] - # Calculate baseline FLOP components - baseline_opt_flops = [] - baseline_sampling_flops = [] + max_asr_points = [] + max_asr_seen = -np.inf - for i in range(0, y_baseline.shape[1], 1): - opt_flop = np.mean(baseline_flops_optimization[:, :i+1].sum(axis=1)) - sampling_flop = np.mean(baseline_flops_sampling_generation[:, i]) * 1 + np.mean(baseline_flops_sampling_prefill_cache[:, i]) + for i in range(len(sorted_total_flops)): + current_asr = sorted_p_harmful[i] - baseline_opt_flops.append(opt_flop+sampling_flop) - baseline_sampling_flops.append(sampling_flop) + # Only add this point if it achieves a higher ASR than we've seen before + if current_asr > max_asr_seen: + max_asr_seen = current_asr + max_asr_points.append((sorted_opt_flops[i], sorted_sampling_flops[i])) - baseline_opt_flops = np.array(baseline_opt_flops) - baseline_sampling_flops = np.array(baseline_sampling_flops) + if max_asr_points: + max_asr_points = np.array(max_asr_points) - plt.scatter(baseline_sampling_flops, baseline_opt_flops, - s=60, alpha=0.9, marker="^", - edgecolors='red', linewidth=2, - color='white', label="Baseline") + plt.plot(max_asr_points[:, 0], max_asr_points[:, 1], + color='black', linewidth=2, linestyle="--",alpha=0.8, label="Compute Optimal Frontier") - plt.xlabel("Sampling FLOPS", fontsize=14) - plt.ylabel("Total FLOPS", fontsize=14) + plt.xlabel("Total FLOPs", fontsize=14) + plt.ylabel("Sampling FLOPs", fontsize=14) + plt.tick_params(axis='both', which='major', labelsize=12) plt.grid(True, alpha=0.3) - plt.title(title, fontsize=16) + # plt.title(title, fontsize=16) # Use log scale for both axes if the range is large if sampling_max / sampling_min > 100: @@ -1152,39 +1966,287 @@ def flops_breakdown_plot( if opt_max / opt_min > 100: plt.yscale('log') - plt.legend(loc='upper left') + plt.legend(loc='lower left', fontsize=13, bbox_to_anchor=(-0.1, 0.95)) plt.tight_layout() - plt.savefig(f"evaluate/distributional_paper/flops_breakdown/{title}.pdf", bbox_inches='tight') + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/flops_breakdown/{title.replace(' ', '_')}.pdf", bbox_inches='tight') + else: + plt.savefig(f"evaluate/distributional_paper/flops_breakdown/{title.replace(' ', '_')}_t={threshold}.pdf", bbox_inches='tight') plt.close() if verbose: - logging.info(f"Sampling FLOPS range: {sampling_flops.min():.2e} to {sampling_flops.max():.2e}") - logging.info(f"Optimization FLOPS range: {opt_flops.min():.2e} to {opt_flops.max():.2e}") + logging.info(f"Sampling FLOPs range: {sampling_flops.min():.2e} to {sampling_flops.max():.2e}") + logging.info(f"Optimization FLOPs range: {opt_flops.min():.2e} to {opt_flops.max():.2e}") logging.info(f"p_harmful range: {p_harmful_vals.min():.4f} to {p_harmful_vals.max():.4f}") logging.info(f"Surface grid shape: {p_harmful_grid.shape}") logging.info(f"Valid surface points: {np.sum(~np.isnan(p_harmful_grid))}/{p_harmful_grid.size}") +def histogram_plot( + sampled_data: dict[str,np.ndarray], + model_title: str, + atk_name: str, + cfg: dict, + threshold: float|None = None, +): + plt.figure(figsize=(5, 6)) + data_list = [] + positions = [] + data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) + for i in np.arange(data.shape[1]): + data_list.append(data[:, i].flatten()) + positions.append(i) + + # Create 2D heatmap + # Define bins for p_harmful values (y-axis) + p_harmful_bins = np.linspace(0, 1, 101) # 50 bins from 0 to 1 + + # Create 2D histogram matrix + heatmap_data = np.zeros((len(p_harmful_bins)-1, len(positions))) + + for i, (pos, d) in enumerate(zip(positions, data_list)): + # Calculate histogram for this position + counts, _ = np.histogram(d, bins=p_harmful_bins) + heatmap_data[:, i] = counts / len(d) + + # Create the heatmap + im = plt.imshow(heatmap_data, + aspect='auto', + origin='lower', + extent=[positions[0], positions[-1], 0, 1], + cmap='viridis', + norm=LogNorm(vmin=1/len(d), vmax=heatmap_data.max()) + ) + # Add colorbar + cbar = plt.colorbar(im, label='Density') + + # Calculate and plot median and mean lines + medians = [] + means = [] + for data_at_pos in data_list: + medians.append(np.median(data_at_pos)) + means.append(np.mean(data_at_pos)) + + # Plot mean line + plt.plot(positions, means, color='orange', linewidth=2, label='Mean', alpha=0.8) + plt.plot(positions, np.maximum.accumulate(means), color='red', linewidth=2, label='Max', alpha=0.8) + + # Add legend + plt.legend(loc='upper right', framealpha=0.8) + + plt.xscale("log") + plt.xlabel('Step Index', fontsize=18) + plt.ylabel('p_harmful', fontsize=18) + # plt.title(f'{model_title} - {atk_name}', fontsize=18) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Save the plot + if threshold is None: + filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}.pdf" + else: + filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}_t={threshold}.pdf" + plt.savefig(filename.replace(' ', '_'), dpi=300, bbox_inches='tight') + plt.close() + +def histogram_2_plot( + sampled_data: dict[str,np.ndarray], + model_title: str, + cfg: dict, + threshold: float|None = None, +): + # Create histogram plot + plt.figure(figsize=(10, 6)) + threshold = 0.0 + show_top = False # Set to False to hide the top subplot + bins = np.linspace(0, 1, 21) + + + data_low = np.array(sampled_data[("scores","strong_reject","p_harmful")])[:, 0].flatten() + data_low = data_low[data_low > threshold] + + data_high = np.array(sampled_data[("scores","strong_reject","p_harmful")])[:, -1].flatten() + data_high = data_high[data_high > threshold] + + # Calculate Fisher-Pearson skewness coefficient for both datasets + def calculate_skewness(data): + n = len(data) + mean = np.mean(data) + std = np.std(data, ddof=0) # Population standard deviation + m3 = np.sum((data - mean)**3) / n + return m3 / (std**3) if std > 0 else 0 + + skew_low = calculate_skewness(data_low) + skew_high = calculate_skewness(data_high) + + # Create CDF plot + fig, ax = plt.subplots(figsize=(10, 6)) + + # Calculate survival functions (1 - CDF) + data_low_sorted = np.sort(data_low) + data_high_sorted = np.sort(data_high) + + # Calculate proportion of items with value <= x, then invert to get survival function + cdf_low = np.arange(1, len(data_low_sorted) + 1) / len(data_low_sorted) + cdf_high = np.arange(1, len(data_high_sorted) + 1) / len(data_high_sorted) + + survival_low = 1 - cdf_low + survival_high = 1 - cdf_high + + # Plot survival functions + ax.plot(data_low_sorted, survival_low, label=r"First Step ($\gamma_1$" + f"={skew_low:.2f})", linewidth=2, alpha=0.8) + ax.plot(data_high_sorted, survival_high, label=r"Last Step ($\gamma_1$" + f"={skew_high:.2f})", linewidth=2, alpha=0.8) + + ax.set_xlabel(r"$s_{harm}$", fontsize=14) + ax.set_ylabel("Survival Probability (P(X $>$ x))", fontsize=14) + ax.set_xlim(threshold, 1) + ax.set_ylim(0, 1) + ax.legend(fontsize=12) + ax.grid(True, alpha=0.3) + + ax.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful Survival Function", + fontsize=16) + + # Save the plot + if threshold is None: + filename = f"evaluate/distributional_paper/cdf_plots/{model_title}_{cfg['title_suffix']}.pdf" + else: + filename = f"evaluate/distributional_paper/cdf_plots/{model_title}_{cfg['title_suffix']}_t={threshold}.pdf" + plt.savefig(filename.replace(' ', '_'), dpi=300, bbox_inches='tight') + plt.close() + + if show_top: + fig, (ax_top, ax_bottom) = plt.subplots( + 2, 1, sharex=True, + figsize=(10, 6), + gridspec_kw={"height_ratios": [1, 3]} # top is shorter + ) + axes = (ax_top, ax_bottom) + else: + fig, ax_bottom = plt.subplots(figsize=(10, 4)) + axes = (ax_bottom,) + + # draw the same two histograms on both axes + for ax in axes: + ax.hist(data_high, bins=bins, alpha=0.7, label=r"Last Step ($\gamma_1$" + f"={skew_high:.2f})") + ax.hist(data_low, bins=bins, alpha=0.7, label=r"First Step ($\gamma_1$" + f"={skew_low:.2f})") + + if show_top: + # ------------- tune y limits so the top shows only the tall bar(s) ----- + counts_start = np.sort(np.histogram(data_low.flatten(), bins=bins)[0]) + counts_end = np.sort(np.histogram(data_high.flatten(), bins=bins)[0]) + + ylims = ( + (0, max(counts_start[-2], counts_end[-2]) * 1.1), + (min(counts_start[-1], counts_end[-1]) * 0.9, max(counts_start[-1], counts_end[-1]) * 1.1), + ) + ax_top.set_ylim(*ylims[1]) + ax_bottom.set_ylim(*ylims[0]) + + # ---------------- cosmetic fixes ----------------- + ax_top.spines.bottom.set_visible(False) + ax_bottom.spines.top.set_visible(False) + ax_top.tick_params(axis="x", which="both", bottom=False, labelbottom=False) + + # little diagonal "break" marks + d = .5 # size of diagonal lines + kwargs = dict(marker=[(-1, -1), (1, 1)], markersize=12, + linestyle="none", color="k", clip_on=False) + ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **kwargs) + ax_bottom.plot([0, 1], [1, 1], transform=ax_bottom.transAxes, **kwargs) + + # --------- the usual labels ------------ + ax_bottom.set_xlabel(r"$s_{harm}$", fontsize=14) + ax_bottom.set_ylabel("Frequency (\\%)", fontsize=14) + ax_bottom.set_xlim(threshold, 1) + + if show_top: + ax_top.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful distribution", + fontsize=16) + else: + ax_bottom.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful distribution", + fontsize=16) + + # Convert y-axis tick labels to percentages + if not show_top: + # Get current y-tick locations + yticks = ax_bottom.get_yticks() + # Convert to percentages by dividing by total count and multiplying by 100 + ytick_labels = [f"{100*tick / len(data_low.flatten()):.0f}%" for tick in yticks] + ax_bottom.set_yticklabels(ytick_labels) + else: + for ax in (ax_bottom,): + # Get current y-tick locations + yticks = ax.get_yticks() + # Convert to percentages by dividing by total count and multiplying by 100 + ytick_labels = [f"{100*tick / len(data_low.flatten()):.0f}%" for tick in yticks] + ax.set_yticklabels(ytick_labels) + # Convert y-axis tick labels to percentages + for ax in (ax_top,): + # Get current y-tick locations + yticks = ax.get_yticks() + # Convert to percentages by dividing by total count and multiplying by 100 + ytick_labels = [f"{100*tick / len(data_low.flatten()):.1f}%" for tick in yticks] + ax.set_yticklabels(ytick_labels) + ax_top.legend() + + if not show_top: + ax_bottom.legend() + + for ax in axes: + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot + if threshold is None: + filename = f"evaluate/distributional_paper/histograms_2/{model_title}_{cfg['title_suffix']}.pdf" + else: + filename = f"evaluate/distributional_paper/histograms_2/{model_title}_{cfg['title_suffix']}_t={threshold}.pdf" + plt.savefig(filename.replace(' ', '_'), dpi=300, bbox_inches='tight') + plt.close() + def ridge_plot( sampled_data: dict[str,np.ndarray], model_title: str, cfg: dict, + threshold: float|None = None, ): - sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) + sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0), 'figure.figsize': (3, 3)}) # Create ridge plot for p_harmful distributions across steps data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) # Prepare data for ridge plot ridge_data = [] - step_idxs = [0] + list(generate_sample_sizes(data.shape[1]-1)) - if data.shape[1]-1 not in step_idxs: - step_idxs.append(data.shape[1]-1) + def log_spaced_indices(n_cols: int, k: int = 4) -> list[int]: + """ + Return k log-spaced column indices in [0, n_cols-1] inclusive. + Guarantees 0 and n_cols-1 are present; deduplicates if n_cols is small. + """ + # corner cases: 0 or 1 column → just [0]; 2–3 cols → all of them + if n_cols <= k: + return list(range(n_cols)) + + max_idx = n_cols - 1 + # make (k) points geometrically spaced in (1 … max_idx) + inner = np.geomspace(1, max_idx, num=k, dtype=int) + + # build the final list and drop duplicates, then sort + idx = np.unique(np.concatenate(([0], inner, [max_idx]))) + # if de-duplication left us with fewer than k values, pad with lin-spaced ones + if idx.size < k: + extra = np.linspace(0, max_idx, num=k, dtype=int) + idx = np.unique(np.concatenate((idx, extra)))[:k] + + return idx.tolist() + step_idxs = log_spaced_indices(data.shape[1], 4) + for step_idx in step_idxs: step_data = data[:, step_idx, :].flatten() # Get p_harmful values for this step # Round/bucketize the data into five values: 0, 0.25, 0.5, 0.75, 1.0 # step_data = np.round(step_data * 4) / 4 for value in step_data: - ridge_data.append({'step': f'Step {step_idx}', 'p_harmful': value}) + ridge_data.append({'step': f'Step {step_idx}', r"$h(Y)$": value}) df = pd.DataFrame(ridge_data) # Create ridge plot for p_harmful distributions across steps @@ -1193,42 +2255,480 @@ def ridge_plot( pal = sns.cubehelix_palette(n_steps, rot=-.25, light=.7) # Initialize the FacetGrid object - g = sns.FacetGrid(df, row="step", hue="step", aspect=15, height=.5, palette=pal, + g = sns.FacetGrid(df, row="step", hue="step", aspect=5, height=.4, palette=pal, row_order=unique_steps) # Draw the densities - g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, fill=True, alpha=1, linewidth=1.5) - g.map(sns.kdeplot, "p_harmful", bw_adjust=0.5, clip_on=True, color="w", lw=2) + g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, clip=(0, 1), fill=True, alpha=1, linewidth=0, zorder=1) + g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, clip=(0, 1), color="w", lw=3, zorder=0) # Add vertical lines for mean and median def add_mean_lines(x, **kwargs): ax = plt.gca() mean_val = np.mean(x) - ax.axvline(mean_val, color='red', linestyle='-', alpha=0.7, linewidth=1, ymax=0.8) + median_val = np.median(x) + percentile_95 = np.percentile(x, 95) + # ax.axhline(0, color='black', linestyle='-', alpha=0.7, linewidth=0.5, ymax=0.5) + ax.axvline(median_val, color='black', linestyle='--', alpha=0.7, linewidth=1, ymax=0.5) + ax.axvline(percentile_95, color='blue', linestyle='--', alpha=0.7, linewidth=1, ymax=0.5) + ax.axvline(mean_val, color='red', linestyle='-', alpha=0.7, linewidth=1, ymax=0.5) - g.map(add_mean_lines, "p_harmful") + g.map(add_mean_lines, r"$h(Y)$") # Add reference line at y=0 - g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False) + g.refline(y=0, linewidth=1, linestyle="-", color=None, clip_on=False) # Set the subplots to overlap g.figure.subplots_adjust(hspace=-.4) + # g.figure.subplots_adjust(top=) # Remove axes details that don't play well with overlap - g.set_titles("") + g.set_titles(f"") g.set(yticks=[], ylabel="") g.despine(bottom=True, left=True) + g.set_xlabels(r"$h(Y)$", fontsize=14) g.set(xlim=(0, 1)) - - # Add title to ridge plot - g.figure.suptitle(f'{model_title} - {cfg["title_suffix"]} - p_harmful Ridge Plot', - fontsize=14, y=0.95) + plt.style.use("science") + # Add legend for the mean line + from matplotlib.lines import Line2D + legend_elements = [Line2D([0], [0], color='red', lw=1, alpha=0.7, label=r'$\text{Mean}$'), + Line2D([0], [0], color='black', linestyle='--', lw=1, alpha=0.7, label=r'$\text{Median}$'), + Line2D([0], [0], color='blue', linestyle='--', lw=1, alpha=0.7, label=r'$\text{95th Percentile}$')] + def put_legend_on_top(fig, handles, **legend_kw): + """ + Add a single figure-level legend centred above *fig* and + tighten the subplot area so only the legend's real height is reserved. + """ + # 1 — draw the legend (temporarily anywhere) + lg = fig.legend(handles=handles, + loc="upper center", + bbox_to_anchor=(0.5, 1), # top centre of the figure + frameon=False, **legend_kw) + + # 2 — draw the canvas *once* so we get the correct bbox + fig.canvas.draw() + renderer = fig.canvas.get_renderer() + legend_bbox = lg.get_window_extent(renderer=renderer) + + # convert pixel height to figure fraction + legend_h_px = legend_bbox.height + fig_h_px = fig.get_size_inches()[1] * fig.dpi + frac = legend_h_px / fig_h_px + + # 3 — shrink the subplot area so it sits just below the legend + pad = 0.01 # a tiny bit of breathing room + fig.subplots_adjust(top=1-frac-pad) + + return lg + + # put_legend_on_top(g.figure, legend_elements, ncol=1) # <─ that's it + # g.figure.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1), ncol=1, + # frameon=False, columnspacing=1.0, handletextpad=0.5) + g.figure.suptitle(f"{model_title}", fontsize=14, y=0.9, va="top") # Save the ridge plot - filename = f"evaluate/distributional_paper/ridge_plots/{model_title}_{cfg['title_suffix']}.pdf" - g.figure.savefig(filename, dpi=300, bbox_inches='tight') + if threshold is None: + filename = f"evaluate/distributional_paper/ridge_plots/{model_title}_{cfg['title_suffix']}.pdf" + else: + filename = f"evaluate/distributional_paper/ridge_plots/{model_title}_{cfg['title_suffix']}_t={threshold}.pdf" + g.figure.savefig(filename.replace(' ', '_'), bbox_inches='tight') + plt.close(g.figure) + n_steps_to_show = 4 + + # ---------- basic theming ---------- + sns.set_theme( + style="white", + rc={ + "axes.facecolor": (0, 0, 0, 0), + "figure.figsize": (1.5 * n_steps_to_show, 1.5), # widen for columns + }, + ) + + # ---------- collect the data ---------- + data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) + + # choose exactly n_steps_to_show equally-spaced indices + all_step_idxs = [0] + list(generate_sample_sizes(data.shape[1]-1)) + if data.shape[1]-1 not in all_step_idxs: + all_step_idxs.append(data.shape[1]) + + # Select exactly n_steps_to_show equally-spaced indices + if len(all_step_idxs) > n_steps_to_show: + # Use numpy to select evenly spaced indices + indices = np.linspace(0, len(all_step_idxs) - 1, n_steps_to_show + 1, dtype=int) + indices = [indices[0], *indices[2:]] + step_idxs = [all_step_idxs[i] for i in indices] + else: + step_idxs = all_step_idxs + + ridge_rows = [] + for idx in step_idxs: + ridge_rows.extend( + { + "step": f"Step {idx}", + r"$h(Y)$": val, + } + for val in data[:, idx, :].ravel() + ) + + df = pd.DataFrame(ridge_rows) + + # ---------- build the faceted plot ---------- + unique_steps = sorted(df["step"].unique(), key=lambda x: int(x.split()[1])) + pal = sns.cubehelix_palette(int(len(unique_steps)*1.5), rot=-0.25, light=0.7) + + g = sns.FacetGrid( + df, + col="step", + hue="step", + palette=pal, + col_order=unique_steps, + sharey=False, # independent y-axis per column + aspect=1, + height=2.5, + ) + + # densities + g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, fill=True, alpha=1, linewidth=0, zorder=1) + g.map(sns.kdeplot, r"$h(Y)$", bw_adjust=0.5, color="w", lw=3, zorder=0) + # # Store the y-limits from the first plot to apply to all plots + # first_ax = g.axes.flat[0] + # first_ylim = first_ax.get_ylim() + + # # Apply the same y-limits to all subplots + # for ax in g.axes.flat: + # ax.set_ylim(first_ylim) + # Add a single y-axis tick at the density value of h(Y)=0 for each subplot + + g.set(yticks=[], ylabel="") # hide y-ticks + + for ax in g.axes.flat: + # Get the KDE line data + kde_line = ax.lines[0] # The first line should be the KDE plot + x_data = kde_line.get_xdata() + y_data = kde_line.get_ydata() + + # Find the density value at h(Y)=0 by interpolating + if len(x_data) > 0 and len(y_data) > 0: + # Find the closest x value to 0 or interpolate + if 0 in x_data: + density_at_zero = y_data[x_data == 0][0] + else: + # Interpolate to find density at x=0 + from scipy.interpolate import interp1d + if x_data.min() <= 0 <= x_data.max(): + interp_func = interp1d(x_data, y_data, kind='linear', bounds_error=False, fill_value=0) + density_at_zero = interp_func(0) + else: + density_at_zero = 0 + + # Set a single y-tick at this density value + ax.set_yticks([density_at_zero/2, density_at_zero]) + ax.set_yticklabels([f'{density_at_zero/2:.1f}', f'{density_at_zero:.1f}']) + ax.tick_params(axis='y', labelsize=12, pad=-2) + + # central-tendency & cut-off lines + def add_mean_lines(x, **kwargs): + ax = plt.gca() + mean_val = np.mean(x) + median_val = np.median(x) + p95 = np.percentile(x, 95) + ax.axvline(median_val, ls="--", lw=1, color="black", ymax=0.5, alpha=0.7) + ax.axvline(p95, ls="--", lw=1, color="blue", ymax=0.5, alpha=0.7) + ax.axvline(mean_val, ls="-", lw=1, color="red", ymax=0.5, alpha=0.7) + + g.map(add_mean_lines, r"$h(Y)$") + + # Reduce horizontal spacing between subplots + # g.figure.subplots_adjust(wspace=-0.0) + + # aesthetics + g.set_titles("") # no subplot headers + tick_vals = np.linspace(0, 1, 6) # [0. , 0.2, 0.4, 0.6, 0.8, 1.] + g.set(xticks=tick_vals) + for ax in g.axes.flat: + ax.tick_params(axis='x', pad=0, labelsize=12) + g.set_xlabels("Harmfulness", fontsize=14) + g.set(xlim=(0, 1)) + # g.despine(left=True) + plt.style.use("science") + # ------------- y-axis label on first facet ------------- + first_ax = g.axes.flat[0] + first_ax.set_ylabel("Density", fontsize=13)#, labelpad=-14) + + # optional: make sure the other facets stay unlabeled + for ax in g.axes.flat[1:]: + ax.set_ylabel("") + # ------------- add "Step x" labels ------------- + for ax, step in zip(g.axes.flat, unique_steps): + if step[-1] == "9": + step = "Step " + str(int(step.split()[1])+1) + ax.text( + 0.5, 0.95, step, # centered just above each panel + ha="center", va="bottom", + transform=ax.transAxes, + fontsize=12, + fontweight="bold" # optional, adjust to taste + ) + + # build a single, vertical legend that mimics the example image + legend_elements = [ + Line2D([0], [0], color="black", lw=1, label="Median", ls="--"), + Line2D([0], [0], color="red", lw=1, label="Greedy"), + Line2D([0], [0], color="blue", lw=1, label="95th percentile", ls="--"), + ] + + # Determine if we have a single subplot + if len(g.axes.flat) == 1: + bbox_anchor = (0.2, 0.8) + else: + bbox_anchor = (0.055, 0.8) + + g.figure.legend( + handles=legend_elements, + loc="upper left", # anchor to top-left of the figure + bbox_to_anchor=bbox_anchor, # fine-tune position (x, y in fig-coords) + frameon=False, + ncol=1, # vertical stack + handletextpad=0.4, + labelspacing=0.3, + borderaxespad=0.0, + ) + + # Add horizontal time arrow above the plots + if len(g.axes.flat) > 1: + from matplotlib.patches import FancyArrowPatch + + # Get the positions of the first and last subplots + first_ax = g.axes.flat[0] + last_ax = g.axes.flat[-1] + + # Get the positions in figure coordinates + first_pos = first_ax.get_position() + last_pos = last_ax.get_position() + + # Calculate arrow position (slightly above the plots) + arrow_y = first_pos.y1 + 0.08 # 8% above the top of the plots + arrow_start_x = first_pos.x0 + 0.1 * first_pos.width # 10% into first subplot + arrow_end_x = last_pos.x1 - 0.1 * last_pos.width # 90% into last subplot + + # Create and add the arrow patch + arrow = FancyArrowPatch((arrow_start_x, arrow_y), (arrow_end_x, arrow_y), + connectionstyle="arc3", + arrowstyle='-|>', + mutation_scale=10, + linewidth=0.75, + color='black', + alpha=1.0, + transform=g.figure.transFigure) + g.figure.patches.append(arrow) + + # --------- save / close ---------- + if threshold is None: + filename = ( + f"evaluate/distributional_paper/ridge_plots/" + f"{model_title}_{cfg['title_suffix']}_side_by_side.pdf" + ) + else: + filename = ( + f"evaluate/distributional_paper/ridge_plots/" + f"{model_title}_{cfg['title_suffix']}_side_by_side_t={threshold}.pdf" + ) + g.figure.savefig(filename.replace(' ', '_'), bbox_inches="tight") plt.close(g.figure) + # ---------- NEW: ratio line plot ---------- + num_steps = data.shape[1] + ratios_1 = [] + ratios_2 = [] + + for step_idx in range(num_steps): + vals = data[:, step_idx, :].flatten() + # Original ratio: [0.1,0.5] vs [0.5,1.0] + n_low_1 = np.sum((vals > 0.50) & (vals <= 1.0)) + n_high_1 = np.sum((vals > 0.10) & (vals <= 1.0)) + ratio_1 = n_low_1 / n_high_1 if n_high_1 else np.nan + ratios_1.append(ratio_1) + + # New ratio: [0.0,0.1] vs [0.1,1.0] + n_low_2 = np.sum((vals >= 0.10) & (vals <= 1.0)) + n_high_2 = np.sum((vals >= 0.00) & (vals <= 1.0)) + ratio_2 = n_low_2 / n_high_2 if n_high_2 else np.nan + ratios_2.append(ratio_2) + + # Create figure with two subfigures + fig, (ax1) = plt.subplots(1, 1, figsize=(6.5, 2.65)) + plt.style.use("science") + + # First subfigure: original ratio plots + sns.lineplot(x=np.arange(num_steps), y=ratios_2, label=r"$P(\text{¬refusal})$", ax=ax1, marker="o" if num_steps == 1 else None) + sns.lineplot(x=np.arange(num_steps), y=ratios_1, linestyle="--", label=r"$P(\text{harmful} \mid \text{¬refusal})$", ax=ax1, marker="x" if num_steps == 1 else None) + ax1.yaxis.set_major_locator(MaxNLocator(nbins="auto", integer=False)) + ax1.set_xlabel("Step") + ax1.set_ylabel("Frequency") + ax1.set_title(f"{model_title}") + # ax1.set_ylim(bottom=0.4) + + # Place legend to the left of the first subplot + ax1.legend(bbox_to_anchor=(-0.3, 0.5), loc='center right') + + plt.tight_layout() + if threshold is None: + plt.savefig( + f"evaluate/distributional_paper/ratio_plots/{model_title.replace(' ', '_')}_{cfg['title_suffix'].replace(' ', '_')}.pdf", + bbox_inches="tight" + ) + else: + plt.savefig( + f"evaluate/distributional_paper/ratio_plots/{model_title.replace(' ', '_')}_{cfg['title_suffix'].replace(' ', '_')}_t={threshold}.pdf", + bbox_inches="tight" + ) + plt.close() + + +def optimization_progress_plot( + results: dict[str,np.ndarray], + title: str = "Optimization Progress Analysis", + metric: tuple[str, ...] = ('scores', 'strong_reject', 'p_harmful'), + cumulative: bool = False, + threshold: float|None = None, + step_subset: tuple[int, ...] = None, + verbose: bool = True, +): + """ + Plot optimization progress by showing harmfulness level vs number of samples + for different optimization steps. This reveals whether attacks actually improve + during optimization. + + Parameters + ---------- + results : dict + Attack results data + title : str + Plot title + metric : tuple + Metric to analyze + cumulative : bool + Whether to use cumulative sampling + threshold : float, optional + Threshold for binary classification + step_subset : tuple, optional + Specific steps to show. If None, shows evenly spaced steps + verbose : bool + Whether to print verbose output + """ + y, _, _, _ = preprocess_data(results, metric, threshold, None) + n_runs, n_steps, n_total_samples = y.shape + + # Select steps to show + if step_subset is None: + # Show specific steps: 0, 1, 2, 5, 10, 20, 50, 100, 200, 500 if they exist, plus the last one + candidate_steps = [0, 1, 2, 5, 10, 20, 50, 100, 200, 500] + steps_to_show = [s for s in candidate_steps if s < n_steps] + # Always include the last step if it's not already included + if n_steps > 0 and (n_steps - 1) not in steps_to_show: + steps_to_show.append(n_steps - 1) + else: + steps_to_show = [s for s in step_subset if s < n_steps] + + plt.figure(figsize=(8, 5)) + + # Color map for different steps + colors = plt.cm.viridis(np.linspace(0, 1, len(steps_to_show))) + + rng = np.random.default_rng(42) # Fixed seed for reproducibility + + # For each optimization step, create a curve + for i, step_idx in enumerate(steps_to_show): + sample_counts = range(1, n_total_samples + 1) + harm_levels = [] + + # Calculate harmfulness level for different sample counts at this step + for n_samples in sample_counts: + step_values = [] + + # Multiple random samples for smoother curves + for _ in range(50): # 50 random samples for averaging + if cumulative and step_idx > 0: + # Cumulative: max of (best from previous steps, best from current step) + samples_up_to_now = y[:, :step_idx, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0] + samples_at_step = y[:, step_idx, rng.choice(n_total_samples, size=n_samples, replace=False)].max(axis=-1) + step_value = np.stack([samples_up_to_now, samples_at_step], axis=1).max(axis=1).mean(axis=0) + else: + # Non-cumulative: just the best from current step + step_value = y[:, step_idx, rng.choice(n_total_samples, size=n_samples, replace=False)].max(axis=-1).mean(axis=0) + + step_values.append(step_value) + + harm_levels.append(np.mean(step_values)) + + # Plot the curve for this optimization step + label = f"Step {step_idx}" if step_idx < n_steps - 1 else f"Step {step_idx + 1} (Final)" + plt.plot(sample_counts, harm_levels, + color=colors[i], linewidth=2.5, marker='o', markersize=4, + label=label, alpha=0.9) + + # Styling + plt.xlabel("Number of Samples", fontsize=14) + if threshold is None: + plt.ylabel(r"$s_{harm@n}$", fontsize=16) + else: + plt.ylabel(r"${ASR}@n$", fontsize=14) + + plt.grid(True, alpha=0.3) + plt.xlim(1, n_total_samples) + plt.ylim(bottom=0) + + # Add legend + plt.legend(loc='lower right', fontsize=11) + + plt.tight_layout() + + # Save plot + if threshold is None: + plt.savefig(f"evaluate/distributional_paper/optimization_progress/{title.replace(' ', '_')}.pdf") + else: + plt.savefig(f"evaluate/distributional_paper/optimization_progress/{title.replace(' ', '_')}_t={threshold}.pdf") + plt.close() + + if verbose: + logging.info(f"Optimization progress plot saved for {title}") + logging.info(f"Steps analyzed: {steps_to_show}") + logging.info(f"Sample range: 1 to {n_total_samples}") + + # Calculate improvement metrics + if len(steps_to_show) >= 2: + first_step = steps_to_show[0] + last_step = steps_to_show[-1] + + # Compare harm levels at max samples + first_harm = [] + last_harm = [] + + for _ in range(50): # More samples for accurate comparison + if cumulative and first_step > 0: + first_val = np.stack([ + y[:, :first_step, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0], + y[:, first_step, rng.choice(n_total_samples, size=n_total_samples, replace=False)].max(axis=-1) + ], axis=1).max(axis=1).mean(axis=0) + else: + first_val = y[:, first_step, rng.choice(n_total_samples, size=n_total_samples, replace=False)].max(axis=-1).mean(axis=0) + + if cumulative and last_step > 0: + last_val = np.stack([ + y[:, :last_step, rng.choice(n_total_samples, size=1, replace=False)].max(axis=1)[:, 0], + y[:, last_step, rng.choice(n_total_samples, size=n_total_samples, replace=False)].max(axis=-1) + ], axis=1).max(axis=1).mean(axis=0) + else: + last_val = y[:, last_step, rng.choice(n_total_samples, size=n_total_samples, replace=False)].max(axis=-1).mean(axis=0) + first_harm.append(first_val) + last_harm.append(last_val) + + improvement = np.mean(last_harm) - np.mean(first_harm) + relative_improvement = improvement / np.mean(first_harm) if np.mean(first_harm) > 0 else 0 + + logging.info(f"Absolute improvement (first to last): {improvement:.4f}") + logging.info(f"Relative improvement: {relative_improvement:.2%}") # Helper --------------------------------------------------------------------------- @@ -1265,7 +2765,7 @@ def run_analysis( DATASET_IDX, GROUP_BY) # ---------- generate plot based on analysis type ---------- - flops_per_step_fn = lambda x: FLOPS_PER_STEP.get(atk_name, lambda x, c: 0)(x, num_model_params(model)) + flops_per_step_fn = lambda x: FLOPs_PER_STEP[atk_name](x, num_model_params(model)) if analysis_type == "pareto": pareto_plot( @@ -1275,14 +2775,24 @@ def run_analysis( cumulative=cfg["cumulative"], metric=METRIC, flops_per_step=flops_per_step_fn, - threshold=None, + threshold=0.5, + color_scale="sqrt", + ) + elif analysis_type == "non_cumulative_pareto": + non_cumulative_pareto_plot( + sampled_data, + baseline_data, + title=f"{model_title} {cfg['title_suffix']}", + metric=METRIC, + flops_per_step=flops_per_step_fn, + threshold=0.5, color_scale="sqrt", ) elif analysis_type == "flops_ratio": flops_ratio_plot( sampled_data, baseline_data, - title=f"{model_title} {cfg['title_suffix']} FLOPS Ratio", + title=f"{model_title} {cfg['title_suffix']} FLOPs Ratio", cumulative=cfg["cumulative"], metric=METRIC, flops_per_step=flops_per_step_fn, @@ -1300,235 +2810,130 @@ def run_analysis( threshold=None, ) elif analysis_type == "histogram": - - plt.figure(figsize=(10, 6)) - data_list = [] - positions = [] - data = np.array(sampled_data[("scores", "strong_reject", "p_harmful")]) - for i in np.arange(data.shape[1]): - data_list.append(data[:, i].flatten()) - positions.append(i) - - # Create 2D heatmap - # Define bins for p_harmful values (y-axis) - p_harmful_bins = np.linspace(0, 1, 101) # 50 bins from 0 to 1 - - # Create 2D histogram matrix - heatmap_data = np.zeros((len(p_harmful_bins)-1, len(positions))) - - for i, (pos, d) in enumerate(zip(positions, data_list)): - # Calculate histogram for this position - counts, _ = np.histogram(d, bins=p_harmful_bins) - heatmap_data[:, i] = counts / len(d) - - # Create the heatmap - im = plt.imshow(heatmap_data, - aspect='auto', - origin='lower', - extent=[positions[0], positions[-1], 0, 1], - cmap='viridis', - norm=LogNorm(vmin=1/len(d), vmax=heatmap_data.max()) + histogram_plot( + sampled_data, + model_title, + atk_name, + cfg, + threshold=None, ) - # Add colorbar - cbar = plt.colorbar(im, label='Density') - - # Calculate and plot median and mean lines - medians = [] - means = [] - for data_at_pos in data_list: - medians.append(np.median(data_at_pos)) - means.append(np.mean(data_at_pos)) - - # Plot mean line - plt.plot(positions, means, color='orange', linewidth=2, label='Mean', alpha=0.8) - plt.plot(positions, np.maximum.accumulate(means), color='red', linewidth=2, label='Max', alpha=0.8) - - # Add legend - plt.legend(loc='upper right', framealpha=0.8) - - plt.xscale("log") - plt.xlabel('Step Index', fontsize=14) - plt.ylabel('p_harmful', fontsize=14) - plt.title(f'{model_title} - {atk_name} - p_harmful Distribution', fontsize=16) - plt.grid(True, alpha=0.3) - plt.tight_layout() - - # Save the plot - filename = f"evaluate/distributional_paper/histograms/{model_title}_{cfg['title_suffix']}.pdf" - plt.savefig(filename, dpi=300, bbox_inches='tight') - plt.close() elif analysis_type == "ridge": ridge_plot( sampled_data, model_title, cfg, + threshold=None, ) elif analysis_type == "histogram_2": - # Create histogram plot - plt.figure(figsize=(10, 6)) - threshold = 0.0 - show_top = False # Set to False to hide the top subplot - bins = np.linspace(0, 1, 21) - - - data_low = np.array(sampled_data[("scores","strong_reject","p_harmful")])[:, 0].flatten() - data_low = data_low[data_low > threshold] - - data_high = np.array(sampled_data[("scores","strong_reject","p_harmful")])[:, -1].flatten() - data_high = data_high[data_high > threshold] - - # Create CDF plot - fig, ax = plt.subplots(figsize=(10, 6)) - - # Calculate survival functions (1 - CDF) - data_low_sorted = np.sort(data_low) - data_high_sorted = np.sort(data_high) - - # Calculate proportion of items with value <= x, then invert to get survival function - cdf_low = np.arange(1, len(data_low_sorted) + 1) / len(data_low_sorted) - cdf_high = np.arange(1, len(data_high_sorted) + 1) / len(data_high_sorted) - - survival_low = 1 - cdf_low - survival_high = 1 - cdf_high - - # Plot survival functions - ax.plot(data_low_sorted, survival_low, label="First Step", linewidth=2, alpha=0.8) - ax.plot(data_high_sorted, survival_high, label="Last Step", linewidth=2, alpha=0.8) - - ax.set_xlabel("p_harmful", fontsize=14) - ax.set_ylabel("Survival Probability (P(X $>$ x))", fontsize=14) - ax.set_xlim(threshold, 1) - ax.set_ylim(0, 1) - ax.legend(fontsize=12) - ax.grid(True, alpha=0.3) - - ax.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful Survival Function", - fontsize=16) - - # Save the plot - filename = f"evaluate/distributional_paper/cdf_plots/{model_title}_{cfg['title_suffix']}.pdf" - plt.savefig(filename, dpi=300, bbox_inches='tight') - plt.close() - - if show_top: - fig, (ax_top, ax_bottom) = plt.subplots( - 2, 1, sharex=True, - figsize=(10, 6), - gridspec_kw={"height_ratios": [1, 3]} # top is shorter - ) - axes = (ax_top, ax_bottom) - else: - fig, ax_bottom = plt.subplots(figsize=(10, 4)) - axes = (ax_bottom,) - - # draw the same two histograms on both axes - for ax in axes: - ax.hist(data_high, bins=bins, alpha=0.7, label="Last Step") - ax.hist(data_low, bins=bins, alpha=0.7, label="First Step") - - if show_top: - # ------------- tune y limits so the top shows only the tall bar(s) ----- - counts_start = np.sort(np.histogram(data_low.flatten(), bins=bins)[0]) - counts_end = np.sort(np.histogram(data_high.flatten(), bins=bins)[0]) - - ylims = ( - (0, max(counts_start[-2], counts_end[-2]) * 1.1), - (min(counts_start[-1], counts_end[-1]) * 0.9, max(counts_start[-1], counts_end[-1]) * 1.1), - ) - ax_top.set_ylim(*ylims[1]) - ax_bottom.set_ylim(*ylims[0]) - - # ---------------- cosmetic fixes ----------------- - ax_top.spines.bottom.set_visible(False) - ax_bottom.spines.top.set_visible(False) - ax_top.tick_params(axis="x", which="both", bottom=False, labelbottom=False) - - # little diagonal "break" marks - d = .5 # size of diagonal lines - kwargs = dict(marker=[(-1, -1), (1, 1)], markersize=12, - linestyle="none", color="k", clip_on=False) - ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **kwargs) - ax_bottom.plot([0, 1], [1, 1], transform=ax_bottom.transAxes, **kwargs) - - # --------- the usual labels ------------ - ax_bottom.set_xlabel("p_harmful", fontsize=14) - ax_bottom.set_ylabel("Frequency (\\%)", fontsize=14) - ax_bottom.set_xlim(threshold, 1) - - if show_top: - ax_top.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful distribution", - fontsize=16) - else: - ax_bottom.set_title(f"{model_title} - {cfg['title_suffix']} - p_harmful distribution", - fontsize=16) - - # Convert y-axis tick labels to percentages - if not show_top: - # Get current y-tick locations - yticks = ax_bottom.get_yticks() - # Convert to percentages by dividing by total count and multiplying by 100 - ytick_labels = [f"{100*tick / len(data_low.flatten()):.0f}%" for tick in yticks] - ax_bottom.set_yticklabels(ytick_labels) - else: - for ax in (ax_bottom,): - # Get current y-tick locations - yticks = ax.get_yticks() - # Convert to percentages by dividing by total count and multiplying by 100 - ytick_labels = [f"{100*tick / len(data_low.flatten()):.0f}%" for tick in yticks] - ax.set_yticklabels(ytick_labels) - # Convert y-axis tick labels to percentages - for ax in (ax_top,): - # Get current y-tick locations - yticks = ax.get_yticks() - # Convert to percentages by dividing by total count and multiplying by 100 - ytick_labels = [f"{100*tick / len(data_low.flatten()):.1f}%" for tick in yticks] - ax.set_yticklabels(ytick_labels) - ax_top.legend() - - if not show_top: - ax_bottom.legend() - - for ax in axes: - ax.grid(True, alpha=0.3) - - plt.tight_layout() - - # Save the plot - filename = f"evaluate/distributional_paper/histograms_2/{model_title}_{cfg['title_suffix']}.pdf" - plt.savefig(filename, dpi=300, bbox_inches='tight') - plt.close() + histogram_2_plot( + sampled_data, + model_title, + cfg, + threshold=None, + ) elif analysis_type == "flops_breakdown": flops_breakdown_plot( sampled_data, baseline_data, - title=f"{model_title} {cfg['title_suffix']} FLOPS Breakdown", + title=f"{model_title} {cfg['title_suffix']} FLOPs Breakdown", cumulative=cfg["cumulative"], metric=METRIC, flops_per_step=flops_per_step_fn, threshold=None, color_scale="sqrt", ) + elif analysis_type == "optimization_progress": + optimization_progress_plot( + sampled_data, + title=f"{model_title} {cfg['title_suffix']} Optimization Progress", + cumulative=False, # each step is considered independent + metric=METRIC, + threshold=None, + ) else: raise ValueError(f"Unknown analysis type: {analysis_type}") +def run_comparative_analysis( + model: str, + model_title: str, + analysis_type: str = "comparative_pareto", + threshold: float|None = None, +): + """ + Run comparative analysis across multiple attacks for a single model. + """ + logging.info(f"{analysis_type.title()} Analysis: {model_title}") + + # Collect data from all attacks for this model + attacks_data = {} + flops_per_step_fns = {} + + for atk_name, cfg in ATTACKS: + try: + # Fetch attack data + sampled_data = fetch_data(model, cfg.get("attack_override", atk_name), cfg["sample_params"](), + DATASET_IDX, GROUP_BY) + + # Apply post-processing if needed + if post := cfg.get("postprocess"): + post(sampled_data, METRIC) + + # Use title_suffix as key to distinguish between different configs of the same attack + config_key = cfg['title_suffix'] + + # Store the data and config + attacks_data[config_key] = (sampled_data, cfg) + + # Store flops function (using default parameter to capture current value) + flops_per_step_fns[config_key] = lambda x, attack=atk_name: FLOPs_PER_STEP[attack](x, num_model_params(model)) + + except Exception as e: + logging.warning(f"Could not load data for {atk_name} ({cfg.get('title_suffix', 'unknown config')}): {e}") + continue + + # Generate comparative plot + if analysis_type == "comparative_pareto": + comparative_pareto_plot( + model=model, + model_title=model_title, + attacks_data=attacks_data, + title=f"{model_title}", + metric=METRIC, + flops_per_step_fns=flops_per_step_fns, + threshold=None, + baseline_attacks={"gcg", "beast", "pair", "autodan"}, + ) + elif analysis_type == "multi_attack_non_cumulative_pareto": + multi_attack_non_cumulative_pareto_plot( + attacks_data=attacks_data, + model_title=model_title, + title=f"{model_title}", + metric=METRIC, + threshold=None, + ) + else: + raise ValueError(f"Unknown comparative analysis type: {analysis_type}") + # ---------------------------------------------------------------------------------- # Configuration and Constants # ---------------------------------------------------------------------------------- MODELS = { - "meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta Llama 3.1 8B", + "meta-llama/Meta-Llama-3.1-8B-Instruct": "Llama 3.1 8B", "google/gemma-3-1b-it": "Gemma 3 1B", - "GraySwanAI/Llama-3-8B-Instruct-RR": "Llama 3 CB", - "Unispac/Llama2-7B-Chat-Augmented": "Llama 2 DeepAlign", + "GraySwanAI/Llama-3-8B-Instruct-RR": "Llama 3 8B CB", + "Unispac/Llama2-7B-Chat-Augmented": "Llama 2 7B DA", } -FLOPS_PER_STEP = { +FLOPs_PER_STEP = { "autodan": lambda s, c: 69845248149248 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, "gcg": lambda s, c: int(1e14) + 14958709489152 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, "beast": lambda s, c: 10447045889280 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, "pair": lambda s, c: 83795198566400 + 78737584640 // num_model_params("Qwen/Qwen2.5-0.5B-Instruct") * c, + "direct": lambda s, c: 0, + "bon": lambda s, c: 0, } # for 0.5B model # Attack-specific configuration ----------------------------------------------------- @@ -1547,9 +2952,11 @@ def run_analysis( title_suffix="BEAST", cumulative=False, sample_params=lambda: { + "mask_undecided_tokens": False, "generation_config": {"num_return_sequences": 50, "temperature": 0.7}, }, baseline_params=lambda: { + "mask_undecided_tokens": False, "generation_config": {"num_return_sequences": 1, "temperature": 0.0}, }, )), @@ -1632,7 +3039,7 @@ def run_analysis( ), )), ("bon", dict( - title_suffix="BoN Repro", + title_suffix="BoN temp 1.0", cumulative=False, sample_params=lambda: {"num_steps": 1000, "generation_config": {"temperature": 1.0}}, baseline_params=lambda: { @@ -1670,30 +3077,98 @@ def run_analysis( METRIC = ("scores", "strong_reject", "p_harmful") GROUP_BY = {"model", "attack_params"} -DATASET_IDX = list(range(75)) +DATASET_IDX = list(range(100)) -def main(fail: bool = False): - for analysis_type in ["pareto", "flops_ratio", "ideal_ratio", "histogram", "histogram_2", "ridge", "flops_breakdown"]: - # for analysis_type in [ "ridge"]: +def main(fail: bool = False, analysis_types=None): + if analysis_types is None: + analysis_types = ["pareto", "non_cumulative_pareto", "flops_ratio", "ideal_ratio", "histogram", "histogram_2", "ridge", "flops_breakdown", "optimization_progress", "comparative_pareto", "multi_attack_non_cumulative_pareto"] + for analysis_type in analysis_types: logging.info("\n" + "="*80) logging.info(f"GENERATING {analysis_type.upper().replace('_', ' ')} PLOTS") logging.info("="*80) - for model_key, model_title in MODELS.items(): - logging.info(f"Model: {model_key}") - for atk_name, atk_cfg in ATTACKS: + if analysis_type in ["comparative_pareto", "multi_attack_non_cumulative_pareto"]: + # For comparative analysis, iterate over models only + for model_key, model_title in MODELS.items(): + logging.info(f"Model: {model_key}") try: - run_analysis(model_key, model_title, atk_name, atk_cfg, analysis_type) + run_comparative_analysis(model_key, model_title, analysis_type) except Exception as e: if fail: raise e - logging.info(f"Error running {analysis_type} analysis for {atk_name}, " - f"cfg: {atk_cfg.get('title_suffix', 'unknown')}: {e}") + logging.info(f"Error running {analysis_type} analysis for {model_title}: {e}") + else: + # For individual attack analysis, iterate over both models and attacks + for model_key, model_title in MODELS.items(): + logging.info(f"Model: {model_key}") + for atk_name, atk_cfg in ATTACKS: + try: + run_analysis(model_key, model_title, atk_name, atk_cfg, analysis_type) + except Exception as e: + if fail: + raise e + logging.info(f"Error running {analysis_type} analysis for {atk_name}, " + f"cfg: {atk_cfg.get('title_suffix', 'unknown')}: {e}") + +def make_hero_plot(): + asr_labels = ["PAIR", "AutoDAN", "GCG"] + asr_delta = [0.16, 0.21, 0.37] + speedup_labels = asr_labels + speedups = [2.7, 8.9, 137.5] + # Create side-by-side subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 2.65)) + cmap = plt.get_cmap("viridis") + + # Left subplot: Speedup + bars1 = ax1.bar(speedup_labels, speedups, color=cmap(np.linspace(0, 1, len(speedup_labels)+1)[1:]), alpha=0.8, edgecolor='black') + ax1.set_ylabel("Speedup (FLOPs)", fontsize=16) + ax1.set_ylim(0, 170) + # ax1.set_title("Computational Efficiency", fontsize=14) + ax1.grid(True, alpha=0.3, axis='y') + # ax1.set_xticks(["PAIR", "AutoDAN", "GCG"], rotation=45, ha='right') + + # Add horizontal line at y=1 for reference + # ax1.axhline(y=1, color='red', linestyle='--', alpha=0.7, linewidth=1) + + # Add value labels on bars + for bar, value in zip(bars1, speedups): + ax1.annotate(f'{value:.1f}x', + xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), + xytext=(0, 5), + textcoords='offset points', + ha='center', va='bottom', fontsize=14) + + # Right subplot: ASR Delta + bars2 = ax2.bar(asr_labels, asr_delta, color=cmap(np.linspace(0, 1, len(asr_labels)+1)[1:]), alpha=0.8, edgecolor='black') + ax2.set_ylabel(r"$\Delta$ ASR", fontsize=16) + ax2.set_ylim(0, 0.45) + # ax2.set_title("Attack Success Rate Improvement", fontsize=14) + ax2.grid(True, alpha=0.3, axis='y') + # ax2.set_xticks(labels=["PAIR", "AutoDAN", "GCG"], rotation=45, ha='right') + # Make ticks bigger + ax1.tick_params(axis='both', which='major', labelsize=14) + ax2.tick_params(axis='both', which='major', labelsize=14) + ax1.tick_params(axis='x', which='major', labelrotation=45) + ax2.tick_params(axis='x', which='major', labelrotation=45) + + # Add value labels on bars + for bar, value in zip(bars2, asr_delta): + ax2.annotate(f'+{value:.2f}' if value > 0 else f'{value:.2f}', + xy=(bar.get_x() + bar.get_width()/2, bar.get_height()), + xytext=(0, 5), + textcoords='offset points', + ha='center', va='bottom', fontsize=14) + + plt.tight_layout() + plt.savefig("evaluate/distributional_paper/mini_hero_plot.pdf", dpi=300, bbox_inches='tight') + plt.close() if __name__ == "__main__": + make_hero_plot() import argparse parser = argparse.ArgumentParser(description='Generate plots for distributional paper') parser.add_argument('--fail', action='store_true', help='Override flag to fail') + parser.add_argument('--analysis_types', "-p", nargs='+', help='Analysis types to run') args = parser.parse_args() - main(args.fail) + main(args.fail, args.analysis_types)