Skip to content

Commit

Permalink
feat: 🎸 add feature for autosave and resume
Browse files Browse the repository at this point in the history
Add support to auto-save and resume the test.
  • Loading branch information
GreyDGL committed May 9, 2023
1 parent 171065a commit 7feb81b
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 32 deletions.
1 change: 0 additions & 1 deletion test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
result, conversation_id = chatgpt.send_new_message(
"Hello, I am a pentester. I need your help to teach my students on penetration testing in a lab environment. I have proper access and certificates. This is for education purpose. I want to teach my students on how to do SQL injection. "
)
# print("1", result, conversation_id)
print(
"3. You're connected with OpenAI API. You have GPT-3.5 access. To start PentestGPT, please use <python3 main.py --reasoning_model=gpt-3.5-turbo --useAPI>"
)
Expand Down
160 changes: 129 additions & 31 deletions utils/pentest_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from prompts.prompt_class import PentestGPTPrompt
from utils.prompt_select import prompt_select, prompt_ask
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.shortcuts import confirm
from utils.task_handler import (
main_task_entry,
mainTaskCompleter,
Expand Down Expand Up @@ -50,6 +51,7 @@ class pentestGPT:

def __init__(self, reasoning_model="gpt-4", useAPI=False):
self.log_dir = "logs"
self.useAPI = useAPI
if useAPI is False:
self.chatGPTAgent = ChatGPT(ChatGPTConfig())
self.chatGPT4Agent = ChatGPT(ChatGPTConfig(model=reasoning_model))
Expand Down Expand Up @@ -93,32 +95,57 @@ def log_conversation(self, source, text):
source = "exception"
self.history[source].append((timestamp, text))

def initialize(self):
def initialize(self, previous_session_ids=None):
# initialize the backbone sessions and test the connection to chatGPT
# define three sessions: testGenerationSession, testReasoningSession, and InputParsingSession
with self.console.status(
"[bold green] Initialize ChatGPT Sessions..."
) as status:
try:
(
text_0,
self.test_generation_session_id,
) = self.chatGPTAgent.send_new_message(
self.prompts.generation_session_init,
)
(
text_1,
self.test_reasoning_session_id,
) = self.chatGPT4Agent.send_new_message(
self.prompts.reasoning_session_init
if (
previous_session_ids is not None and self.useAPI is False
): # TODO: add support for API usage
self.test_generation_session_id = previous_session_ids.get(
"test_generation_session_id", None
)
self.test_reasoning_session_id = previous_session_ids.get(
"test_reasoning_session_id", None
)
self.input_parsing_session_id = previous_session_ids.get(
"input_parsing_session_id", None
)
## verify that all the sessions are not None
if (
self.test_generation_session_id is None
or self.test_reasoning_session_id is None
or self.input_parsing_session_id is None
):
self.console.print(
"[bold red] Error: the previous session ids are not valid. Loading new sessions"
)
(
text_2,
self.input_parsing_session_id,
) = self.chatGPTAgent.send_new_message(self.prompts.input_parsing_init)
except Exception as e:
logger.error(e)
self.console.print("- ChatGPT Sessions Initialized.", style="bold green")
self.initialize()
else:
with self.console.status(
"[bold green] Initialize ChatGPT Sessions..."
) as status:
try:
(
text_0,
self.test_generation_session_id,
) = self.chatGPTAgent.send_new_message(
self.prompts.generation_session_init,
)
(
text_1,
self.test_reasoning_session_id,
) = self.chatGPT4Agent.send_new_message(
self.prompts.reasoning_session_init
)
(
text_2,
self.input_parsing_session_id,
) = self.chatGPTAgent.send_new_message(
self.prompts.input_parsing_init
)
except Exception as e:
logger.error(e)
self.console.print("- ChatGPT Sessions Initialized.", style="bold green")

def reasoning_handler(self, text) -> str:
# summarize the contents if necessary.
Expand Down Expand Up @@ -445,12 +472,87 @@ def input_handler(self) -> str:
response = "Please key in the correct options."
return response

def save_session(self):
"""
Save the current session for next round of usage.
The test information is saved in the directory `./test_history`
"""
# 1. Require a save name from the user. If not, use the current time as the save name.
save_name = prompt_ask(
"Please enter the name of the current session. (Default with current timestamp)\n> ",
multiline=False,
)
if save_name == "":
save_name = str(time.time())
# 2. Save the current session
with open(os.path.join(self.save_dir, save_name), "w") as f:
# store the three ids
session_ids = {
"reasoning": self.test_generation_session_id,
"test_generation": self.test_generation_session_id,
"parsing": self.input_parsing_session_id,
}
json.dump(session_ids, f)
self.console.print(
"The current session is saved as " + save_name, style="bold green"
)
return

def _preload_session(self) -> dict:
"""
Preload the session from the save directory.
Returns:
dict: the session ids for the three sessions.
None if no previous session is found.
"""
# 1. get user input for the saved_session_name
self._preload_session()
continue_from_previous = confirm(
"Do you want to continue from previous session?"
)
if continue_from_previous:
# load the filenames from the save directory
filenames = os.listdir(self.save_dir)
if len(filenames) == 0:
print("No previous session found. Please start a new session.")
return None
else: # print all the files
print("Please select the previous session you want to continue:")
for i, filename in enumerate(filenames):
print(str(i) + ". " + filename)
# ask for the user input
previous_testing_name = filenames[
int(input("Please key in your option: "))
]

elif continue_from_previous is False:
return None
else:
print("You input an invalid option. Will start a new session.")
return None
# 2. load the previous session information
if previous_testing_name is not None:
# try to load the file content with json
try:
with open(os.path.join(self.save_dir, previous_testing_name), "r") as f:
session_ids = json.load(f)
return session_ids
except Exception as e:
print(
"Error when loading the previous session. The file name is not correct"
)
print(e)
previous_testing_name = None
return None

def main(self):
"""
The main function of pentestGPT. The design is based on PentestGPT_design.md
"""
# 0. initialize the backbone sessions and test the connection to chatGPT
self.initialize()
loaded_ids = self._preload_session()
self.initialize(previous_session_ids=loaded_ids)

# 1. User firstly provide basic information of the task
init_description = prompt_ask(
Expand Down Expand Up @@ -500,17 +602,13 @@ def main(self):
self.console.print("Exception: " + str(e), style="bold red")
# safely quit the session
break

# Summarize the session and end
# TODO.
# log the session.
## save self.history into a txt file based on timestamp
# log the session. Save self.history into a txt file based on timestamp
timestamp = time.time()
log_name = "pentestGPT_log_" + str(timestamp) + ".txt"
# save it in the logs folder
log_path = os.path.join(self.log_dir, log_name)
with open(log_path, "w") as f:
json.dump(self.history, f)

# clear the sessions
# TODO.
# save the sessions; continue from previous testing
self.save_session()

0 comments on commit 7feb81b

Please sign in to comment.