Skip to content

Commit

Permalink
add gemini solver
Browse files Browse the repository at this point in the history
  • Loading branch information
ojaffe committed Mar 21, 2024
1 parent e30e141 commit 21a57d7
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 1 deletion.
23 changes: 23 additions & 0 deletions evals/registry/solvers/gemini.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

# ------------------
# gemini-pro
# ------------------

# generation tasks

generation/direct/gemini-pro:
class: evals.solvers.providers.google.gemini_solver:GeminiSolver
args:
model_name: gemini-pro

generation/cot/gemini-pro:
class: evals.solvers.nested.cot_solver:CoTSolver
args:
cot_solver:
class: evals.solvers.providers.google.gemini_solver:GeminiSolver
args:
model_name: gemini-pro
extract_solver:
class: evals.solvers.providers.google.gemini_solver:GeminiSolver
args:
model_name: gemini-pro
211 changes: 211 additions & 0 deletions evals/solvers/providers/google/gemini_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import copy
import os
from dataclasses import asdict, dataclass
from typing import Any, Dict, Union

import google.api_core.exceptions
import google.generativeai as genai
from google.generativeai.client import get_default_generative_client

from evals.record import record_sampling
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import Message, TaskState
from evals.utils.api_utils import create_retrying

# Load API key from environment variable
API_KEY = os.environ.get("GEMINI_API_KEY")
genai.configure(api_key=API_KEY)

SAFETY_SETTINGS = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
GEMINI_RETRY_EXCEPTIONS = (
google.api_core.exceptions.RetryError,
google.api_core.exceptions.TooManyRequests,
google.api_core.exceptions.ResourceExhausted,
)


# TODO: Could we just use google's own types?
# e.g. google.generativeai.types.content_types.ContentType
@dataclass
class GoogleMessage:
role: str
parts: list[str]

def to_dict(self):
return asdict(self)

@staticmethod
def from_evals_message(msg: Message):
valid_roles = {"user", "model"}
to_google_role = {
"system": "user", # Google doesn't have a "system" role
"user": "user",
"assistant": "model",
}
gmsg = GoogleMessage(
role=to_google_role.get(msg.role, msg.role),
parts=[msg.content],
)
assert gmsg.role in valid_roles, f"Invalid role: {gmsg.role}"
return gmsg


class GeminiSolver(Solver):
"""
A solver class that uses Google's Gemini API to generate responses.
"""

def __init__(
self,
model_name: str,
generation_config: Dict[str, Any] = {},
postprocessors: list[str] = [],
registry: Any = None,
):
super().__init__(postprocessors=postprocessors)

self.model_name = model_name
self.gen_config = genai.GenerationConfig(**generation_config)

# We manually define the client. This is normally defined automatically when calling
# the API, but it isn't thread-safe, so we anticipate its creation here
self.glm_client = get_default_generative_client()

@property
def model(self) -> str:
return self.model_name

def _solve(
self,
task_state: TaskState,
**kwargs,
) -> SolverResult:
msgs = [
Message(role="user", content=task_state.task_description),
] + task_state.messages
gmsgs = self._convert_msgs_to_google_format(msgs)
gmsgs = [msg.to_dict() for msg in gmsgs]
try:
glm_model = genai.GenerativeModel(model_name=self.model_name)
glm_model._client = self.glm_client

gen_content_resp = create_retrying(
glm_model.generate_content,
retry_exceptions=GEMINI_RETRY_EXCEPTIONS,
**{
"contents": gmsgs,
"generation_config": self.gen_config,
"safety_settings": SAFETY_SETTINGS,
},
)
if gen_content_resp.prompt_feedback.block_reason:
# Blocked by safety filters
solver_result = SolverResult(
str(gen_content_resp.prompt_feedback),
error=gen_content_resp.prompt_feedback,
)
else:
# Get text response
solver_result = SolverResult(
gen_content_resp.text,
error=gen_content_resp.prompt_feedback,
)
except (google.api_core.exceptions.GoogleAPIError,) as e:
solver_result = SolverResult(
e.message,
error=e,
)
except ValueError as e:
# TODO: Why does this error ever occur and how can we handle it better?
# (See google/generativeai/types/generation_types.py for the triggers)
known_errors = [
"The `response.text` quick accessor",
"The `response.parts` quick accessor",
]
if any(err in str(e) for err in known_errors):
solver_result = SolverResult(
str(e),
error=e,
)
else:
raise e

record_sampling(
prompt=msgs,
sampled=[solver_result.output],
model=self.model,
)
return solver_result

@staticmethod
def _convert_msgs_to_google_format(msgs: list[Message]) -> list[GoogleMessage]:
"""
Gemini API requires that the message list has
- Roles as 'user' or 'model'
- Alternating 'user' and 'model' messages
- Ends with a 'user' message
"""
# Enforce valid roles
gmsgs = []
for msg in msgs:
gmsg = GoogleMessage.from_evals_message(msg)
gmsgs.append(gmsg)
assert gmsg.role in {"user", "model"}, f"Invalid role: {gmsg.role}"

# Enforce alternating messages
# e.g. [user1, user2, model1, user3] -> [user12, model1, user3]
std_msgs = []
for msg in gmsgs:
if len(std_msgs) > 0 and msg.role == std_msgs[-1].role:
# Merge consecutive messages from the same role
std_msgs[-1].parts.extend(msg.parts)
# The API seems to expect a single-element list of strings (???) so we join the
# parts into a list containing a single string
std_msgs[-1].parts = ["\n".join(std_msgs[-1].parts)]
else:
# Proceed as normal
std_msgs.append(msg)

# Enforce last message is from the user
assert std_msgs[-1].role == "user", "Last message must be from the user"
return std_msgs

@property
def name(self) -> str:
return self.model

@property
def model_version(self) -> Union[str, dict]:
return self.model

def __deepcopy__(self, memo):
"""
Deepcopy everything except for self.glm_client, which is instead shared across all copies
"""
cls = self.__class__
result = cls.__new__(cls)

memo[id(self)] = result
for k, v in self.__dict__.items():
if k != "glm_client":
setattr(result, k, copy.deepcopy(v, memo))

result.glm_client = self.glm_client
return result
71 changes: 71 additions & 0 deletions evals/solvers/providers/google/gemini_solver_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import pytest

from evals.record import DummyRecorder
from evals.solvers.providers.google.gemini_solver import GeminiSolver, GoogleMessage
from evals.task_state import Message, TaskState

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
MODEL_NAME = "gemini-pro"


@pytest.fixture
def dummy_recorder():
recorder = DummyRecorder(None) # type: ignore
with recorder.as_default_recorder("x"):
yield recorder


@pytest.fixture
def gemini_solver():
os.environ["EVALS_SEQUENTIAL"] = "1" # TODO: Remove after fixing threading issue
solver = GeminiSolver(
model_name=MODEL_NAME,
)
return solver


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="API tests are wasteful to run on every commit.")
def test_solver(dummy_recorder, gemini_solver):
"""
Test that the solver generates a response coherent with the message history
while following the instructions from the task description.
"""
solver = gemini_solver

answer = "John Doe"
task_state = TaskState(
task_description=f"When you are asked for your name, respond with '{answer}' (without quotes).",
messages=[
Message(role="user", content="What is 2 + 2?"),
Message(role="assistant", content="4"),
Message(role="user", content="What is your name?"),
],
)

solver_res = solver(task_state=task_state)
assert solver_res.output == answer, f"Expected '{answer}', but got {solver_res.output}"


def test_message_format():
"""
Test that messages in our evals format is correctly converted to the format
expected by Gemini.
"""

messages = [
Message(role="system", content="You are a great mathematician."),
Message(role="user", content="What is 2 + 2?"),
Message(role="assistant", content="5"),
Message(role="user", content="That's incorrect. What is 2 + 2?"),
]

gmessages = GeminiSolver._convert_msgs_to_google_format(messages)
expected = [
GoogleMessage(role="user", parts=["You are a great mathematician.\nWhat is 2 + 2?"]),
GoogleMessage(role="model", parts=["5"]),
GoogleMessage(role="user", parts=["That's incorrect. What is 2 + 2?"]),
]

assert gmessages == expected, f"Expected {expected}, but got {gmessages}"
1 change: 1 addition & 0 deletions evals/solvers/providers/google/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
google-generativeai
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ dependencies = [
"gymnasium",
"networkx",
"chess",
"anthropic"
"anthropic",
"google-generativeai",
]

[project.urls]
Expand Down

0 comments on commit 21a57d7

Please sign in to comment.