Skip to content

Commit

Permalink
add mistralai
Browse files Browse the repository at this point in the history
  • Loading branch information
kharvd committed Dec 14, 2023
1 parent 9ce23b0 commit 1f6defb
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 1 deletion.
3 changes: 3 additions & 0 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gptcli.completion import CompletionProvider, ModelOverrides, Message
from gptcli.google import GoogleCompletionProvider
from gptcli.llama import LLaMACompletionProvider
from gptcli.mistral import MistralCompletionProvider
from gptcli.openai import OpenAICompletionProvider
from gptcli.anthropic import AnthropicCompletionProvider

Expand Down Expand Up @@ -64,6 +65,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
return LLaMACompletionProvider()
elif model.startswith("chat-bison"):
return GoogleCompletionProvider()
elif model.startswith("mistral"):
return MistralCompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
1 change: 1 addition & 0 deletions gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class GptCliConfig:
show_price: bool = True
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
mistral_api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY")
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
log_file: Optional[str] = None
Expand Down
4 changes: 4 additions & 0 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import google.generativeai as genai
import gptcli.anthropic
import gptcli.mistral
from gptcli.assistant import (
Assistant,
DEFAULT_ASSISTANTS,
Expand Down Expand Up @@ -178,6 +179,9 @@ def main():
)
sys.exit(1)

if config.mistral_api_key:
gptcli.mistral.api_key = config.mistral_api_key

if config.anthropic_api_key:
gptcli.anthropic.api_key = config.anthropic_api_key

Expand Down
47 changes: 47 additions & 0 deletions gptcli/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Iterator, List
import os
from gptcli.completion import CompletionProvider, Message
from mistralai.client import MistralClient, ChatMessage

api_key = os.environ.get("MISTRAL_API_KEY")


class MistralCompletionProvider(CompletionProvider):
def __init__(self):
self.client = MistralClient(api_key=api_key)

def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
kwargs = {}
if "temperature" in args:
kwargs["temperature"] = args["temperature"]
if "top_p" in args:
kwargs["top_p"] = args["top_p"]

if stream:
response_iter = self.client.chat_stream(
model=args["model"],
messages=[
ChatMessage(role=msg["role"], content=msg["content"])
for msg in messages
],
**kwargs,
)

for response in response_iter:
next_choice = response.choices[0]
if next_choice.finish_reason is None and next_choice.delta.content:
yield next_choice.delta.content
else:
response = self.client.chat(
model=args["model"],
messages=[
ChatMessage(role=msg["role"], content=msg["content"])
for msg in messages
],
**kwargs,
)
next_choice = response.choices[0]
if next_choice.message.content:
yield next_choice.message.content
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"anthropic==0.7.7",
"attrs==23.1.0",
"black==23.1.0",
"mistralai==0.0.8",
"google-generativeai==0.1.0",
"openai==1.3.8",
"prompt-toolkit==3.0.41",
Expand All @@ -28,7 +29,7 @@ dependencies = [
"rich==13.7.0",
"tiktoken==0.5.2",
"tokenizers==0.15.0",
"typing_extensions==4.5.0",
"typing_extensions==4.9.0",
]

[project.optional-dependencies]
Expand Down

0 comments on commit 1f6defb

Please sign in to comment.