-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfactory.py
31 lines (27 loc) · 1.04 KB
/
factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from .lua_generate import LuaGenerator
from .py_generate import PyGenerator
from .rs_generate import RsGenerator
from .generator_types import Generator
from .model import ModelBase, GPT4, GPT35, StarChat, GPTDavinci, WizardCoder
def generator_factory(lang: str) -> Generator:
if lang == "py" or lang == "python":
return PyGenerator()
elif lang == "rs" or lang == "rust":
return RsGenerator()
elif lang == "lua":
return LuaGenerator()
else:
raise ValueError(f"Invalid language for generator: {lang}")
def model_factory(model_name: str) -> ModelBase:
if model_name == "gpt-4":
return GPT4()
elif model_name == "gpt-3.5-turbo":
return GPT35()
elif model_name == "starchat" or model_name == "star-chat":
return StarChat()
elif model_name == "wizardcoder" or model_name == "wizard-coder":
return WizardCoder()
elif model_name.startswith("text-davinci"):
return GPTDavinci(model_name)
else:
raise ValueError(f"Invalid model name: {model_name}")