-
Notifications
You must be signed in to change notification settings - Fork 3
Added locks for safer async execution #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| from .base import VerifyMonitor | ||
| from interwhen.utils.EAT_helper import compute_entropy, exponential_moving_average, exponential_moving_variance | ||
| from interwhen.utils.DEER_helper import stream_and_compute_geom_mean | ||
| import gc | ||
|
|
||
|
|
||
| class EATMonitor(VerifyMonitor): | ||
|
|
@@ -33,12 +34,27 @@ def __init__(self, name, model_name, alpha=0.2, delta=0.0001, | |
| ) | ||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | ||
|
|
||
| # Instantiate Lock for safer async execution | ||
| self.lock = asyncio.Lock() | ||
|
|
||
| # State tracking | ||
| self.entropy = [] | ||
| self.ema_means = [] | ||
| self.ema_vars = [] | ||
| self.exit_point = None | ||
|
|
||
| def reset(self): | ||
| """Reset monitor state for a new problem without reloading the model.""" | ||
| self.entropy = [] | ||
| self.ema_means = [] | ||
| self.ema_vars = [] | ||
| self.exit_point = None | ||
| gc.collect() | ||
| try: | ||
| torch.cuda.empty_cache() | ||
| except Exception as e: | ||
| print("Error while emptying cuda cache: ",e) | ||
|
Comment on lines
+46
to
+56
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is not being used any more, right? If not, we can remove this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be useful in some cases. For instance, consider the case where a user evaluates n samples. Monitors like EAT and DEER can be initialized just once instead of initializing separately for each of the n samples, and after each iteration, they can be reset. This helps reduce the overhead of repeated initialization for these monitors. |
||
|
|
||
| async def _verify(self, generated_text, token_index): | ||
| """ | ||
| Core verification logic using entropy. | ||
|
|
@@ -47,30 +63,26 @@ async def _verify(self, generated_text, token_index): | |
|
|
||
| # We append this tail so that we can compute entropy for next token (answer) | ||
| partial_answer = (generated_text + "\n\n</think>" + "\n\n" + 'Final answer is \\boxed{') | ||
| entropy_2 = compute_entropy( | ||
| self.hf_model, | ||
| self.tokenizer, | ||
| partial_answer, | ||
| ) | ||
|
|
||
| self.entropy.append(entropy_2) | ||
| ema_average = exponential_moving_average(self.entropy, self.alpha) | ||
| ema_variance = exponential_moving_variance(self.entropy, self.alpha, 0.0) | ||
|
|
||
| self.ema_means.append(ema_average[-1]) | ||
| self.ema_vars.append(ema_variance[-1]) | ||
|
|
||
| # Early stopping not triggered unless min_steps number of steps have been processed | ||
| if len(self.entropy) < self.min_steps: | ||
| entropy_2 = await asyncio.to_thread(compute_entropy, self.hf_model, self.tokenizer, partial_answer) | ||
| async with self.lock: | ||
| self.entropy.append(entropy_2) | ||
| ema_average = exponential_moving_average(self.entropy, self.alpha) | ||
| ema_variance = exponential_moving_variance(self.entropy, self.alpha, 0.0) | ||
|
|
||
| self.ema_means.append(ema_average[-1]) | ||
| self.ema_vars.append(ema_variance[-1]) | ||
|
|
||
| # Early stopping not triggered unless min_steps number of steps have been processed | ||
| if len(self.entropy) < self.min_steps: | ||
| return (True, None, token_index) | ||
|
|
||
| # Intervene if variance is below threshold | ||
| if ema_variance[-1] < self.delta: | ||
| self.exit_point = len(self.entropy) | ||
| # Return False to trigger early stop | ||
| return (False, generated_text, token_index) | ||
|
|
||
| return (True, None, token_index) | ||
|
|
||
| # Intervene if variance is below threshold | ||
| if ema_variance[-1] < self.delta: | ||
| self.exit_point = len(self.entropy) | ||
| # Return False to trigger early stop | ||
| return (False, generated_text, token_index) | ||
|
|
||
| return (True, None, token_index) | ||
|
|
||
| async def verify(self, step, token_index, event, event_info): | ||
| """ | ||
|
|
@@ -82,20 +94,20 @@ async def verify(self, step, token_index, event, event_info): | |
| return step, None | ||
|
|
||
| # Early stop triggered | ||
| if not event.is_set(): | ||
| event_info["generated_text"] = step | ||
| event_info["feedback"] = feedback | ||
| event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix | ||
| event_info["entropy_history"] = self.entropy.copy() | ||
| event_info["ema_variance"] = self.ema_vars[-1] if self.ema_vars else None | ||
| event.set() | ||
| async with self.lock: | ||
| if not event.is_set(): | ||
| event_info["generated_text"] = step | ||
| event_info["feedback"] = feedback | ||
| event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix | ||
| event_info["entropy_history"] = self.entropy.copy() | ||
| event_info["ema_variance"] = self.ema_vars[-1] if self.ema_vars else None | ||
| event.set() | ||
|
|
||
| async def fix(self, generated_text, event_info, fix_method=None): | ||
| """ | ||
| Appending the </think> to force the thinking process to conclude. | ||
| """ | ||
| fixed_text = generated_text[:event_info['correction_index']] + "\n\n</think>" | ||
| print("VISHAAAAAAAAAAAAAAAK"*100) | ||
| return fixed_text | ||
|
|
||
| def step_extractor(self, chunk, generated_text): | ||
|
|
@@ -128,6 +140,17 @@ def __init__(self, name, llm_server, delta=0.995, answer_start_token="</think>", | |
| self.max_probe_steps = max_probe_steps | ||
| self.answer_start_token = answer_start_token | ||
| self.confidence = [] | ||
| # Instantiate Lock for safer async execution | ||
| self.lock = asyncio.Lock() | ||
|
|
||
| def reset(self): | ||
| """Reset monitor state for a new problem.""" | ||
| self.confidence = [] | ||
| gc.collect() | ||
| try: | ||
| torch.cuda.empty_cache() | ||
| except Exception as e: | ||
| print("Error while emptying cuda cache: ",e) | ||
|
Comment on lines
+146
to
+153
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here too
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in an earlier comment |
||
|
|
||
| async def _verify(self, generated_text, token_index): | ||
| """ | ||
|
|
@@ -137,14 +160,19 @@ async def _verify(self, generated_text, token_index): | |
|
|
||
| # We apppend this tail so that we can compute confidence for the answer | ||
| partial_answer = (generated_text + "\n\n</think>" + "\n\n" + 'Final answer is \\boxed{') | ||
| self.llm_server["payload"]["prompt"] = partial_answer | ||
| confidence = stream_and_compute_geom_mean(self.llm_server) | ||
| self.confidence.append(confidence) | ||
|
|
||
| # Create copy to avoid mutating shared state | ||
| payload_copy = {**self.llm_server["payload"], "prompt": partial_answer} | ||
| server_copy = {**self.llm_server, "payload": payload_copy} | ||
|
|
||
| confidence = await asyncio.to_thread(stream_and_compute_geom_mean, server_copy) | ||
|
|
||
| if confidence > self.delta: | ||
| return False, generated_text, token_index | ||
| async with self.lock: | ||
| self.confidence.append(confidence) | ||
| if confidence > self.delta: | ||
| return False, generated_text, token_index | ||
|
|
||
| return (True, None, token_index) | ||
| return (True, None, token_index) | ||
|
|
||
| async def verify(self, step, token_index, event, event_info): | ||
| """ | ||
|
|
@@ -156,18 +184,18 @@ async def verify(self, step, token_index, event, event_info): | |
| return step, None | ||
|
|
||
| # Early stop triggered | ||
| if not event.is_set(): | ||
| event_info["generated_text"] = step | ||
| event_info["feedback"] = feedback | ||
| event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix | ||
| event_info["confidence_history"] = self.confidence.copy() | ||
| event.set() | ||
| async with self.lock: | ||
| if not event.is_set(): | ||
| event_info["generated_text"] = step | ||
| event_info["feedback"] = feedback | ||
| event_info["correction_index"] = len(step) # so we know where to slice the gen text during fix | ||
| event_info["confidence_history"] = self.confidence.copy() | ||
| event.set() | ||
|
|
||
| async def fix(self, generated_text, event_info, fix_method=None): | ||
| """ | ||
| Appending </think> to force the thinking process to conclude. | ||
| """ | ||
| # Append answer prompt to conclude | ||
| fixed_text = generated_text[:event_info['correction_index']] + "\n\n</think>" | ||
| return fixed_text | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the latency difference after adding the lock across all the monitors ?