From 37add8fb8326be115a2ede5b0935d8a65534a166 Mon Sep 17 00:00:00 2001 From: Brian Muhia Date: Thu, 11 Jan 2024 13:34:45 +0300 Subject: [PATCH] Add logprobs, classify and predict to OpenAIChatCompletionAgent --- ice/agents/openai.py | 90 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/ice/agents/openai.py b/ice/agents/openai.py index 38cb2220..7db40220 100644 --- a/ice/agents/openai.py +++ b/ice/agents/openai.py @@ -144,7 +144,6 @@ def _print_markdown(self, obj: Any): """Print the text with markdown formatting.""" env().print(obj, format_markdown=True) - class OpenAIChatCompletionAgent(Agent): """An agent that uses the OpenAI ChatCompletion API to generate completions.""" @@ -153,10 +152,14 @@ def __init__( model: str = "gpt-3.5-turbo", temperature: float = 0.0, top_p: float = 1.0, + logprobs: bool = False, + top_logprobs: int = None ): self.model = model self.temperature = temperature self.top_p = top_p + self.logprobs = logprobs, + self.top_logprobs = top_logprobs async def complete( self, @@ -175,6 +178,16 @@ async def complete( if verbose: self._print_markdown(completion) return completion + + async def predict(self, *, context, default="", verbose=False) -> dict[str, float]: + """Generate a probability distribution over the next token given some context.""" + if verbose: + self._print_markdown(context) + response = await self._complete(context, top_logprobs=5, logprobs=True, max_tokens=1) + prediction = self._extract_prediction(response) + if verbose: + self._print_markdown(prediction) + return prediction async def classify( self, @@ -184,9 +197,28 @@ async def classify( default: Optional[str] = None, verbose: bool = False, ) -> tuple[dict[str, float], Optional[str]]: - raise NotImplementedError( - "OpenAI ChatCompletion has no option to score a classification." - ) + """Generate a classification from a list of choices given some context and a question.""" + if verbose: + self._print_markdown(prompt) + self._print_markdown(choices) + + choice_prefix = longest_common_prefix(choices).rstrip() + prompt_with_prefix = f"{prompt}{choice_prefix}" + + if prompt_with_prefix.endswith(" "): + prompt_with_prefix = prompt_with_prefix[:-1] + default = " " + else: + default = "" + + prediction = await self.predict(context=prompt_with_prefix, default=default) + + rel_probs = self._compute_relative_probs(choices, choice_prefix, prediction) + + if verbose: + self._print_markdown(rel_probs) + + return rel_probs, None async def relevance( self, @@ -200,12 +232,12 @@ async def relevance( "OpenAI ChatCompletion has no option to return a relevance score." ) - async def predict( - self, *, context: str, default: str = "", verbose: bool = False - ) -> dict[str, float]: - raise NotImplementedError( - "OpenAI ChatCompletion does not support getting probabilities." - ) + # async def predict( + # self, *, context: str, default: str = "", verbose: bool = False + # ) -> dict[str, float]: + # raise NotImplementedError( + # "OpenAI ChatCompletion does not support getting probabilities." + # ) async def _complete(self, prompt, **kwargs) -> dict: """Send a completion request to the OpenAI API with the given prompt and parameters.""" @@ -215,9 +247,12 @@ async def _complete(self, prompt, **kwargs) -> dict: "temperature": self.temperature, "top_p": self.top_p, "n": 1, + "logprobs": self.logprobs, + "top_logprobs": self.top_logprobs } ) - messages = [{"role": "user", "content": prompt}] + messages = [{"role": "system", "content": "You are a helpful assistant. Your answers follow instructions and remain grounded in the context."}, + {"role": "user", "content": prompt}] response = await openai_chatcomplete(messages, **kwargs) if "choices" not in response: raise ValueError(f"No choices in response: {response}") @@ -227,11 +262,42 @@ def _extract_completion(self, response: dict) -> str: """Extract the answer text from the completion response.""" return response["choices"][0]["message"]["content"].strip() + def _extract_prediction(self, response: dict) -> dict[str, float]: + """Extract the prediction dictionary from the completion response.""" + answer = response["choices"][0]["logprobs"]["top_logprobs"][0] + return {k: math.exp(p) for (k, p) in answer.items()} + + def _compute_relative_probs( + self, choices: tuple[str, ...], choice_prefix: str, prediction: dict[str, float] + ) -> dict[str, float]: + """Compute the relative probabilities of the choices based on the prediction.""" + + def lookup_prob(choice: str): + scores = 0.0 + for token, prob in prediction.items(): + if choice[len(choice_prefix) :].startswith(token): + scores += prob + return scores + + abs_probs = {choice: lookup_prob(choice) for choice in choices} + Z = sum(abs_probs.values()) + if Z < 0.6: + log.warning(f"{1-Z} of unaccounted probability in classify") + log.warning(choice_prefix) + log.warning(str(prediction)) + log.warning(str(abs_probs)) + + rel_probs = ( + {choice: prob / Z for (choice, prob) in abs_probs.items()} + if Z != 0.0 + else abs_probs + ) + return rel_probs + def _print_markdown(self, obj: Any): """Print the text with markdown formatting.""" env().print(obj, format_markdown=True) - class OpenAIEmbeddingAgent(Agent): """An agent that uses the OpenAI API to generate a relevance score by cosine similarity between two text embeddings."""