Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 11, 2024
1 parent 37add8f commit c065559
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
23 changes: 16 additions & 7 deletions ice/agents/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)


class OpenAIChatCompletionAgent(Agent):
"""An agent that uses the OpenAI ChatCompletion API to generate completions."""

Expand All @@ -153,12 +154,12 @@ def __init__(
temperature: float = 0.0,
top_p: float = 1.0,
logprobs: bool = False,
top_logprobs: int = None
top_logprobs: int = None,
):
self.model = model
self.temperature = temperature
self.top_p = top_p
self.logprobs = logprobs,
self.logprobs = (logprobs,)
self.top_logprobs = top_logprobs

async def complete(
Expand All @@ -178,12 +179,14 @@ async def complete(
if verbose:
self._print_markdown(completion)
return completion

async def predict(self, *, context, default="", verbose=False) -> dict[str, float]:
"""Generate a probability distribution over the next token given some context."""
if verbose:
self._print_markdown(context)
response = await self._complete(context, top_logprobs=5, logprobs=True, max_tokens=1)
response = await self._complete(
context, top_logprobs=5, logprobs=True, max_tokens=1
)
prediction = self._extract_prediction(response)
if verbose:
self._print_markdown(prediction)
Expand Down Expand Up @@ -248,11 +251,16 @@ async def _complete(self, prompt, **kwargs) -> dict:
"top_p": self.top_p,
"n": 1,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
"top_logprobs": self.top_logprobs,
}
)
messages = [{"role": "system", "content": "You are a helpful assistant. Your answers follow instructions and remain grounded in the context."},
{"role": "user", "content": prompt}]
messages = [
{
"role": "system",
"content": "You are a helpful assistant. Your answers follow instructions and remain grounded in the context.",
},
{"role": "user", "content": prompt},
]
response = await openai_chatcomplete(messages, **kwargs)
if "choices" not in response:
raise ValueError(f"No choices in response: {response}")
Expand Down Expand Up @@ -298,6 +306,7 @@ def _print_markdown(self, obj: Any):
"""Print the text with markdown formatting."""
env().print(obj, format_markdown=True)


class OpenAIEmbeddingAgent(Agent):
"""An agent that uses the OpenAI API to generate a relevance score by cosine similarity between two text embeddings."""

Expand Down
2 changes: 1 addition & 1 deletion ice/apis/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def openai_chatcomplete(
"max_tokens": max_tokens,
"n": n,
"logprobs": logprobs,
"top_logprobs": top_logprobs
"top_logprobs": top_logprobs,
}
if logit_bias:
params["logit_bias"] = logit_bias # type: ignore[assignment]
Expand Down

0 comments on commit c065559

Please sign in to comment.