Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion align_system/algorithms/outlines_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CharacterTagEnum,
KDMAValue
)
import ubelt as ub

from align_system.utils import logging
from align_system.utils import adm_utils
Expand Down Expand Up @@ -381,6 +382,25 @@ def choose_action(self, scenario_state, available_actions, alignment_target, **k
return action_to_take, choice_info

def populate_action_parameters(self, scenario_state, action_to_take, dialog):
scenario_state_copy = copy.deepcopy(scenario_state)
# Don't consider the elapsed_time of the state when caching
scenario_state_copy.elapsed_time = 0
depends = '\n'.join((
repr(self.model.model),
repr(scenario_state_copy),
repr(action_to_take),
repr(dialog)))

cacher = ub.Cacher('outlines_adm_populate_action_params', depends, verbose=0)
log.debug(f'cacher.fpath={cacher.fpath}')

cached_output = cacher.tryload()
if cached_output is not None:
log.info("Cache hit for `populate_action_parameters` returning cached output")
return cached_output
else:
log.info("Cache miss for `populate_action_parameters` ..")

if action_to_take.action_type in {ActionTypeEnum.APPLY_TREATMENT,
ActionTypeEnum.TAG_CHARACTER,
ActionTypeEnum.CHECK_ALL_VITALS,
Expand Down Expand Up @@ -469,7 +489,10 @@ def populate_action_parameters(self, scenario_state, action_to_take, dialog):
selected_character_idx,
dialog)

return action_to_take, dialog
outputs = (action_to_take, dialog)
cacher.save(outputs)

return outputs

def ensure_character_id_is_populated(self,
scenario_state,
Expand Down
57 changes: 55 additions & 2 deletions align_system/algorithms/outlines_regression_adm_comparative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import jinja2
import json
import numpy as np
import copy

import outlines
from outlines.samplers import MultinomialSampler
from rich.highlighter import JSONHighlighter
from swagger_client.models import kdma_value
import ubelt as ub

from align_system.utils import logging
from align_system.utils import adm_utils
Expand Down Expand Up @@ -130,6 +132,28 @@ def sample_relevance_predictions(self,
'''
Samples prediction of the relevance of each response to each KDMA
'''
scenario_state_copy = copy.deepcopy(scenario_state)
# Don't consider the elapsed_time of the state when caching
scenario_state_copy.elapsed_time = 0
depends = '\n'.join((
repr(self.model.model),
repr(scenario_state_copy),
repr(scenario_description),
repr(choices),
repr([t['kdma'] for t in target_kdmas]),
repr(available_actions),
repr(incontext_settings)))

cacher = ub.Cacher('comp_reg_relevance', depends, verbose=0)
log.debug(f'cacher.fpath={cacher.fpath}')

cached_output = cacher.tryload()
if cached_output is not None:
log.info("Cache hit for `sample_relevance_predictions` returning cached output")
return cached_output
else:
log.info("Cache miss for `sample_relevance_predictions` ..")

use_icl = False
if "number" in incontext_settings and incontext_settings["number"] > 0:
use_icl = True
Expand Down Expand Up @@ -216,7 +240,10 @@ def sample_relevance_predictions(self,
else:
predictions[choice][kdma_key] = 0

return predictions, reasonings, icl_example_responses
outputs = (predictions, reasonings, icl_example_responses)
cacher.save(outputs)

return outputs

def sample_kdma_score_predictions(self,
scenario_state,
Expand All @@ -236,6 +263,29 @@ def sample_kdma_score_predictions(self,
- predictions: {choice1:{kdma1:[score1(int), ...], ...}, ...}
- reasonings: {choice1:{kdma1:[reasoning1(str), ...], ...}, ...}
'''
scenario_state_copy = copy.deepcopy(scenario_state)
# Don't consider the elapsed_time of the state when caching
scenario_state_copy.elapsed_time = 0
depends = '\n'.join((
repr(self.model.model),
repr(scenario_state_copy),
repr(choices),
repr(available_actions),
repr(outcome_predictions),
repr(kdma_score_examples),
repr(enum_scores),
repr(incontext_settings)))

cacher = ub.Cacher('comp_reg_kdma_estimation', depends, verbose=0)
log.debug(f'cacher.fpath={cacher.fpath}')

cached_output = cacher.tryload()
if cached_output is not None:
log.info("Cache hit for `sample_kdma_score_predictions` returning cached output")
return cached_output
else:
log.info("Cache miss for `sample_kdma_score_predictions` ..")

use_icl = False
if "number" in incontext_settings and incontext_settings["number"] > 0:
use_icl = True
Expand Down Expand Up @@ -346,7 +396,10 @@ def sample_kdma_score_predictions(self,
# Scale score to be between 0 and 1 to match targets
predictions[choice][kdma_key].append(kdma_prediction[choice]['score'] / kdma_factor)

return predictions, reasonings, icl_example_responses
outputs = (predictions, reasonings, icl_example_responses)
cacher.save(outputs)

return outputs

# Returns the outcome prediction (if there was one) and score reasoning for the best sample of the selected choice
def get_selected_choice_reasoning(self, selected_choice, best_sample_index, outcome_predictions, reasonings, relevance_reasonings=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# @package _global_
defaults:
- override /adm: outlines_regression_aligned_comparative/incontext_phase1
- override /interface: ta3

interface:
api_endpoint: "https://darpaitm.caci.com"
session_type: adept
training_session: null
username: "ALIGN-ADM-ComparativeRegression-ADEPT"

adm:
instance:
precision: half
model_name: mistralai/Mistral-7B-Instruct-v0.3
sampler:
_target_: outlines.samplers.GreedySampler
inference_kwargs:
distribution_matching: average # no rel
predict_relevance: false # no rel
kdma_score_examples: true
num_samples: 1
predict_outcomes: false
generator_batch_size: 5
incontext:
method: prompt_bert_similarity
sort_actions: true
normalization: null
number: 5
leave_one_out_strategy: null
most_similar_first: false

force_determinism: true
align_to_target: true
save_last_unstructured_state_per_scenario: true

hydra:
run:
dir: 'multi_experiment_live/ALIGN-ADM-ComparativeRegression-Mistral-7B-Instruct-v0.3-ADEPT/${now:%Y-%m-%d__%H-%M-%S}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# @package _global_
defaults:
- override /adm: outlines_regression_aligned_comparative/incontext_phase1
- override /interface: ta3

interface:
api_endpoint: "https://darpaitm.caci.com"
session_type: adept
training_session: null
username: "ALIGN-ADM-RelevanceComparativeRegression-ADEPT"

adm:
instance:
precision: half
model_name: mistralai/Mistral-7B-Instruct-v0.3
sampler:
_target_: outlines.samplers.GreedySampler
inference_kwargs:
distribution_matching: relevance_average # use rel
predict_relevance: true # use rel
kdma_score_examples: true
num_samples: 1
predict_outcomes: false
generator_batch_size: 5
incontext:
method: prompt_bert_similarity
sort_actions: true
normalization: null
number: 5
leave_one_out_strategy: null
most_similar_first: false

force_determinism: true
align_to_target: true
save_last_unstructured_state_per_scenario: true

hydra:
run:
dir: 'multi_experiment_live/ALIGN-ADM-RelevanceComparativeRegression-Mistral-7B-Instruct-v0.3-ADEPT/${now:%Y-%m-%d__%H-%M-%S}'
25 changes: 24 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ outlines = "^0.0.46"
setuptools = "^70.1.1"
sentencepiece = "^0.2.0"
protobuf = "^5.28.3"
ubelt = "1.3.6"

[tool.poetry.scripts]
run_align_system = 'align_system.cli.run_align_system:main'
Expand Down