Skip to content

Commit

Permalink
get top logprobs from 'content' field, extract prediction from field …
Browse files Browse the repository at this point in the history
…names
  • Loading branch information
poppingtonic authored Jan 11, 2024
1 parent 5851685 commit faecd08
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions ice/agents/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,6 @@ async def relevance(
"OpenAI ChatCompletion has no option to return a relevance score."
)

# async def predict(
# self, *, context: str, default: str = "", verbose: bool = False
# ) -> dict[str, float]:
# raise NotImplementedError(
# "OpenAI ChatCompletion does not support getting probabilities."
# )

async def _complete(self, prompt, **kwargs) -> dict:
"""Send a completion request to the OpenAI API with the given prompt and parameters."""
kwargs.update(
Expand Down Expand Up @@ -272,8 +265,8 @@ def _extract_completion(self, response: dict) -> str:

def _extract_prediction(self, response: dict) -> dict[str, float]:
"""Extract the prediction dictionary from the completion response."""
answer = response["choices"][0]["logprobs"]["top_logprobs"][0]
return {k: math.exp(p) for (k, p) in answer.items()}
answer = response["choices"][0]["logprobs"]["content"][0]["top_logprobs"]
return {a['token']: math.exp(a['logprob']) for a in answer}

def _compute_relative_probs(
self, choices: tuple[str, ...], choice_prefix: str, prediction: dict[str, float]
Expand Down

0 comments on commit faecd08

Please sign in to comment.