-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
84 lines (69 loc) · 2.02 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
from pydantic import BaseModel, ValidationError
from typing import Optional
from services.llms.model_selector import ModelProvider, ModelSelector
PARAMETERS_FILE = "params.json"
class CheckpointerAuthParams(BaseModel):
username: str
password: str
ssl: bool
class CheckpointerParams(BaseModel):
endpoint: str
port: int
db_number: int
auth_params: Optional[CheckpointerAuthParams] = None
class WebRetrieverParams(BaseModel):
enabled: bool
urls: list[str]
is_advanced_search: bool
max_number_of_resources: int
class KdbRetrieverParams(BaseModel):
kdb_id: str
kdb_max_number_of_results: int
kdb_region: str
class ToolsParams(BaseModel):
name: str
description: str
tool_schema: dict
endpoint: str
endpoint_config: dict
class WokerParams(BaseModel):
id: str
name: str
task: str
tools: list[ToolsParams]
class AssistantWorkerParams(BaseModel):
id: str
task: str
class ResearcherWorkerParams(BaseModel):
id: str
task: str
enabled: bool
web_retriever: WebRetrieverParams
kdb_retriever_params: KdbRetrieverParams
class ModelProviderParams(BaseModel):
provider: ModelProvider
model_id: str
temperature: float
provider_args: Optional[dict] = None
class ParametrizationAgent(BaseModel):
model_provider: Optional[ModelProviderParams] = None
checkpointer: CheckpointerParams
workers: list[WokerParams]
assistant_worker: AssistantWorkerParams
researcher_worker: ResearcherWorkerParams
def validate_parametrization_file(json_data):
try:
parametrization = ParametrizationAgent(**json_data)
return parametrization
except ValidationError as e:
print("the provided config format is not valid: ", e)
raise e
def retrieve_parameters():
parameters = None
with open(PARAMETERS_FILE, 'r') as file:
data = json.load(file)
parameters = validate_parametrization_file(data)
return parameters
if __name__ == "__main__":
retrieve_parameters()