diff --git a/align_system/algorithms/chat_kdma_predicting_adm.py b/align_system/algorithms/chat_kdma_predicting_adm.py new file mode 100644 index 00000000..87854707 --- /dev/null +++ b/align_system/algorithms/chat_kdma_predicting_adm.py @@ -0,0 +1,211 @@ +import json +import yaml +import os +from typing import Union, List, Dict, Tuple, Optional, TextIO +from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel +from align_system.algorithms.lib.algorithmic_decision_maker import AlgorithmicDecisionMaker + +class ChatKDMAPredictingADM(ChatLanguageModel, AlgorithmicDecisionMaker): + + def predict_outcomes(self, + scenario_text: str, + probe_text: str, + choices: List[str], + log_file: Optional[TextIO] = None, + max_tokens: int = 512, + temperature: float = 0.6, + outcome_template_file: str = 'pred_outcome.txt') -> List[str]: + """ + Predicts outcomes for given scenario, probe and choices. + + :param scenario: Scenario text. + :param probe: Probe text. + :param choices: Choices text. + :param log_file: Optional log file. + :param max_tokens: Maximum number of tokens to generate. + :param temperature: Temperature for sampling. + :param outcome_template_file: Template file for Outcomes. + :return: List of generated predictions. + """ + return self.generate_from_template( + outcome_template_file, + [ + { + 'scenario': scenario_text, + 'probe': probe_text, + 'choice': choice, + } + for choice in choices + ], + log_file=log_file, + max_tokens=max_tokens, + temperature=temperature + ) + + + def predict_kdma_scores(self, + scenario_text: str, + probe_text: str, + choice_texts: List[str], + predicted_outcomes: Optional[List[str]] = None, + generate_reasoning: bool = True, + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6, + kdma_template_file: str = 'pred_kdma_RO.txt', + kdma_descriptions_file: str = 'lib/templates/bbn_kdma_descriptions.yml') -> Union[List[Dict[str, float]], Tuple[List[Dict[str, float]], List[Dict[str, str]]]]: + """ + Predicts KDMA scores each choice text under the given scenario and probe. + + :param scenario_text: Scenario text. + :param probe_text: Probe text. + :param choice_texts: Choices text. + :param predicted_outcomes: Predicted outcomes. + :param generate_reasoning: Flag to generate reasoning. + :param log_file: Optional log file. + :param max_new_tokens: Maximum number of new tokens to generate. + :param temperature: Temperature for sampling. + :param kdma_template_file: Template file for KDMA prediction. + :param kdma_descriptions_file: Template file for KDMA descriptions. + :return: KDMA predictions. If generate_reasoning is True, return predictions and reasonings. + """ + choice_ids = [f'choice_{i}' for i in range(len(choice_texts))] + substitutions = [] + info = [] + + relative_dir = os.path.dirname(__file__) + kdma_descriptions_file_path = os.path.join(relative_dir, kdma_descriptions_file) + + with open(kdma_descriptions_file_path, 'r') as f: + kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader) + + if predicted_outcomes is None: + predicted_outcomes = [None] * len(choice_texts) + + for choice_id, choice, outcome in zip(choice_ids, choice_texts, predicted_outcomes): + for kdma, kdma_info in kdma_descriptions.items(): + substitution = { + 'kdma': kdma_info['name'], + 'kdma_description': kdma_info['description'], + 'scenario': scenario_text, + 'probe': probe_text, + 'choice': choice, + } + + if outcome is not None: + substitution['outcome'] = outcome + + substitutions.append(substitution) + info.append((choice_id, kdma)) + + def parse_kdma_score_response(response: str) -> Dict[str, Union[float, str]]: + """ + Parses KDMA score response. + + :param response: Response to parse. + :return: Dictionary with KDMA score and reasoning if generate_reasoning. + """ + if generate_reasoning: + start_idx = response.find('{') + end_idx = response.rfind('}') + response_json = json.loads(response[start_idx:end_idx+1]) + assert 'score' in response_json, 'score not found in response' + assert 'reasoning' in response_json, 'reasoning not found in response' + else: + # find the first numeric character + char = None + for c in response: + if c.isnumeric(): + char = c + break + assert char is not None, 'Could not find numeric character in response' + response_json = { + 'score': float(response[response.find(char):]) + } + return response_json + + generations = self.generate_from_template( + kdma_template_file, + substitutions, + parse_kdma_score_response, + log_file=log_file, + max_tokens=max_new_tokens, + temperature=temperature, + ) + + predicted_kdmas = {} + reasonings = {} + for (choice_id, kdma), generation in zip(info, generations): + predicted_choice_kdmas = predicted_kdmas.get(choice_id, {}) + predicted_kdmas[choice_id] = predicted_choice_kdmas + + choice_reasonings = reasonings.get(choice_id, {}) + reasonings[choice_id] = choice_reasonings + + predicted_choice_kdmas[kdma] = generation['score'] + + if generate_reasoning: + choice_reasonings[kdma] = generation['reasoning'] + + predicted_kdmas = [ + predicted_kdmas[choice_id] + for choice_id in choice_ids + ] + if generate_reasoning: + reasonings = [ + reasonings[choice_id] + for choice_id in choice_ids + ] + + if generate_reasoning: + return predicted_kdmas, reasonings + else: + return predicted_kdmas + + + def __call__(self, sample, **kwargs): + target_kdmas = sample['target_kdmas'] + scenario_text = sample['scenario'] + if sample['state'] is not None: + scenario_text += f'\n{sample["state"]}' + + predicted_outcomes = self.predict_outcomes( + scenario_text, + sample['probe'], + sample['choices'], + **kwargs + ) + + predicted_kdmas, generated_reasoning = self.predict_kdma_scores( + scenario_text, + sample['probe'], + sample['choices'], + predicted_outcomes=predicted_outcomes, + **kwargs + ) + + def mse(target_kdmas, predicted_kdmas): + kdmas = set(target_kdmas.keys()) & set(predicted_kdmas.keys()) + + if len(kdmas) == 0: + return 0 + + return sum([(target_kdmas[kdma] - predicted_kdmas[kdma])**2 for kdma in kdmas]) / len(kdmas) + + # find index of min mse + choice_idx = 0 + min_mse = float('inf') + for i, choice in enumerate(sample['choices']): + mse_ = mse(target_kdmas, predicted_kdmas[i]) + if mse_ < min_mse: + min_mse = mse_ + choice_idx = i + + return { + 'choice': choice_idx, + 'predicted_kdmas': predicted_kdmas, + 'info': { + 'predicted_outcomes': predicted_outcomes, + 'generated_reasoning': generated_reasoning, + } + } \ No newline at end of file diff --git a/align_system/algorithms/lib/__init__.py b/align_system/algorithms/lib/__init__.py new file mode 100644 index 00000000..b937c46b --- /dev/null +++ b/align_system/algorithms/lib/__init__.py @@ -0,0 +1,17 @@ +from importlib import reload + +def reload_all(): + # Useful function for developing in an interactive environment without having to restart the kernel + + from align_system.algorithms.lib import util + from align_system.algorithms.lib import language_model as lm + from align_system.algorithms.lib.chat import dialog_tokenizer as dt + from align_system.algorithms.lib.chat import chat_language_model as clm + + from align_system.algorithms import chat_kdma_predicting_adm as kpa + from align_system.algorithms import llama_2_single_kdma_adm as ska + + + # Reload in the correct order + for module in [util, lm, dt, clm, kpa, ska]: + reload(module) \ No newline at end of file diff --git a/align_system/algorithms/lib/algorithmic_decision_maker.py b/align_system/algorithms/lib/algorithmic_decision_maker.py new file mode 100644 index 00000000..cb1b28f2 --- /dev/null +++ b/align_system/algorithms/lib/algorithmic_decision_maker.py @@ -0,0 +1,29 @@ +from abc import abstractmethod + +# ADM sub-classes implement all the algorithm-specific logic +class AlgorithmicDecisionMaker: + + @abstractmethod + def __call__(self, sample, **kwargs): + ''' + sample = { + target_kdmas: { ... } + scenario, + state, + probe, + choices: [ + choice_text, + ... + ] + } + returns { + choice: idx, [required] + predicted_kdmas: { [optional] + 0: { + kdma_name: kdma_value, + }, + 1: { ... } + } + } + ''' + pass \ No newline at end of file diff --git a/align_system/algorithms/lib/chat/__init__.py b/align_system/algorithms/lib/chat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/align_system/algorithms/lib/chat/chat_language_model.py b/align_system/algorithms/lib/chat/chat_language_model.py new file mode 100644 index 00000000..c9754574 --- /dev/null +++ b/align_system/algorithms/lib/chat/chat_language_model.py @@ -0,0 +1,161 @@ +from typing import List, Dict, Optional, Callable, Union, TextIO + +from align_system.algorithms.lib.language_model import LanguageModel +from align_system.algorithms.lib.util import read_template, format_template, dialog_from_string, dialog_to_string +from jinja2.exceptions import TemplateError + +class ChatLanguageModel(LanguageModel): + + def generate_responses(self, + dialogs: List[Dict[str, str]], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: + """ + Generates responses for given dialogs. + + :param dialogs: List of dialogs. + :param log_file: Optional file to log the process. + :param max_new_tokens: Maximum number of new tokens to generate. + :param temperature: Temperature for sampling. + :return: Generated responses. + """ + # If logging is requested, write the dialogues into the log file + if log_file is not None: + log_file.write('**Dialogs:**\n') + for i, dialog in enumerate(dialogs): + log_file.write(f'*Dialog {i}:*\n{dialog_to_string(dialog)}\n') + log_file.flush() + + # Prepare lists for the last user dialogues and prefixes. + # Prefix refers to the assistant's response in the last turn of a dialogue. + user_last_dialogs = [] + prefixes = [] + for dialog in dialogs: + prefix = '' + if dialog[-1]['role'] == 'assistant': + prefix = dialog[-1]['content'] + dialog = dialog[:-1] + user_last_dialogs.append(dialog) + prefixes.append(prefix) + + # Tokenization step + try: + prompt_token_lists = [ + [self.tokenizer.apply_chat_template(dialog, tokenize=True)] + for dialog in user_last_dialogs + ] + except TemplateError as e: + systemless_dialogs = [] + for dialog in user_last_dialogs: + if dialog[0]['role'] == 'system': + dialog[0]['role'] = 'user' + if dialog[1]['role'] == 'user': + dialog[0]['content'] = f"{dialog[0]['content']}\n\n{dialog[1]['content']}" + del dialog[1] + systemless_dialogs.append(dialog) + + prompt_token_lists = [ + [self.tokenizer.apply_chat_template(dialog, tokenize=True)] + for dialog in systemless_dialogs + ] + + + + # Add the prefix tokens to the prompt tokens + for prompt_tokens, prefix in zip(prompt_token_lists, prefixes): + if len(prefix) > 0: + prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) + prompt_tokens[0] += prefix_tokens + + # Generate responses using tokens + prompt_token_lists = [x[0] for x in prompt_token_lists] + responses = self.generate_from_tokens(prompt_token_lists, max_new_tokens=max_new_tokens, temperature=temperature) + prefixed_responses = [ + f'{prefix}{response}' + for prefix, response in zip(prefixes, responses) + ] + + # If logging is requested, write the generated responses into the log file + if log_file is not None: + log_file.write('**Generated Responses:**\n') + for i, response in enumerate(prefixed_responses): + log_file.write(f'*Response {i}:*\n{response}\n') + log_file.flush() + + return prefixed_responses + + def generate_from_template( + self, + template_files: Union[List[str], str], + substitution_dicts: Union[List[Dict[str, str]], Dict[str, str]], + parse_generation_fn: Optional[Callable[[str], str]] = None, + batch_size: int = 5, + log_file: Optional[TextIO] = None, + max_tokens: int = 512, + temperature: float = 0.6, + max_retry: int = 10, + verbose: bool = False) -> List[str]: + """ + Generates responses for given templates with substitutions. + + :param template_files: Template files to use for generation. + :param substitution_dicts: Substitution dictionaries for the templates. + :param parse_generation_fn: Function to parse the generated responses. + :param batch_size: Batch size for generating responses. + :param log_file: Optional file to log the process. + :param max_tokens: Maximum number of tokens to generate. + :param temperature: Temperature for sampling. + :param max_retry: Maximum number of attempts to generate a valid output. + :param verbose: If True, verbose logging is enabled. + :return: Generated responses. + """ + if isinstance(substitution_dicts, dict): + substitution_dicts = [substitution_dicts] + + if isinstance(template_files, str): + template_files = [template_files] * len(substitution_dicts) + + assert len(template_files) == len(substitution_dicts), 'Number of templates and substitutions do not match' + + # Create a dialogue for each template/substitution pair + dialogs = { + i: dialog_from_string(format_template(read_template(template_file), **substitutions)) + for i, (template_file, substitutions) in enumerate(zip(template_files, substitution_dicts)) + } + + outputs = {} + input_counts = {} + while len(dialogs) > 0: + sample_ids = list(dialogs.keys())[:batch_size] + batch = [dialogs[i] for i in sample_ids] + generations = self.generate_responses(batch, log_file=log_file, max_new_tokens=max_tokens, temperature=temperature) + + # Process the generated responses + for sample_id, generation in zip(sample_ids, generations): + input_counts[sample_id] = input_counts.get(sample_id, 0) + 1 + + # If the maximum number of try-outs is exceeded, throw an error + if input_counts[sample_id] > max_retry: + raise Exception(f'Could not generate valid output for sample [{sample_id}]') + + # If there's a specific function to parse the generations, try to apply it + if parse_generation_fn is not None: + try: + outputs[sample_id] = parse_generation_fn(generation) + del dialogs[sample_id] + except Exception as e: + if verbose: + print(f'Error: could not parse output for sample [{sample_id}]') + print(e) + pass + else: + outputs[sample_id] = generation + del dialogs[sample_id] + + assert len(outputs) == len(substitution_dicts), 'Unexpected state: number of outputs and substitutions do not match' + + return [ + outputs[i] + for i in range(len(outputs)) + ] \ No newline at end of file diff --git a/align_system/algorithms/lib/language_model.py b/align_system/algorithms/lib/language_model.py new file mode 100644 index 00000000..39e6bb1d --- /dev/null +++ b/align_system/algorithms/lib/language_model.py @@ -0,0 +1,167 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from typing import List, Union, Optional, TextIO + +class LanguageModel: + """ + Class to define the Language Model. + """ + + @classmethod + def load_model(cls, + hf_model_name: str, + precision: torch.dtype = torch.float32, + device: str = 'cuda') -> 'LanguageModel': + """ + Load the language model. + + :param hf_model_name: Name of the model in Huggingface. + :param precision: Precision of the model's weights. + :param device: Device to run the model on. + :return: Initialized LanguageModel object. + """ + # Load the model from Huggingface + model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=precision) + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + model = model.to(device) + return cls(model, tokenizer) + + def __init__(self, + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer) -> None: + """ + Initializes the language model. + + :param model: Pretrained Huggingface model. + :param tokenizer: Tokenizer from Huggingface. + """ + self.model = model + self.tokenizer = tokenizer + + def generate_from_tokens(self, + prompt_token_lists: List[List[int]], + log_file: Union[None, str, object] = None, + max_new_tokens: int = 512, + temperature: float = 0.6, + padding: str='left') -> List[str]: + """ + Generates text from the given list of tokens. + + :param prompt_token_lists: List of lists of tokens to generate the text. + :param log_file: Path to the log file. + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. + :param padding: Padding direction, either 'left' or 'right'. + :return: Generated text. + """ + # Move to the model's device and unpack + prompt_token_lists = [ + torch.tensor(prompt_tokens).to(self.model.device).unsqueeze(0) + for prompt_tokens in prompt_token_lists + ] + + max_length = max([prompt_tokens.size(1) for prompt_tokens in prompt_token_lists]) + + pad_token_id = self.tokenizer.pad_token_id + + # Padding function for the desired direction + assert padding == 'left' or padding == 'right', f"Padding must be either 'left' or 'right', got {padding}" + pad_fn = lambda prompt_token_size: (max_length - prompt_token_size, 0) if padding == 'left' else (0, max_length - prompt_token_size) + + # Pad each sequence to the max length + padded_prompt_token_lists = [ + torch.nn.functional.pad(prompt_tokens, pad_fn(prompt_tokens.size(1)), value=pad_token_id) + for prompt_tokens in prompt_token_lists + ] + + attention_masks = [ + torch.nn.functional.pad(torch.ones_like(prompt_tokens), pad_fn(prompt_tokens.size(1)), value=0) + for prompt_tokens in prompt_token_lists + ] + + position_ids = [ + torch.nn.functional.pad(torch.arange(prompt_tokens.size(1)).unsqueeze(0), pad_fn(prompt_tokens.size(1)), value=0) + for prompt_tokens in prompt_token_lists + ] + + + # Stack the padded sequences + stacked_prompt_tokens = torch.cat(padded_prompt_token_lists, dim=0) + stacked_attention_masks = torch.cat(attention_masks, dim=0) + stacked_position_ids = torch.cat(position_ids, dim=0) + + if log_file is not None: + prompt_texts = [ + self.tokenizer.decode(prompt_tokens.squeeze(0), skip_special_tokens=True) + for prompt_tokens in padded_prompt_token_lists + ] + log_file.write('**Prompt texts:**\n') + for i, prompt_text in enumerate(prompt_texts): + log_file.write(f'Prompt {i}:\n{prompt_text}\n') + + log_file.flush() + + + + # Generate outputs for all dialogs in a batch + # TODO ensure the batch size is not too large for the GPU + outputs = self.model.generate( + stacked_prompt_tokens, + attention_mask=stacked_attention_masks, + # position_ids=stacked_position_ids, # TODO figure out why including the position ids breaks the model + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=max_new_tokens, + temperature=temperature + ) + + # Decode the generated outputs + decoded_outputs = [ + self.tokenizer.decode(output_tokens[len(prompt_tokens.squeeze(0)):], skip_special_tokens=True) + for output_tokens, prompt_tokens in zip(outputs.sequences, padded_prompt_token_lists) + ] + + if log_file is not None: + log_file.write('**Generated texts:**\n') + for i, decoded_output in enumerate(decoded_outputs): + log_file.write(f'*Generation {i}:*\n{decoded_output}\n') + log_file.flush() + + return decoded_outputs + + def generate(self, + prompt_texts: List[str], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: + """ + Generates text from the given list of inputs. + + :param prompt_texts: List of prompts to generate from. + :param log_file: Optional file object to write to + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. + """ + # Convert the text to tokens and generate the text + prompt_token_lists = [self.tokenizer.encode(prompt_text) for prompt_text in prompt_texts] + return self.generate_from_tokens(prompt_token_lists, log_file, max_new_tokens, temperature) + + def generate_with_prefixes(self, + prompt_texts: List[str], + prefixes: List[str], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: + """ + Generates text from the given list of inputs with prefixes. + + :param prompt_texts: List of prompts to generate from. + :param prefixes: List of prefixes to prepend to the generated text. + :param log_file: Optional file object to write to + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. + """ + # Combine the inputs with prefixes and generate the text + combined_texts = [f'{prompt}{prefix}' for prompt, prefix in zip(prompt_texts, prefixes)] + generations = self.generate(combined_texts, log_file, max_new_tokens, temperature) + return [f'{prefix}{generation}' for prefix, generation in zip(prefixes, generations)] \ No newline at end of file diff --git a/align_system/algorithms/lib/util.py b/align_system/algorithms/lib/util.py new file mode 100644 index 00000000..fad60e49 --- /dev/null +++ b/align_system/algorithms/lib/util.py @@ -0,0 +1,91 @@ +import re +import os +from typing import List, Dict + + +def dialog_from_string(dialog_string: str) -> List[Dict[str, str]]: + """ + Transforms the dialog in string format to a list of dictionary format. + + :param dialog_string: Dialog in string format. + :return: Dialog in the list of dictionary format. + """ + # Dictionary to map string markers to role names + dialog_markers = { + '=== system': 'system', + '=== user': 'user', + '=== assistant': 'assistant', + } + dialog = [] + lines = dialog_string.split('\n') + + current_role = '' + current_content = '' + for line in lines: + if line.strip() in dialog_markers: # If a line indicates a role change + if current_role and current_content: # Save the previous role's dialog + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + current_role = dialog_markers[line.strip()] # Set the new role + current_content = '' + else: # Continue appending content if the role hasn't changed + current_content += f'{line}\n' + # Append the last piece of dialog + if current_role and current_content: + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + return dialog + + +def dialog_to_string(dialog: List[Dict[str, str]]) -> str: + """ + Transforms the dialog in list of dictionary to string format. + + :param dialog: Dialog in list of dictionary format. + :return: Dialog in string format. + """ + output = '' + for dialog_piece in dialog: + role = dialog_piece['role'] + content = dialog_piece['content'] + output += f"=== {role}\n" + output += f"{content}\n" + + return output + + +def format_template(template: str, **substitutions: str) -> str: + """ + Replaces placeholders in a template with provided substitutions. + + :param template: The template with placeholders indicated as {{placeholder}}. + :param substitutions: The substitutions to replace in the template. + :return: The template with all placeholders substituted. + """ + for key, value in substitutions.items(): + key = '{{%s}}' % key + if not key in template: + raise Exception(f'Could not find key {key} in template') + template = template.replace(key, value) + + # ensure there are no strings surrounded by {{ }} + matches = re.findall(r'{{.*?}}', template) + # if there are any matches, raise an exception + if len(matches) > 0: + raise Exception(f'Unsubstituited key(s) in template: {matches}') + + return template + + +def read_template(template_file_name: str, template_dir='templates') -> str: + current_directory = os.path.dirname(os.path.abspath(__file__)) + full_path = os.path.join(current_directory, template_dir, template_file_name) + + with open(full_path, 'r') as template_file: + template = template_file.read() + + return template \ No newline at end of file