From 6bdd57662ac746cea562acc5509cba8e480b207f Mon Sep 17 00:00:00 2001 From: Grey_D Date: Fri, 12 May 2023 21:43:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=B8=20continue=20from=20previo?= =?UTF-8?q?us=20session?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + utils/chatgpt.py | 5 ++ utils/pentest_gpt.py | 111 ++++++++++++++++++++++++++----------------- 3 files changed, 74 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index 02c8d63..5380b95 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ outputs/ logs/ utils/logs/ archive/ +test_history/ # C extensions *.so diff --git a/utils/chatgpt.py b/utils/chatgpt.py index b30b833..f085f28 100644 --- a/utils/chatgpt.py +++ b/utils/chatgpt.py @@ -243,6 +243,11 @@ def send_message(self, message, conversation_id): message.answer = result message.request_end_timestamp = end_time message.time_escaped = end_time - start_time + # add additional logic for reloading (only for PentestGPT continue from previous sessions) + if conversation_id not in self.conversation_dict: + conversation: Conversation = Conversation() + conversation.conversation_id = conversation_id + self.conversation_dict[conversation_id] = conversation conversation: Conversation = self.conversation_dict[conversation_id] conversation.message_list.append(message) return text diff --git a/utils/pentest_gpt.py b/utils/pentest_gpt.py index 6554af9..e5bcc7f 100644 --- a/utils/pentest_gpt.py +++ b/utils/pentest_gpt.py @@ -51,6 +51,8 @@ class pentestGPT: def __init__(self, reasoning_model="gpt-4", useAPI=False): self.log_dir = "logs" + self.save_dir = "test_history" + self.task_log = {} # the information that can be saved to continue in the next session self.useAPI = useAPI if useAPI is False: self.chatGPTAgent = ChatGPT(ChatGPTConfig()) @@ -95,6 +97,40 @@ def log_conversation(self, source, text): source = "exception" self.history[source].append((timestamp, text)) + def _feed_init_prompts(self): + # 1. User firstly provide basic information of the task + init_description = prompt_ask( + "Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ", + multiline=False, + ) + self.log_conversation("user", init_description) + self.task_log['task description'] = init_description + ## Provide the information to the reasoning session for the task initialization. + prefixed_init_description = self.prompts.task_description + init_description + with self.console.status( + "[bold green] Generating Task Information..." + ) as status: + _response = self.reasoning_handler(prefixed_init_description) + self.console.print("- Task information generated. \n", style="bold green") + # 2. Reasoning session generates the first thing to do and provide the information to the generation session + with self.console.status("[bold green]Processing...") as status: + first_generation_response = self.test_generation_handler( + self.prompts.todo_to_command + self.prompts.first_todo + ) + # 3. Show user the first thing to do. + self.console.print( + "PentestGPT suggests you to do the following: ", style="bold green" + ) + self.console.print(_response) + self.log_conversation( + "PentestGPT", "PentestGPT suggests you to do the following: \n" + _response + ) + self.console.print("You may start with:", style="bold green") + self.console.print(first_generation_response) + self.log_conversation( + "PentestGPT", "You may start with: \n" + first_generation_response + ) + def initialize(self, previous_session_ids=None): # initialize the backbone sessions and test the connection to chatGPT # define three sessions: testGenerationSession, testReasoningSession, and InputParsingSession @@ -102,14 +138,24 @@ def initialize(self, previous_session_ids=None): previous_session_ids is not None and self.useAPI is False ): # TODO: add support for API usage self.test_generation_session_id = previous_session_ids.get( - "test_generation_session_id", None + "test_generation", None ) self.test_reasoning_session_id = previous_session_ids.get( - "test_reasoning_session_id", None + "reasoning", None ) self.input_parsing_session_id = previous_session_ids.get( - "input_parsing_session_id", None + "parsing", None ) + # debug the three sessions + print("Previous session ids: " + str(previous_session_ids)) + print("Test generation session id: " + str(self.test_generation_session_id)) + print("Test reasoning session id: " + str(self.test_reasoning_session_id)) + print("Input parsing session id: " + str(self.input_parsing_session_id)) + print("-----------------") + self.task_log = previous_session_ids.get("task_log", {}) + self.console.print("Task log: " + str(self.task_log), style="bold green") + print("You may use discussion function to remind yourself of the task.") + ## verify that all the sessions are not None if ( self.test_generation_session_id is None @@ -120,6 +166,7 @@ def initialize(self, previous_session_ids=None): "[bold red] Error: the previous session ids are not valid. Loading new sessions" ) self.initialize() + else: with self.console.status( "[bold green] Initialize ChatGPT Sessions..." @@ -146,6 +193,8 @@ def initialize(self, previous_session_ids=None): except Exception as e: logger.error(e) self.console.print("- ChatGPT Sessions Initialized.", style="bold green") + self._feed_init_prompts() + def reasoning_handler(self, text) -> str: # summarize the contents if necessary. @@ -380,7 +429,6 @@ def input_handler(self) -> str: self.log_conversation("pentestGPT", response) ### (2.3) local task handler - while True: local_task_response = self.local_input_handler() if local_task_response == "continue": @@ -432,6 +480,7 @@ def input_handler(self) -> str: ## (2) pass the information to the reasoning session. with self.console.status("[bold green] PentestGPT Thinking...") as status: response = self.reasoning_handler(self.prompts.discussion + user_input) + print("debug, finished reasoning") ## (3) print the results self.console.print("PentestGPT:\n", style="bold green") self.console.print(response + "\n", style="yellow") @@ -477,6 +526,7 @@ def save_session(self): Save the current session for next round of usage. The test information is saved in the directory `./test_history` """ + self.console.print("Before you quit, you may want to save the current session.", style="bold green") # 1. Require a save name from the user. If not, use the current time as the save name. save_name = prompt_ask( "Please enter the name of the current session. (Default with current timestamp)\n> ", @@ -486,11 +536,12 @@ def save_session(self): save_name = str(time.time()) # 2. Save the current session with open(os.path.join(self.save_dir, save_name), "w") as f: - # store the three ids + # store the three ids and task_log session_ids = { - "reasoning": self.test_generation_session_id, + "reasoning": self.test_reasoning_session_id, "test_generation": self.test_generation_session_id, "parsing": self.input_parsing_session_id, + "task_log": self.task_log, } json.dump(session_ids, f) self.console.print( @@ -507,7 +558,6 @@ def _preload_session(self) -> dict: None if no previous session is found. """ # 1. get user input for the saved_session_name - self._preload_session() continue_from_previous = confirm( "Do you want to continue from previous session?" ) @@ -518,13 +568,18 @@ def _preload_session(self) -> dict: print("No previous session found. Please start a new session.") return None else: # print all the files - print("Please select the previous session you want to continue:") + print("Please select the previous session by its index (integer):") for i, filename in enumerate(filenames): print(str(i) + ". " + filename) # ask for the user input - previous_testing_name = filenames[ - int(input("Please key in your option: ")) - ] + try: + previous_testing_name = filenames[ + int(input("Please key in your option (integer): ")) + ] + print("You selected: " + previous_testing_name) + except ValueError as e: + print("You input an invalid option. Will start a new session.") + return None elif continue_from_previous is False: return None @@ -554,39 +609,9 @@ def main(self): loaded_ids = self._preload_session() self.initialize(previous_session_ids=loaded_ids) - # 1. User firstly provide basic information of the task - init_description = prompt_ask( - "Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ", - multiline=False, - ) - self.log_conversation("user", init_description) - ## Provide the information to the reasoning session for the task initialization. - prefixed_init_description = self.prompts.task_description + init_description - with self.console.status( - "[bold green] Generating Task Information..." - ) as status: - _response = self.reasoning_handler(prefixed_init_description) - self.console.print("- Task information generated. \n", style="bold green") - # 2. Reasoning session generates the first thing to do and provide the information to the generation session - with self.console.status("[bold green]Processing...") as status: - first_generation_response = self.test_generation_handler( - self.prompts.todo_to_command + self.prompts.first_todo - ) - # 3. Show user the first thing to do. - self.console.print( - "PentestGPT suggests you to do the following: ", style="bold green" - ) - self.console.print(_response) - self.log_conversation( - "PentestGPT", "PentestGPT suggests you to do the following: \n" + _response - ) - self.console.print("You may start with:", style="bold green") - self.console.print(first_generation_response) - self.log_conversation( - "PentestGPT", "You may start with: \n" + first_generation_response - ) - # 4. enter the main loop. + + # enter the main loop. while True: try: result = self.input_handler()