From faecd08a679203f330db16f55f765eeb35c8dc99 Mon Sep 17 00:00:00 2001 From: Brian Muhia Date: Thu, 11 Jan 2024 19:54:14 +0300 Subject: [PATCH] get top logprobs from 'content' field, extract prediction from field names --- ice/agents/openai.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/ice/agents/openai.py b/ice/agents/openai.py index b5ffaf81..1179b3f8 100644 --- a/ice/agents/openai.py +++ b/ice/agents/openai.py @@ -235,13 +235,6 @@ async def relevance( "OpenAI ChatCompletion has no option to return a relevance score." ) - # async def predict( - # self, *, context: str, default: str = "", verbose: bool = False - # ) -> dict[str, float]: - # raise NotImplementedError( - # "OpenAI ChatCompletion does not support getting probabilities." - # ) - async def _complete(self, prompt, **kwargs) -> dict: """Send a completion request to the OpenAI API with the given prompt and parameters.""" kwargs.update( @@ -272,8 +265,8 @@ def _extract_completion(self, response: dict) -> str: def _extract_prediction(self, response: dict) -> dict[str, float]: """Extract the prediction dictionary from the completion response.""" - answer = response["choices"][0]["logprobs"]["top_logprobs"][0] - return {k: math.exp(p) for (k, p) in answer.items()} + answer = response["choices"][0]["logprobs"]["content"][0]["top_logprobs"] + return {a['token']: math.exp(a['logprob']) for a in answer} def _compute_relative_probs( self, choices: tuple[str, ...], choice_prefix: str, prediction: dict[str, float]