Skip to content

Commit

Permalink
Add logprobs, classify and predict to OpenAIChatCompletionAgent
Browse files Browse the repository at this point in the history
  • Loading branch information
poppingtonic authored Jan 11, 2024
1 parent e50edc5 commit 37add8f
Showing 1 changed file with 78 additions and 12 deletions.
90 changes: 78 additions & 12 deletions ice/agents/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)


class OpenAIChatCompletionAgent(Agent):
"""An agent that uses the OpenAI ChatCompletion API to generate completions."""

Expand All @@ -153,10 +152,14 @@ def __init__(
model: str = "gpt-3.5-turbo",
temperature: float = 0.0,
top_p: float = 1.0,
logprobs: bool = False,
top_logprobs: int = None
):
self.model = model
self.temperature = temperature
self.top_p = top_p
self.logprobs = logprobs,
self.top_logprobs = top_logprobs

async def complete(
self,
Expand All @@ -175,6 +178,16 @@ async def complete(
if verbose:
self._print_markdown(completion)
return completion

async def predict(self, *, context, default="", verbose=False) -> dict[str, float]:
"""Generate a probability distribution over the next token given some context."""
if verbose:
self._print_markdown(context)
response = await self._complete(context, top_logprobs=5, logprobs=True, max_tokens=1)
prediction = self._extract_prediction(response)
if verbose:
self._print_markdown(prediction)
return prediction

async def classify(
self,
Expand All @@ -184,9 +197,28 @@ async def classify(
default: Optional[str] = None,
verbose: bool = False,
) -> tuple[dict[str, float], Optional[str]]:
raise NotImplementedError(
"OpenAI ChatCompletion has no option to score a classification."
)
"""Generate a classification from a list of choices given some context and a question."""
if verbose:
self._print_markdown(prompt)
self._print_markdown(choices)

choice_prefix = longest_common_prefix(choices).rstrip()
prompt_with_prefix = f"{prompt}{choice_prefix}"

if prompt_with_prefix.endswith(" "):
prompt_with_prefix = prompt_with_prefix[:-1]
default = " "
else:
default = ""

prediction = await self.predict(context=prompt_with_prefix, default=default)

rel_probs = self._compute_relative_probs(choices, choice_prefix, prediction)

if verbose:
self._print_markdown(rel_probs)

return rel_probs, None

async def relevance(
self,
Expand All @@ -200,12 +232,12 @@ async def relevance(
"OpenAI ChatCompletion has no option to return a relevance score."
)

async def predict(
self, *, context: str, default: str = "", verbose: bool = False
) -> dict[str, float]:
raise NotImplementedError(
"OpenAI ChatCompletion does not support getting probabilities."
)
# async def predict(
# self, *, context: str, default: str = "", verbose: bool = False
# ) -> dict[str, float]:
# raise NotImplementedError(
# "OpenAI ChatCompletion does not support getting probabilities."
# )

async def _complete(self, prompt, **kwargs) -> dict:
"""Send a completion request to the OpenAI API with the given prompt and parameters."""
Expand All @@ -215,9 +247,12 @@ async def _complete(self, prompt, **kwargs) -> dict:
"temperature": self.temperature,
"top_p": self.top_p,
"n": 1,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
)
messages = [{"role": "user", "content": prompt}]
messages = [{"role": "system", "content": "You are a helpful assistant. Your answers follow instructions and remain grounded in the context."},
{"role": "user", "content": prompt}]
response = await openai_chatcomplete(messages, **kwargs)
if "choices" not in response:
raise ValueError(f"No choices in response: {response}")
Expand All @@ -227,11 +262,42 @@ def _extract_completion(self, response: dict) -> str:
"""Extract the answer text from the completion response."""
return response["choices"][0]["message"]["content"].strip()

def _extract_prediction(self, response: dict) -> dict[str, float]:
"""Extract the prediction dictionary from the completion response."""
answer = response["choices"][0]["logprobs"]["top_logprobs"][0]
return {k: math.exp(p) for (k, p) in answer.items()}

def _compute_relative_probs(
self, choices: tuple[str, ...], choice_prefix: str, prediction: dict[str, float]
) -> dict[str, float]:
"""Compute the relative probabilities of the choices based on the prediction."""

def lookup_prob(choice: str):
scores = 0.0
for token, prob in prediction.items():
if choice[len(choice_prefix) :].startswith(token):
scores += prob
return scores

abs_probs = {choice: lookup_prob(choice) for choice in choices}
Z = sum(abs_probs.values())
if Z < 0.6:
log.warning(f"{1-Z} of unaccounted probability in classify")
log.warning(choice_prefix)
log.warning(str(prediction))
log.warning(str(abs_probs))

rel_probs = (
{choice: prob / Z for (choice, prob) in abs_probs.items()}
if Z != 0.0
else abs_probs
)
return rel_probs

def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)


class OpenAIEmbeddingAgent(Agent):
"""An agent that uses the OpenAI API to generate a relevance score by cosine similarity between two text embeddings."""

Expand Down

0 comments on commit 37add8f

Please sign in to comment.