Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[NeuralChat] Refactor ut server cases and improve code coverage (#914)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel authored Dec 14, 2023
1 parent 6641e04 commit 351dc6f
Show file tree
Hide file tree
Showing 25 changed files with 523 additions and 828 deletions.
29 changes: 0 additions & 29 deletions .github/workflows/script/unitTest/run_unit_test_neuralchat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,9 @@ function pytest() {
ut_log_name=${LOG_DIR}/${JOB_NAME}.log
export GLOG_minloglevel=2

# Kill the neuralchat server processes
ports="5000 6000 6001 6060 7000 7070 7777 8000 8080 9000 9090"
# Loop through each port and find associated PIDs
for port in $ports; do
# Use lsof to find the processes associated with the port
pids=$(lsof -ti :$port)
if [ -n "$pids" ]; then
echo "Processes running on port $port: $pids"
# Terminate the processes gracefully with SIGTERM
kill $pids
echo "Terminated processes on port $port."
else
echo "No processes found on port $port."
fi
done

itrex_path=$(python -c 'import intel_extension_for_transformers; import os; print(os.path.dirname(intel_extension_for_transformers.__file__))')
find . -name "test*.py" | sed 's,\.\/,coverage run --source='"${itrex_path}"' --append ,g' | sed 's/$/ --verbose/' >> run.sh
sort run.sh -o run.sh
echo -e '
ports="5000 6000 6001 6060 7000 7070 7777 8000 8080 9000 9090"
for port in $ports; do
pids=$(lsof -ti :$port)
if [ -n "$pids" ]; then
echo "Processes running on port $port: $pids"
kill $pids
echo "Terminated processes on port $port."
else
echo "No processes found on port $port."
fi
done
' >> run.sh
coverage erase

# run UT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self) -> None:
super().__init__()
self.chatbot = None

def set_chatbot(self, bot, use_deepspeed, world_size, host, port) -> None:
def set_chatbot(self, bot, use_deepspeed=False, world_size=1, host="0.0.0.0", port="80") -> None:
self.chatbot = bot
self.use_deepspeed = use_deepspeed
self.world_size = world_size
Expand All @@ -68,13 +68,13 @@ def get_chatbot(self):
if self.chatbot is None:
raise RuntimeError("Retrievalbot instance has not been set.")
return self.chatbot

def handle_retrieval_request(self, request: RetrievalRequest) -> RetrievalResponse:
bot = self.get_chatbot()
# TODO: NeuralChatBot.retrieve_model()
result = bot.predict(request)
return RetrievalResponse(content=result)


router = RetrievalAPIRouter()
RETRIEVAL_FILE_PATH = os.getenv("RETRIEVAL_FILE_PATH", default="./photoai_retrieval_docs")+'/'
Expand Down Expand Up @@ -103,11 +103,10 @@ async def retrieval_upload_link(request: Request):
instance = plugins['retrieval']["instance"]
instance.append_localdb(append_path=link_list, persist_path=persist_path)
print(f"[askdoc - upload_link] kb appended successfully")
except Exception as e:
except Exception as e: # pragma: no cover
logger.info(f"[askdoc - upload_link] create knowledge base failes! {e}")
return Response(content="Error occurred while uploading links.", status_code=500)
return {"Succeed"}

# create new kb with link
else:
print(f"[askdoc - upload_link] create")
Expand All @@ -119,7 +118,7 @@ async def retrieval_upload_link(request: Request):
cur_path = Path(path_prefix) / f"{user_id}-{kb_id}"
os.makedirs(path_prefix, exist_ok=True)
cur_path.mkdir(parents=True, exist_ok=True)

user_upload_dir = Path(path_prefix) / f"{user_id}-{kb_id}/upload_dir"
user_persist_dir = Path(path_prefix) / f"{user_id}-{kb_id}/persist_dir"
user_upload_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -132,11 +131,10 @@ async def retrieval_upload_link(request: Request):
instance = plugins['retrieval']["instance"]
instance.create(input_path=link_list, persist_dir=str(user_persist_dir))
print(f"[askdoc - upload_link] kb created successfully")
except Exception as e:
except Exception as e: # pragma: no cover
logger.info(f"[askdoc - upload_link] create knowledge base failes! {e}")
return "Error occurred while uploading files."
return {"knowledge_base_id": kb_id}


@router.post("/v1/aiphotos/askdoc/create")
async def retrieval_create(request: Request,
Expand Down Expand Up @@ -178,8 +176,8 @@ async def retrieval_create(request: Request,
instance = plugins['retrieval']["instance"]
instance.create(input_path=str(user_upload_dir), persist_dir=str(user_persist_dir))
print(f"[askdoc - create] kb created successfully")
except Exception as e:
logger.info(f"[askdoc - create] create knowledge base failes! {e}")
except Exception as e: # pragma: no cover
logger.info(f"[askdoc - create] create knowledge base failed! {e}")
return "Error occurred while uploading files."
return {"knowledge_base_id": kb_id}

Expand Down Expand Up @@ -218,7 +216,7 @@ async def retrieval_append(request: Request,
instance = plugins['retrieval']["instance"]
instance.append_localdb(append_path=save_file_name, persist_path=persist_path)
print(f"[askdoc - append] new file successfully appended to kb")
except Exception as e:
except Exception as e: # pragma: no cover
logger.info(f"[askdoc - append] create knowledge base failes! {e}")
return "Error occurred while uploading files."
return "Succeed"
Expand Down Expand Up @@ -326,7 +324,7 @@ def save_chat_feedback_to_db(request: FeedbackRequest) -> None:
try:
with mysql_db.transaction():
mysql_db.insert(sql, None)
except:
except: # pragma: no cover
raise Exception("""Exception occurred when inserting data into MySQL,
please check the db session and your syntax.""")
else:
Expand All @@ -342,8 +340,7 @@ def get_feedback_from_db():
sql = f"SELECT * FROM feedback ;"
try:
feedback_list = mysql_db.fetch_all(sql)

except:
except: # pragma: no cover
raise Exception("""Exception occurred when querying data from MySQL, \
please check the db session and your syntax.""")
else:
Expand Down Expand Up @@ -373,5 +370,3 @@ def data_generator():
data_generator(),
media_type='text/csv',
headers={"Content-Disposition": f"attachment;filename=feedback{cur_time_str}.csv"})


Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TextChatAPIRouter(APIRouter):
def __init__(self) -> None:
super().__init__()

def set_chatbot(self, chatbot, use_deepspeed, world_size, host, port) -> None:
def set_chatbot(self, chatbot, use_deepspeed=False, world_size=1, host="0.0.0.0", port=80) -> None:
self.chatbot = chatbot
self.use_deepspeed = use_deepspeed
self.world_size = world_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self) -> None:
super().__init__()
self.chatbot = None

def set_chatbot(self, chatbot, use_deepspeed, world_size, host, port) -> None:
def set_chatbot(self, chatbot, use_deepspeed=False, world_size=1, host="0.0.0.0", port=80) -> None:
self.chatbot = chatbot
self.use_deepspeed = use_deepspeed
self.world_size = world_size
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit 351dc6f

Please sign in to comment.