Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass at copying over the changes to enable chat_kdma_predicting_adm to work with any chat LLM #31

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions align_system/algorithms/chat_kdma_predicting_adm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import json
import yaml
import os
from typing import Union, List, Dict, Tuple, Optional, TextIO
from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel
from align_system.algorithms.lib.algorithmic_decision_maker import AlgorithmicDecisionMaker

class ChatKDMAPredictingADM(ChatLanguageModel, AlgorithmicDecisionMaker):

def predict_outcomes(self,
scenario_text: str,
probe_text: str,
choices: List[str],
log_file: Optional[TextIO] = None,
max_tokens: int = 512,
temperature: float = 0.6,
outcome_template_file: str = 'pred_outcome.txt') -> List[str]:
"""
Predicts outcomes for given scenario, probe and choices.

:param scenario: Scenario text.
:param probe: Probe text.
:param choices: Choices text.
:param log_file: Optional log file.
:param max_tokens: Maximum number of tokens to generate.
:param temperature: Temperature for sampling.
:param outcome_template_file: Template file for Outcomes.
:return: List of generated predictions.
"""
return self.generate_from_template(
outcome_template_file,
[
{
'scenario': scenario_text,
'probe': probe_text,
'choice': choice,
}
for choice in choices
],
log_file=log_file,
max_tokens=max_tokens,
temperature=temperature
)


def predict_kdma_scores(self,
scenario_text: str,
probe_text: str,
choice_texts: List[str],
predicted_outcomes: Optional[List[str]] = None,
generate_reasoning: bool = True,
log_file: Optional[TextIO] = None,
max_new_tokens: int = 512,
temperature: float = 0.6,
kdma_template_file: str = 'pred_kdma_RO.txt',
kdma_descriptions_file: str = 'lib/templates/bbn_kdma_descriptions.yml') -> Union[List[Dict[str, float]], Tuple[List[Dict[str, float]], List[Dict[str, str]]]]:
"""
Predicts KDMA scores each choice text under the given scenario and probe.

:param scenario_text: Scenario text.
:param probe_text: Probe text.
:param choice_texts: Choices text.
:param predicted_outcomes: Predicted outcomes.
:param generate_reasoning: Flag to generate reasoning.
:param log_file: Optional log file.
:param max_new_tokens: Maximum number of new tokens to generate.
:param temperature: Temperature for sampling.
:param kdma_template_file: Template file for KDMA prediction.
:param kdma_descriptions_file: Template file for KDMA descriptions.
:return: KDMA predictions. If generate_reasoning is True, return predictions and reasonings.
"""
choice_ids = [f'choice_{i}' for i in range(len(choice_texts))]
substitutions = []
info = []

relative_dir = os.path.dirname(__file__)
kdma_descriptions_file_path = os.path.join(relative_dir, kdma_descriptions_file)

with open(kdma_descriptions_file_path, 'r') as f:
kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader)

if predicted_outcomes is None:
predicted_outcomes = [None] * len(choice_texts)

for choice_id, choice, outcome in zip(choice_ids, choice_texts, predicted_outcomes):
for kdma, kdma_info in kdma_descriptions.items():
substitution = {
'kdma': kdma_info['name'],
'kdma_description': kdma_info['description'],
'scenario': scenario_text,
'probe': probe_text,
'choice': choice,
}

if outcome is not None:
substitution['outcome'] = outcome

substitutions.append(substitution)
info.append((choice_id, kdma))

def parse_kdma_score_response(response: str) -> Dict[str, Union[float, str]]:
"""
Parses KDMA score response.

:param response: Response to parse.
:return: Dictionary with KDMA score and reasoning if generate_reasoning.
"""
if generate_reasoning:
start_idx = response.find('{')
end_idx = response.rfind('}')
response_json = json.loads(response[start_idx:end_idx+1])
assert 'score' in response_json, 'score not found in response'
assert 'reasoning' in response_json, 'reasoning not found in response'
else:
# find the first numeric character
char = None
for c in response:
if c.isnumeric():
char = c
break
assert char is not None, 'Could not find numeric character in response'
response_json = {
'score': float(response[response.find(char):])
}
return response_json

generations = self.generate_from_template(
kdma_template_file,
substitutions,
parse_kdma_score_response,
log_file=log_file,
max_tokens=max_new_tokens,
temperature=temperature,
)

predicted_kdmas = {}
reasonings = {}
for (choice_id, kdma), generation in zip(info, generations):
predicted_choice_kdmas = predicted_kdmas.get(choice_id, {})
predicted_kdmas[choice_id] = predicted_choice_kdmas

choice_reasonings = reasonings.get(choice_id, {})
reasonings[choice_id] = choice_reasonings

predicted_choice_kdmas[kdma] = generation['score']

if generate_reasoning:
choice_reasonings[kdma] = generation['reasoning']

predicted_kdmas = [
predicted_kdmas[choice_id]
for choice_id in choice_ids
]
if generate_reasoning:
reasonings = [
reasonings[choice_id]
for choice_id in choice_ids
]

if generate_reasoning:
return predicted_kdmas, reasonings
else:
return predicted_kdmas


def __call__(self, sample, **kwargs):
target_kdmas = sample['target_kdmas']
scenario_text = sample['scenario']
if sample['state'] is not None:
scenario_text += f'\n{sample["state"]}'

predicted_outcomes = self.predict_outcomes(
scenario_text,
sample['probe'],
sample['choices'],
**kwargs
)

predicted_kdmas, generated_reasoning = self.predict_kdma_scores(
scenario_text,
sample['probe'],
sample['choices'],
predicted_outcomes=predicted_outcomes,
**kwargs
)

def mse(target_kdmas, predicted_kdmas):
kdmas = set(target_kdmas.keys()) & set(predicted_kdmas.keys())

if len(kdmas) == 0:
return 0

return sum([(target_kdmas[kdma] - predicted_kdmas[kdma])**2 for kdma in kdmas]) / len(kdmas)

# find index of min mse
choice_idx = 0
min_mse = float('inf')
for i, choice in enumerate(sample['choices']):
mse_ = mse(target_kdmas, predicted_kdmas[i])
if mse_ < min_mse:
min_mse = mse_
choice_idx = i

return {
'choice': choice_idx,
'predicted_kdmas': predicted_kdmas,
'info': {
'predicted_outcomes': predicted_outcomes,
'generated_reasoning': generated_reasoning,
}
}
17 changes: 17 additions & 0 deletions align_system/algorithms/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from importlib import reload

def reload_all():
# Useful function for developing in an interactive environment without having to restart the kernel

from align_system.algorithms.lib import util
from align_system.algorithms.lib import language_model as lm
from align_system.algorithms.lib.chat import dialog_tokenizer as dt
from align_system.algorithms.lib.chat import chat_language_model as clm

from align_system.algorithms import chat_kdma_predicting_adm as kpa
from align_system.algorithms import llama_2_single_kdma_adm as ska


# Reload in the correct order
for module in [util, lm, dt, clm, kpa, ska]:
reload(module)
29 changes: 29 additions & 0 deletions align_system/algorithms/lib/algorithmic_decision_maker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import abstractmethod

# ADM sub-classes implement all the algorithm-specific logic
class AlgorithmicDecisionMaker:

@abstractmethod
def __call__(self, sample, **kwargs):
'''
sample = {
target_kdmas: { ... }
scenario,
state,
probe,
choices: [
choice_text,
...
]
}
returns {
choice: idx, [required]
predicted_kdmas: { [optional]
0: {
kdma_name: kdma_value,
},
1: { ... }
}
}
'''
pass
Empty file.
Loading