Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for Solvers #1461

Merged
merged 5 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions evals/elsuite/bluff/strategy_solver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
import re
from importlib import import_module
from typing import Optional

from evals.elsuite.bluff.bluff.cards import get_bluff_move
from evals.solvers.solver import Solver, SolverResult
from evals.solvers.utils import PersistentMemoryCache
from evals.task_state import Message, TaskState


Expand All @@ -24,7 +26,23 @@ def __init__(
self.max_attempts = max_attempts
self.rethink_strategy_after = rethink_strategy_after

def __call__(self, task_state: TaskState):
# interaction_length=1 to store reasoning step in private memory
self.interaction_cache = PersistentMemoryCache(interaction_length=1)

def _generate_response(self, task_state: TaskState):
"""
Calls base solver. Modifies taks state to remove all non-reasoning messages
from assistant
"""
task_state = copy.deepcopy(task_state)
task_state.messages = [
msg
for msg in task_state.messages
if msg.role != "assistant" or msg.content.startswith("{") or len(msg.content) > 20
]
return self.base_solver(task_state).output

def _solve(self, task_state: TaskState):
"""
This solver does three things that should help the model play better:
1. Adds a strategy guide as the first message (just after the task description)
Expand All @@ -35,25 +53,12 @@ def __call__(self, task_state: TaskState):
# GENERAL NOTE.
# This function is pretty ugly. I'm not sure how to implement this better. We decided this is good enough.

# Remove assistant messages added by the main solver (i.e. non-JSON).
# We need len(msg.content) > 20 because we don't want to remove "rething startegy".
task_state.messages = [
msg
for msg in task_state.messages
if msg.role != "assistant" or msg.content.startswith("{") or len(msg.content) > 20
]
# Before the first move in a game - strategy guide goes first
strategy_msg = Message("system", strategy)
task_state.messages.insert(0, strategy_msg)
task_state.messages = self.interaction_cache.load_private_interaction(task_state)

game = task_state.current_state

if len(game.rounds) == 1 and len(game.rounds[0].moves) < 2:
# Before the first move in a game - strategy guide goes first
strategy_msg = Message("system", strategy)

# This if is important - we might have already tried
# to bid, but gave an invalid bid, so still we have no moves
if strategy_msg not in task_state.messages:
task_state.messages.insert(0, strategy_msg)

if (
self.rethink_strategy_after is not None
and len(game.rounds) == 1 + self.rethink_strategy_after
Expand All @@ -67,15 +72,32 @@ def __call__(self, task_state: TaskState):
if strategy_update_msg not in task_state.messages:
last_system_message = task_state.messages.pop()
task_state.messages.append(strategy_update_msg)
response = self.base_solver(task_state).output
response = self._generate_response(task_state)
task_state.messages.append(Message("assistant", response))
task_state.messages.append(last_system_message)

# Manually update interaction cache, since we re-order messages
last_interaction = self.interaction_cache.last_interaction
last_interaction_messages = last_interaction.messages[:-1] + [
Message("system", strategy_update_msg),
Message("assistant", response),
Message("system", last_system_message),
]
last_interaction_private_ids = last_interaction.private_messages_ids + [
len(task_state.messages) - 3,
len(task_state.messages) - 2,
]

self.interaction_cache.last_interaction.messages = last_interaction_messages
self.interaction_cache.last_interaction.private_messages_ids = (
last_interaction_private_ids
)

# If this move_str is preserved, the game engine will have to deal with that
# (and it has some way of solving this problem)
move_str = "[INVALID MOVE]"
for _ in range(self.max_attempts):
response = self.base_solver(task_state).output
response = self._generate_response(task_state)
try:
move_str = self._parse_response(response)
# This will raise ValueError if this is not a valid move
Expand All @@ -85,6 +107,9 @@ def __call__(self, task_state: TaskState):
pass

task_state.messages.append(Message("assistant", response))
task_state.messages.append(Message("assistant", move_str))
self.interaction_cache.save_private_interaction(task_state)

return SolverResult(move_str)

@property
Expand Down
5 changes: 3 additions & 2 deletions evals/elsuite/make_me_pay/solvers/lm_con_artist_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
EXTRACT_MESSAGE,
TIME_REMINDER_MSG,
)
from evals.solvers.nested.cot_solver import CoTSolver
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import Message, TaskState

Expand All @@ -28,7 +29,7 @@ def __init__(

# Add more args to base solver if we're using CoT
base_solver_args_extra = {}
if class_name == "OpenAIChatCompletionCoTSolver":
if issubclass(cls, CoTSolver):
base_solver_args_extra = {
"cot_template": COT_MESSAGE,
"extract_template": EXTRACT_MESSAGE,
Expand All @@ -49,7 +50,7 @@ def __init__(
def name(self):
return "Scaffolded-LM-Solver"

def __call__(self, task_state: TaskState, **kwargs) -> SolverResult:
def _solve(self, task_state: TaskState, **kwargs) -> SolverResult:
# Optional additional message for better LM capabilities. Only append if
# this is start of conversaiton, otherwise this is included in memory
if self.lm_system_prompt:
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/sandbagging/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _construct_prompt(self, task_state: TaskState) -> Sequence[Dict]:

return prompt

def __call__(self, task_state: TaskState, **kwargs) -> (Sequence[Dict], SolverResult):
def _solve(self, task_state: TaskState, **kwargs) -> (Sequence[Dict], SolverResult):
prompt = self._construct_prompt(task_state)
result = self._predict_answer(prompt, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions evals/elsuite/self_prompting/solvers/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
This solver simply returns an empty string as the prompt.
"""

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand All @@ -32,7 +32,7 @@ def __init__(
This solver simply returns the original instruction as the prompt.
"""

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand All @@ -54,7 +54,7 @@ def __init__(
This solver concatenates the given input-output examples as few-shot demonstrations.
"""

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.extract_completion_fn = OpenAIChatCompletionFn(**self.completion_fn_options)
self.extract_template = extract_template

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand Down
9 changes: 7 additions & 2 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]:
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-base": 8192,
"gpt-4-1106-preview": 128_000,
}

# first, look for an exact match
Expand Down Expand Up @@ -135,7 +136,7 @@ def make_completion_fn(
# No match, so try to find a completion-fn-id in the registry
spec = self.get_completion_fn(name)
if spec is None:
raise ValueError(f"Could not find CompletionFn in the registry with ID {name}")
raise ValueError(f"Could not find CompletionFn/Solver in the registry with ID {name}")
if spec.args is None:
spec.args = {}
spec.args.update(kwargs)
Expand Down Expand Up @@ -195,7 +196,7 @@ def get_modelgraded_spec(self, name: str, **kwargs: dict) -> Optional[ModelGrade
)

def get_completion_fn(self, name: str) -> Optional[CompletionFnSpec]:
return self._dereference(name, self._completion_fns, "completion_fn", CompletionFnSpec)
return self._dereference(name, self._completion_fns | self._solvers, "completion_fn", CompletionFnSpec)

def get_eval(self, name: str) -> Optional[EvalSpec]:
return self._dereference(name, self._evals, "eval", EvalSpec)
Expand Down Expand Up @@ -303,6 +304,10 @@ def _load_registry(self, registry_paths: Sequence[Path], resource_type: str) ->
def _completion_fns(self) -> RawRegistry:
return self._load_registry(self._registry_paths, "completion_fns")

@functools.cached_property
def _solvers(self) -> RawRegistry:
return self._load_registry(self._registry_paths, "solvers")

@functools.cached_property
def _eval_sets(self) -> RawRegistry:
return self._load_registry(self._registry_paths, "eval_sets")
Expand Down
77 changes: 0 additions & 77 deletions evals/registry/completion_fns/bluff.yaml

This file was deleted.

92 changes: 0 additions & 92 deletions evals/registry/completion_fns/make-me-pay.yaml

This file was deleted.

Loading
Loading