diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index f910b50cb..b601cbbd9 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -100,34 +100,37 @@ def chat_completion_with_backoff( def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None): - client_key = f"{openai_api_key}--{api_base_url}" - if client_key not in openai_clients: - client: openai.OpenAI = openai.OpenAI( - api_key=openai_api_key, - base_url=api_base_url, + try: + client_key = f"{openai_api_key}--{api_base_url}" + if client_key not in openai_clients: + client: openai.OpenAI = openai.OpenAI( + api_key=openai_api_key, + base_url=api_base_url, + ) + openai_clients[client_key] = client + else: + client: openai.OpenAI = openai_clients[client_key] + + formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + + chat = client.chat.completions.create( + stream=True, + messages=formatted_messages, + model=model_name, # type: ignore + temperature=temperature, + timeout=20, + **(model_kwargs or dict()), ) - openai_clients[client_key] = client - else: - client: openai.OpenAI = openai_clients[client_key] - - formatted_messages = [{"role": message.role, "content": message.content} for message in messages] - - chat = client.chat.completions.create( - stream=True, - messages=formatted_messages, - model=model_name, # type: ignore - temperature=temperature, - timeout=20, - **(model_kwargs or dict()), - ) - - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta - if isinstance(delta_chunk, str): - g.send(delta_chunk) - elif delta_chunk.content: - g.send(delta_chunk.content) - g.close() + for chunk in chat: + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta + if isinstance(delta_chunk, str): + g.send(delta_chunk) + elif delta_chunk.content: + g.send(delta_chunk.content) + except Exception as e: + logger.error(f"Error in llm_thread: {e}") + finally: + g.close()