diff --git a/evals/elsuite/track_the_stat/README.md b/evals/elsuite/track_the_stat/README.md new file mode 100644 index 0000000000..20c1580b2f --- /dev/null +++ b/evals/elsuite/track_the_stat/README.md @@ -0,0 +1,134 @@ +# Track the Stat + +This eval measures how well models can implicitly keep track of task state, by +asking models to compute the rolling median or the rolling mode over a sequence +of integers. + +## Usage + +Run with: + +```bash +oaieval track_the_stat +``` + +We have found that `generation/direct/gpt-4-0125-preview` works well on this +eval. For more examples of tested solvers, see +[`./scripts/run_experiments.sh`](./scripts/run_experiments.sh). + +## Evaluation Process + +The evaluation process is as follows for a given sample from our dataset: + +1. The `TASK_DESCRIPTION` prompt is shown to the solver. +2. The sample contains an integer to use as a seed for a random number + generator. +3. The random number generator generates 300 random integers between 0 and 100, + with replacement. +4. The integers are shown one by one to the solver. +5. At each turn (i.e., after each integer is shown), the solver needs to respond + with the current rolling median or the current rolling mode of the integers + seen so far. +6. The solver's response is parsed and compared to the correct rolling median or + rolling mode. +7. If the solver's response is incorrect or a violation is raised (answered in + the incorrect format), the evaluation stops and we measure how many turns the + solver lasted for. If the solver's response is correct, we move on to the + next integer. + +## Prompts + +We refer readers to the [`./prompts/`](./prompts/) folder for the +`TASK_DESCRIPTION` used in the eval. + +## Metrics + +Below are the metrics returned by the eval: + + +| **Metric** | **Notes** | +|------------------- |-------------------------------------------------------------------------------------------------------------------------------------------- | +| avg_max_length | The maximum sequence length the model can handle before failing, averaged across the samples. Higher is better. Best possible is 300. | +| stddev_max_length | The standard deviation on the above. | +| median_max_length | The median of the maximum sequence length the model can handle before failing, across the samples. Higher is better. Best possible is 300. | +| max_max_length | The maximum sequence length the model handled before failing across all samples. | +| min_max_length | The minimum sequence length the model handled before failing across all samples. | +| violation_rate | how often the model responds in an invalid format. i.e. not using the `[: ]` format. | + + +## Variants + +The eval has two variants: median and mode. In the median variant, the solver +needs to track the rolling median. In the mode variant, the solver needs to +track the rolling mode. + +```bash +oaieval track_the_stat. +``` + +## Custom Solvers + +We implement 3 custom solvers for this eval in [./solvers.py](./solvers.py) + +1. `ExplicitStateSolver`: A nested solver that injects an explicit + representation of the task state after each number is seen. For example, for + the median task we inject the sorted list of numbers seen so far. For the + mode task, we inject a dictionary that maps each number seen so far to its + count. We view this solver as a baseline for the task, providing the + performance of the models on _explicit_ state tracking, rather than the + default _implicit_ state tracking. +2. `RandomBaselineSolver`: A solver that randomly chooses a number from the + numbers seen so far as the rolling median or mode. In case of even length + lists in the median variant, it chooses two random numbers and returns their + arithmetic mean. We view this baseline as equivalent to randomly guessing. +3. `TrackTheStatHuman`: A helper solver class that wraps the `HumanCliSolver` + class such that users do not have to wrap their answer in the + `[median: ]` or `[mode: ]` format and can instead just + directly type the number. + +## Token Usage Estimates + +Below are token usage estimates for a given run (one run = all samples) of the +eval. + +For the mode task: + +| Model (state tracking) | Input | Output | Total | +| ----------------------------- | --------- | --------- | ---------- | +| gpt-3.5-turbo-0125 (implicit) | 670,000 | 10,000 | 680,000 | +| gpt-3.5-turbo-0125 (explicit) | 2,710,000 | 30,000 | 2,740,000 | +| gpt-4-base (implicit) | 9,030,000 | 2,110,000 | 11,150,000 | +| gpt-4-base (explicit) | 3,720,000 | 960,000 | 4,680,000 | +| gpt-4-0125-preview (implicit) | 3,050,000 | 30,000 | 3,080,000 | +| gpt-4-0125-preview (explicit) | 8,580,000 | 50,000 | 8,630,000 | + +For the median task: + +| Model (state tracking) | Input | Output | Total | +| ----------------------------- | --------- | ------- | --------- | +| gpt-3.5-turbo-0125 (implicit) | 430,000 | 10,000 | 440,000 | +| gpt-3.5-turbo-0125 (explicit) | 880,000 | 10,000 | 890,000 | +| gpt-4-base (implicit) | 2,900,000 | 760,000 | 3,660,000 | +| gpt-4-base (explicit) | 3,250,000 | 810,000 | 4,060,000 | +| gpt-4-0125-preview (implicit) | 690,000 | 10,000 | 700,000 | +| gpt-4-0125-preview (explicit) | 1,430,000 | 20,000 | 1,450,000 | + +## Future modifications + +- Identify new variants of the task beyond median or mode, where the explicit + state is either impossible to represent or not useful for the task. This would + allow us to more comfortably measure the implicit state tracking, even on CoT + solvers. +- Identify more realistic and/or complex tasks. +- Introduce distractors. + +## Version History + +- v0: Initial version released + +## Contribution Statement + +Eval design, implementation, and results evaluation were primarily conducted by +Giulio Starace, under the guidance of (alphabetically by last-name) Steven +Adler, Andrei Alexandru, James Aung, and Chan Jun Shern who provided research +input, report revisions, and project management support. diff --git a/evals/elsuite/track_the_stat/eval.py b/evals/elsuite/track_the_stat/eval.py new file mode 100644 index 0000000000..d1ca65d719 --- /dev/null +++ b/evals/elsuite/track_the_stat/eval.py @@ -0,0 +1,96 @@ +import logging +import random +from typing import Any, Optional + +import numpy as np + +from evals.elsuite.track_the_stat import prompts, utils +from evals.eval import SolverEval +from evals.record import RecorderBase, record_metrics +from evals.solvers.solver import Solver +from evals.task_state import Message, TaskState + +logging.getLogger("httpx").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + + +class TrackTheStat(SolverEval): + def __init__(self, task: str, n_samples: Optional[int] = 250, *args, **kwargs): + super().__init__(*args, **kwargs) + assert task in [ + "median", + "mode", + ], f"task must be either 'median' or 'mode', but got {task}" + self.task = task + # warn, color in yellow + logger.warning( + utils.yellow_string( + "By nature of what is being evaluated, this eval assumes that the " + "solver cannot make use of external scratchpads or similar solutions " + "to explicitly write down the task state at every step. Using solvers " + "that allow for this functionality will likely produce invalid results." + ) + ) + self.task_desc = prompts.TASK_DESCRIPTION.format( + task=task, + task_further_details=prompts.task_to_further_details[task], + task_example=prompts.task_to_example[task], + ) + self.task_fn = utils.task_to_fn[task] + self.n_samples = n_samples + self.rng = random.Random(self.seed) + + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random) -> None: + capped_inf_list = np.random.default_rng(sample["seed"]).integers(0, 100, size=300) + metrics = self._eval_sample(solver, capped_inf_list) + + record_metrics(**metrics) + + def _eval_sample(self, solver: Solver, capped_inf_list: list[int]) -> dict: + violation = False + task_state = TaskState(task_description=self.task_desc, messages=[]) + for i, num in enumerate(capped_inf_list): + curr_list = capped_inf_list[: i + 1] + task_state.messages.append(Message(role="user", content=str(num))) + task_state.current_state = utils.compute_state(curr_list, self.task) + solver_output = solver(task_state).output + solver_response = utils.parse_solver_output(solver_output, self.task) + if solver_response is None: + violation = True + break + if round(solver_response, 1) != round(self.task_fn(curr_list), 1): + break + task_state.messages.append(Message(role="assistant", content=solver_output)) + + return { + "max_length": len(curr_list) - 1, + "violation": violation, + } + + def run(self, recorder: RecorderBase): + samples = self._get_samples() + self.eval_all_samples(recorder, samples) + logged_metrics: list[dict] = recorder.get_metrics() + + agg_metrics = self._compute_agg_metrics(logged_metrics) + return agg_metrics + + def _compute_agg_metrics(self, logged_metrics: list[dict]) -> dict: + max_lengths = np.array([x["max_length"] for x in logged_metrics]) + + agg_metrics = { + "avg_max_length": np.mean(max_lengths), + "stddev_max_length": np.std(max_lengths), + "median_max_length": np.median(max_lengths), + "max_max_length": np.max(max_lengths), + "min_max_length": np.min(max_lengths), + "violation_rate": np.mean([x["violation"] for x in logged_metrics]), + } + # necessary for serialization, json doesn't like np floats + agg_metrics = {k: float(v) for k, v in agg_metrics.items()} + return agg_metrics + + def _get_samples(self) -> list[dict]: + samples = [{"seed": num} for num in range(self.n_samples)] + samples = self.rng.sample(samples, self.n_samples) # shuffling + return samples diff --git a/evals/elsuite/track_the_stat/prompts/__init__.py b/evals/elsuite/track_the_stat/prompts/__init__.py new file mode 100644 index 0000000000..31d1b2de39 --- /dev/null +++ b/evals/elsuite/track_the_stat/prompts/__init__.py @@ -0,0 +1,27 @@ +from . import mode +from . import median + + +TASK_DESCRIPTION = """\ +In the following messages, you will be given a sequence of numbers.\ + At each turn, you will be shown a number as input, and you should respond with the\ + {task} of all the input numbers shown to you so far. + +{task_further_details} + +Here is an example of what this may look like. +{task_example} + +Format your response as [{task}: ] (square brackets included), as shown in\ +the transcript above. The task will begin now. +""" + +task_to_example = { + "median": median.MEDIAN_EXAMPLE, + "mode": mode.MODE_EXAMPLE, +} + +task_to_further_details = { + "median": median.MEDIAN_FURTHER_DETAILS, + "mode": mode.MODE_FURTHER_DETAILS, +} diff --git a/evals/elsuite/track_the_stat/prompts/median.py b/evals/elsuite/track_the_stat/prompts/median.py new file mode 100644 index 0000000000..aae3c0ecc8 --- /dev/null +++ b/evals/elsuite/track_the_stat/prompts/median.py @@ -0,0 +1,33 @@ +MEDIAN_EXAMPLE = """\ +```example +input: 1 +ideal_response: [median: 1]\ + # your response; 1 is the only number shown so far +--- +input: 2 +ideal_response: [median: 1.5]\ + # even number of numbers, so median = mean(1,2) = 1.5 +--- +input: 1 +ideal_response: [median: 1]\ + # 1 is now the middle number when sorting the numbers +--- +input: 3 +ideal_response: [median: 1.5]\ + # middle numbers are now 1 and 2, so once again median = mean(1,2) = 1.5 +--- +input: 3 +ideal_response: [median: 2]\ +# the sorted list is [1 1 2 3 3]; odd length, so median is the middle number, 2 +--- +input: 0 +ideal_response: [median: 1.5]\ +# the sorted list is [0 1 1 2 3 3]; even length, so median is mean(1,2) = 1.5 +```\ +""" + + +MEDIAN_FURTHER_DETAILS = """\ +NOTE: In case of lists containing an even number of elements, you should respond with the\ + arithmetic mean of the middle two numbers of the sorted list.\ +""" diff --git a/evals/elsuite/track_the_stat/prompts/mode.py b/evals/elsuite/track_the_stat/prompts/mode.py new file mode 100644 index 0000000000..5756e7e55c --- /dev/null +++ b/evals/elsuite/track_the_stat/prompts/mode.py @@ -0,0 +1,29 @@ +MODE_EXAMPLE = """\ +```example +input: 1 +ideal_response: [mode: 1]\ + # your response; 1 is the only number shown so far +--- +input: 2 +ideal_response: [mode: 2]\ + # 1 and 2 are tied modes (both appeared once), 2 > 1 +--- +input: 1 +ideal_response: [mode: 1]\ + # 1 now has appeared more than any other number +--- +input: 3 +ideal_response: [mode: 1] +--- +input: 3 +ideal_response: [mode: 3]\ + # 3 is tied with 1 in terms of appearances, 3 > 1 +--- +input: 0 +ideal_response: [mode: 3] +```\ +""" + +MODE_FURTHER_DETAILS = """\ +NOTE: In case of ties, you should respond with the largest number that is part of the tie.\ +""" diff --git a/evals/elsuite/track_the_stat/scripts/make_plots.py b/evals/elsuite/track_the_stat/scripts/make_plots.py new file mode 100644 index 0000000000..b40e4a3586 --- /dev/null +++ b/evals/elsuite/track_the_stat/scripts/make_plots.py @@ -0,0 +1,296 @@ +from pathlib import Path +import argparse +import json + +from tqdm.auto import tqdm +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from evals.utils import log_utils + + +def zero_if_none(input_num): + if input_num is None: + return 0 + else: + return input_num + + +MODELS = [ + "gpt-4-0125-preview", + "gpt-4-base", + "gpt-3.5-turbo-0125", + "gemini-pro-1.0", + "mixtral-8x7b-instruct", + "llama-2-70b-chat", + "random_baseline", + "human_baseline", +] +# separate list for OAI models for token counting, not supported in others. +OAI_MODELS = [ + "gpt-4-0125-preview", + "gpt-3.5-turbo-0125", + "gpt-4-base", +] + +STAT_TO_LABEL = { + "avg_max_length": "Average maximum sequence length achieved [no. of turns]", + "violation_rate": "Violation rate", +} + + +def make_results_dict(log_dir: Path) -> dict: + results_dict = prepare_results_dict() + results_dict = fill_results_dict(results_dict, log_dir) + return results_dict + + +def get_model(spec): + # this is hilariously ugly but it works for now (sorry) + if "gpt-4-turbo-preview" in spec["completion_fns"][0]: + return "gpt-4-0125-preview" + elif "gpt-3.5-turbo" in spec["completion_fns"][0]: + return "gpt-3.5-turbo-0125" + elif "gpt-4-base" in spec["completion_fns"][0]: + return "gpt-4-base" + elif "gemini-pro" in spec["completion_fns"][0]: + return "gemini-pro-1.0" + elif "mixtral-8x7b-instruct" in spec["completion_fns"][0]: + return "mixtral-8x7b-instruct" + elif "llama-2-70b-chat" in spec["completion_fns"][0]: + return "llama-2-70b-chat" + elif "random_baseline" in spec["completion_fns"][0]: + return "random_baseline" + elif "human" in spec["completion_fns"][0]: + return "human_baseline" + + +def get_state_tracking(spec): + if "explicit" in spec["completion_fns"][0]: + return "explicit" + else: + return "implicit" + + +def fill_results_dict(results_dict, log_dir): + print("Parsing logs...") + final_results = log_utils.get_final_results_from_dir(log_dir) + specs = log_utils.get_specs_from_dir(log_dir) + files = list(final_results.keys()) + + for file in tqdm(files): + final_result = final_results[file] + spec = specs[file] + task = spec["split"] + model = get_model(spec) + state_tracking = get_state_tracking(spec) + for stat in results_dict: + results_dict[stat][task][model][state_tracking]["raw"].append( + final_result[stat] + ) + # compute means/std_errs + for file in tqdm(files): + spec = specs[file] + task = spec["split"] + model = get_model(spec) + state_tracking = get_state_tracking(spec) + for stat in results_dict: + data_points = results_dict[stat][task][model][state_tracking]["raw"] + results_dict[stat][task][model][state_tracking]["mean"] = np.mean( + data_points + ) + results_dict[stat][task][model][state_tracking]["std_err"] = np.std( + data_points + ) / np.sqrt(len(data_points) if len(data_points) > 1 else 1) + return results_dict + + +def prepare_results_dict(): + results_dict = { + stat: { + task: { + model: { + state_tracking: {"raw": []} + for state_tracking in ["implicit", "explicit"] + } + for model in MODELS + } + for task in ["mode", "median"] + } + for stat in ["avg_max_length", "violation_rate"] + } + return results_dict + + +def make_bar_plot(results_dict: dict, task: str, stat: str, save_path: Path): + sns.set_context("paper") + sns.set_style("whitegrid") + + data = results_dict[stat][task] + + # the random baseline and human baseline aren't plotted as bars + models = MODELS[:-2] + + state_tracking_kinds = ["explicit", "implicit"] + + means = [ + [data[model][cat]["mean"] for cat in state_tracking_kinds] for model in models + ] + std_errs = [ + [data[model][cat]["std_err"] for cat in state_tracking_kinds] + for model in models + ] + cmap = plt.get_cmap("Paired") + colors = np.array([cmap(i) for i in range(len(state_tracking_kinds))]) + + # Plotting + x = np.arange(len(models)) # the label locations + + width = 0.4 + + fig, ax = plt.subplots(1, 1, figsize=(8, 6), dpi=300) + + explicit_bars = ax.barh( + x + width / 2, + [mean[0] for mean in means], + width, + xerr=[err[0] for err in std_errs], + label="Explicitly tracked state baseline", + color=colors[0], + ) + implicit_bars = ax.barh( + x - width / 2, + [mean[1] for mean in means], + width, + xerr=[err[1] for err in std_errs], + label="Implicitly tracked state", + color=colors[1], + ) + + ax.set_xlabel(STAT_TO_LABEL[stat]) + # maximum x + xerr value times 1.2 + x_max = ( + max([m for mean in means for m in mean]) + + max([e for err in std_errs for e in err]) + ) * 1.2 + ax.set_xlim([0, x_max]) + ax.set_yticks(x) + ax.set_yticklabels(models) + + ax.bar_label(implicit_bars, padding=3, fmt="%.2f") + ax.bar_label(explicit_bars, padding=3, fmt="%.2f") + + # plot random and human baselines + random_baseline = data["random_baseline"]["implicit"]["mean"] + random_err = data["random_baseline"]["implicit"]["std_err"] + ax.axvline(random_baseline, color="red", linestyle="--", label="Random baseline") + ax.axvspan( + random_baseline - random_err, + random_baseline + random_err, + color="red", + alpha=0.05, + ) + + human_baseline = data["human_baseline"]["implicit"]["mean"] + human_err = data["human_baseline"]["implicit"]["std_err"] + ax.axvline( + human_baseline, + color="#366a9d", + linestyle=":", + label="Human baseline (implicit)", + ) + + ax.axvspan( + human_baseline - human_err, + human_baseline + human_err, + color="#366a9d", + alpha=0.05, + ) + + # get rid of horizontal grid lines + ax.grid(axis="y", which="both") + + ax.legend() + + fig.tight_layout() + + plt.savefig(save_path, bbox_inches="tight", dpi=300) + + +def count_tokens(log_dir) -> dict[str, dict[str, dict[str, int]]]: + """ + model -> task -> input, output, total tokens + """ + token_counts = { + model: { + task: { + state_tracking: {kind: 0 for kind in ["input", "output", "total"]} + for state_tracking in ["implicit", "explicit"] + } + for task in ["mode", "median"] + } + for model in OAI_MODELS + } + globbed_logs = list(log_dir.glob("*.log")) + already_examined = set() + for log in tqdm(globbed_logs, total=len(globbed_logs), desc="Counting tokens"): + spec = log_utils.extract_spec(log) + task = spec["split"] + model = get_model(spec) + state_tracking = get_state_tracking(spec) + + if model not in OAI_MODELS: + continue + + # dont care about repeats, this is a rough estimate anyway + if (model, task, state_tracking) in already_examined: + continue + already_examined.add((model, task, state_tracking)) + + samplings = log_utils.extract_individual_results(log, "sampling") + for sampling in samplings: + usage = sampling["usage"] + token_counts[model][task][state_tracking]["input"] += zero_if_none( + usage["prompt_tokens"] + ) + token_counts[model][task][state_tracking]["output"] += zero_if_none( + usage["completion_tokens"] + ) + token_counts[model][task][state_tracking]["total"] += zero_if_none( + usage["total_tokens"] + ) + return token_counts + + +def main(args: argparse.Namespace): + log_dir = Path(args.log_dir) + save_dir = Path(args.save_dir) + save_dir.mkdir(exist_ok=True, parents=True) + + results_dict = make_results_dict(log_dir) + + for stat in tqdm(results_dict.keys(), desc=f"Plotting..."): + for task in tqdm(["mode", "median"], desc=f"Plotting {stat}"): + save_path = save_dir / f"{task}_{stat}.png" + make_bar_plot(results_dict, task, stat, save_path) + save_path = save_dir / f"{stat}.json" + with open(save_path, "w") as f: + json.dump(results_dict[stat], f, indent=2) + + token_counts = count_tokens(log_dir) + save_path = save_dir / "token_counts.json" + with open(save_path, "w") as f: + json.dump(token_counts, f, indent=2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--log_dir", type=str, required=True, help="Where the logs are stored" + ) + parser.add_argument( + "--save_dir", type=str, required=True, help="Where to save the plots" + ) + args = parser.parse_args() + main(args) diff --git a/evals/elsuite/track_the_stat/scripts/run_experiments.sh b/evals/elsuite/track_the_stat/scripts/run_experiments.sh new file mode 100644 index 0000000000..8307866418 --- /dev/null +++ b/evals/elsuite/track_the_stat/scripts/run_experiments.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +usage() { + echo "Usage: $0 -l logdir" + echo " -l logdir Specify the directory for log files" + exit 1 +} + +# Check if no arguments were provided +if [ $# -eq 0 ]; then + usage + exit 1 +fi + +# Parse command-line options +while getopts 's:l:' flag; do + case "${flag}" in + l) logdir=${OPTARG} ;; + *) usage ;; + esac +done + +# Check if mandatory arguments were provided +if [ -z "$logdir" ]; then + usage + exit 1 +fi + +NUM_REPEATS=3 + +export EVALS_THREADS=10 +export EVALS_THREADS_TIMEOUT=5 + +declare -a SOLVERS=( + # 4-turbo-preview + "generation/direct/gpt-4-turbo-preview" + "track_the_stat/explicit_state/gpt-4-turbo-preview" + # 3.5-turbo + "generation/direct/gpt-3.5-turbo" + "track_the_stat/explicit_state/gpt-3.5-turbo" + # 4-base + "generation/hhh/gpt-4-base" + "track_the_stat/explicit_state/hhh/gpt-4-base" + # gemini pro + "generation/direct/gemini-pro" + "track_the_stat/explicit_state/gemini-pro" + # mixtral-8x7b-instruct + "generation/direct/mixtral-8x7b-instruct" + "track_the_stat/explicit_state/mixtral-8x7b-instruct" + # llama chat 70b + "generation/direct/llama-2-70b-chat" + "track_the_stat/explicit_state/llama-2-70b-chat" + # random baseline + "track_the_stat/random_baseline" +) +declare -a TASKS=( + "mode" + "median" +) + +# Check if GEMINI_API_KEY is set +if [ -z "$GEMINI_API_KEY" ]; then + echo "Enter your Gemini API Key:" + read -s GEMINI_API_KEY + export GEMINI_API_KEY +fi + +# Check if TOGETHER_API_KEY is set +if [ -z "$TOGETHER_API_KEY" ]; then + echo "Enter your Together API Key:" + read -s TOGETHER_API_KEY + export TOGETHER_API_KEY +fi + +start_time=$SECONDS +for ((i = 1; i <= NUM_REPEATS; i++)); do + for task in "${TASKS[@]}"; do + for solver in "${SOLVERS[@]}"; do + if [[ $solver == *"gemini"* ]]; then + export EVALS_SEQUENTIAL=1 + else + export EVALS_SEQUENTIAL=0 + fi + solver_dotted=${solver//\//.} + record_path="${logdir}/${solver_dotted}_${task}_${i}" + echo "Running $solver on $task (repeat $i)" + oaieval $solver "track_the_stat.${task}" \ + --record_path "$record_path.log" --seed $i + done + done +done +echo "Total time: $((SECONDS - start_time)) seconds" diff --git a/evals/elsuite/track_the_stat/solvers.py b/evals/elsuite/track_the_stat/solvers.py new file mode 100644 index 0000000000..65721002cc --- /dev/null +++ b/evals/elsuite/track_the_stat/solvers.py @@ -0,0 +1,98 @@ +import random +from typing import Any + +from evals.elsuite.track_the_stat import utils +from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec +from evals.task_state import Message, TaskState + + +class ExplicitStateSolver(NestedSolver): + def __init__( + self, + underlying_solver: SolverSpec, + state_role: str = "assistant", + *args, + **kwargs, + ): + super().__init__(underlying_solver=underlying_solver, *args, **kwargs) + self.state_role = state_role + + @property + def underlying_solver(self) -> Solver: + return self.get_solver("underlying_solver") + + def _render_state(self, current_state: dict) -> str: + rendered_state_string = f"{current_state['state_label']}\n{current_state['state_data']}" + return rendered_state_string + + def _build_message(self, task_state: TaskState) -> str: + message_string = "The current state, useful for solving the task\n" + self._render_state( + task_state.current_state + ) + return Message(role=self.state_role, content=message_string) + + def _solve(self, task_state: TaskState) -> SolverResult: + precomputed_state_message = self._build_message(task_state) + task_state.messages.append(precomputed_state_message) + + solver_result = self.underlying_solver(task_state=task_state) + return solver_result + + +class RandomBaselineSolver(Solver): + def __init__(self, registry: Any = None, *args, **kwargs): + super().__init__() + + def _solve(self, task_state: TaskState) -> SolverResult: + task = task_state.current_state["task_name"] + random_output = self._task_solve(task, task_state) + solver_result = SolverResult(output=f"[{task}: {random_output}]") + return solver_result + + def _task_solve(self, task: str, task_state: TaskState) -> str: + if task == "mode": + return self._mode_solve(task_state) + elif task == "median": + return self._median_solve(task_state) + + def _mode_solve(self, task_state: TaskState) -> str: + """ + Picks a random number from the numbers seen so far + """ + numbers = list(task_state.current_state["state_data"].keys()) + random_mode = random.choice(numbers) + return str(random_mode) + + def _median_solve(self, task_state: TaskState) -> str: + """ + Picks a random number from the numbers seen so far + (in case of even number of numbers, picks the average of two random numbers) + """ + numbers = task_state.current_state["state_data"] + if len(numbers) % 2 == 0: + random_1, random_2 = random.choices(numbers, k=2) + random_median = (random_1 + random_2) / 2 + else: + random_median = random.choice(numbers) + return str(round(random_median, 1)) + + +class TrackTheStatHuman(NestedSolver): + def __init__(self, human_cli_solver: SolverSpec, *args, **kwargs): + super().__init__(human_cli_solver=human_cli_solver, *args, **kwargs) + + @property + def human_cli_solver(self) -> Solver: + return self.get_solver("human_cli_solver") + + def _solve(self, task_state: TaskState) -> SolverResult: + human_result = self.human_cli_solver(task_state=task_state) + task = task_state.current_state["task_name"] + # wrap the result in [: ] if not already wrapped + output = utils.parse_solver_output(human_result.output, task) + if output is None: # there is a violation -- output is not wrapped + return SolverResult( + output=f"[{task}: {human_result.output}]", + ) + else: # no violation -- output is already wrapped + return human_result diff --git a/evals/elsuite/track_the_stat/utils.py b/evals/elsuite/track_the_stat/utils.py new file mode 100644 index 0000000000..55467c5100 --- /dev/null +++ b/evals/elsuite/track_the_stat/utils.py @@ -0,0 +1,78 @@ +import re +from collections import Counter +from typing import Union + +import numpy as np + + +def yellow_string(str: str) -> str: + return f"\033[1;33m{str}\033[0m" + + +def median(numbers: list[int]) -> int: + """ + Returns the median of the given list of numbers. If the list has an even + number of elements, the arithmetic mean of the two middle elements of the + sorted list is returned. + """ + return np.median(numbers) + + +def mode(numbers: list[int]) -> int: + """ + Returns the mode of the given list of numbers. If there are multiple modes, + the largest mode is returned. + """ + frequency = {} + for number in numbers: + frequency[number] = frequency.get(number, 0) + 1 + + max_frequency = max(frequency.values()) + candidates = [number for number, freq in frequency.items() if freq == max_frequency] + + return max(candidates) + + +task_to_fn = {"median": median, "mode": mode} + + +def parse_solver_output(solver_output: str, task: str) -> Union[int, None]: + solver_string = solver_output.strip().lower() + pattern = rf"\[{task}: (\d+(?:\.\d+)?)\]" + + match = re.search(pattern, solver_string) + + if match: + try: + output = float(match.group(1)) + except ValueError: + output = None + else: + output = None + + return output + + +def compute_mode_state(curr_list: list[int]) -> dict: + counter = Counter(curr_list) + return dict(counter) + + +def compute_median_state(curr_list: list[int]) -> dict: + sorted_list = sorted(curr_list) + return sorted_list + + +def compute_state(curr_list: list[int], task) -> dict: + if task == "mode": + return { + "task_name": task, + "state_label": "number to count", + "state_data": compute_mode_state(curr_list), + } + else: + return { + "task_name": task, + "state_label": "sorted list of shown numbers", + "state_data": compute_median_state(curr_list), + } diff --git a/evals/registry/evals/track_the_stat.yaml b/evals/registry/evals/track_the_stat.yaml new file mode 100644 index 0000000000..c64ce0ed40 --- /dev/null +++ b/evals/registry/evals/track_the_stat.yaml @@ -0,0 +1,22 @@ +track_the_stat: + id: track_the_stat.mode + metrics: + [ + "avg_max_length", + "stddev_max_length", + "median_max_length", + "max_max_length", + "min_max_length", + "violation_rate", + ] + description: "Perform a sequential task by keeping track of state implicitly" + +track_the_stat.mode: + class: evals.elsuite.track_the_stat.eval:TrackTheStat + args: + task: mode + +track_the_stat.median: + class: evals.elsuite.track_the_stat.eval:TrackTheStat + args: + task: median diff --git a/evals/registry/solvers/track_the_stat.yaml b/evals/registry/solvers/track_the_stat.yaml new file mode 100644 index 0000000000..fc061f9583 --- /dev/null +++ b/evals/registry/solvers/track_the_stat.yaml @@ -0,0 +1,82 @@ +track_the_stat/explicit_state/gemini-pro: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.providers.google.gemini_solver:GeminiSolver + args: + model_name: gemini-pro + state_role: "user" + +track_the_stat/explicit_state/llama-2-70b-chat: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/mixtral-8x7b-instruct: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/gpt-3.5-turbo: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/gpt-4-turbo-preview: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-turbo-preview + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/explicit_state/hhh/gpt-4-base: + class: evals.elsuite.track_the_stat.solvers:ExplicitStateSolver + args: + underlying_solver: + class: evals.solvers.nested.hhh_solver:HHHSolver + args: + solver: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 512 + +track_the_stat/human_cli: + class: evals.elsuite.track_the_stat.solvers:TrackTheStatHuman + args: + human_cli_solver: + class: evals.solvers.human_cli_solver:HumanCliSolver + args: + registry: null + +track_the_stat/random_baseline: + class: evals.elsuite.track_the_stat.solvers:RandomBaselineSolver diff --git a/evals/utils/log_utils.py b/evals/utils/log_utils.py index d54a846f41..6ef2b5e8ff 100644 --- a/evals/utils/log_utils.py +++ b/evals/utils/log_utils.py @@ -14,6 +14,17 @@ def get_final_results_from_dir(log_dir: Union[str, Path]) -> dict[Path, dict]: return final_results_dict +def get_specs_from_dir(log_dir: Union[str, Path]) -> dict[Path, dict]: + """ + Given a directory of log files, return a dictionary mapping log file paths to specs. + """ + specs_dict = {} + for path in Path(log_dir).glob("**/*.log"): + spec = extract_spec(path) + specs_dict[path] = spec + return specs_dict + + def extract_final_results(path: Path) -> dict: """ Given a path to a log file, find and return the "final_report" dictionary. @@ -31,7 +42,7 @@ def extract_final_results(path: Path) -> dict: raise ValueError(f"Could not find final_report in {path}") -def extract_individual_results(path: Path) -> list[dict]: +def extract_individual_results(path: Path, type_string: str = "metrics") -> list[dict]: """ Given a path to a log file, grab all the individual sample results. """ @@ -42,7 +53,7 @@ def extract_individual_results(path: Path) -> list[dict]: try: loaded_line = json.loads(line) if "type" in loaded_line: - if loaded_line["type"] == "metrics": + if loaded_line["type"] == type_string: all_data.append(loaded_line["data"]) except json.decoder.JSONDecodeError: print(f"Skipping line: {line}")