@@ -53,7 +53,7 @@ def check_service_status(self) -> BaseResponse:
53
53
return BaseResponse (code = 404 , msg = f"未找到代码库 { self .engine_name } " )
54
54
return BaseResponse (code = 200 , msg = f"找到代码库 { self .engine_name } " )
55
55
56
- def _process (self , query : str , history : List [History ], model , llm_config : LLMConfig , embed_config : EmbedConfig ):
56
+ def _process (self , query : str , history : List [History ], model , llm_config : LLMConfig , embed_config : EmbedConfig , local_graph_path = "" ):
57
57
'''process'''
58
58
59
59
codes_res = search_code (query = query , cb_name = self .engine_name , code_limit = self .code_limit ,
@@ -67,7 +67,8 @@ def _process(self, query: str, history: List[History], model, llm_config: LLMCon
67
67
embed_model_path = embed_config .embed_model_path ,
68
68
embed_engine = embed_config .embed_engine ,
69
69
model_device = embed_config .model_device ,
70
- embed_config = embed_config
70
+ embed_config = embed_config ,
71
+ local_graph_path = local_graph_path
71
72
)
72
73
73
74
context = codes_res ['context' ]
@@ -115,6 +116,7 @@ def chat(
115
116
model_name : str = Body ("" , ),
116
117
temperature : float = Body (0.5 , ),
117
118
model_device : str = Body ("" , ),
119
+ local_graph_path : str = Body (", " ),
118
120
** kargs
119
121
):
120
122
params = locals ()
@@ -127,9 +129,9 @@ def chat(
127
129
self .local_doc_url = local_doc_url if isinstance (local_doc_url , bool ) else local_doc_url .default
128
130
self .request = request
129
131
self .cb_search_type = cb_search_type
130
- return self ._chat (query , history , llm_config , embed_config , ** kargs )
132
+ return self ._chat (query , history , llm_config , embed_config , local_graph_path , ** kargs )
131
133
132
- def _chat (self , query : str , history : List [History ], llm_config : LLMConfig , embed_config : EmbedConfig , ** kargs ):
134
+ def _chat (self , query : str , history : List [History ], llm_config : LLMConfig , embed_config : EmbedConfig , local_graph_path : str , ** kargs ):
133
135
history = [History (** h ) if isinstance (h , dict ) else h for h in history ]
134
136
135
137
service_status = self .check_service_status ()
@@ -140,7 +142,7 @@ def chat_iterator(query: str, history: List[History]):
140
142
# model = getChatModel()
141
143
model = getChatModelFromConfig (llm_config )
142
144
143
- result , content = self .create_task (query , history , model , llm_config , embed_config , ** kargs )
145
+ result , content = self .create_task (query , history , model , llm_config , embed_config , local_graph_path , ** kargs )
144
146
# logger.info('result={}'.format(result))
145
147
# logger.info('content={}'.format(content))
146
148
@@ -156,9 +158,9 @@ def chat_iterator(query: str, history: List[History]):
156
158
return StreamingResponse (chat_iterator (query , history ),
157
159
media_type = "text/event-stream" )
158
160
159
- def create_task (self , query : str , history : List [History ], model , llm_config : LLMConfig , embed_config : EmbedConfig ):
161
+ def create_task (self , query : str , history : List [History ], model , llm_config : LLMConfig , embed_config : EmbedConfig , local_graph_path : str ):
160
162
'''构建 llm 生成任务'''
161
- chain , context , result = self ._process (query , history , model , llm_config , embed_config )
163
+ chain , context , result = self ._process (query , history , model , llm_config , embed_config , local_graph_path )
162
164
logger .info ('chain={}' .format (chain ))
163
165
try :
164
166
content = chain ({"context" : context , "question" : query })
0 commit comments