Skip to content

Commit

Permalink
feat: 🎸 continue from previous session
Browse files Browse the repository at this point in the history
  • Loading branch information
GreyDGL committed May 12, 2023
1 parent 7feb81b commit 6bdd576
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 43 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ outputs/
logs/
utils/logs/
archive/
test_history/

# C extensions
*.so
Expand Down
5 changes: 5 additions & 0 deletions utils/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ def send_message(self, message, conversation_id):
message.answer = result
message.request_end_timestamp = end_time
message.time_escaped = end_time - start_time
# add additional logic for reloading (only for PentestGPT continue from previous sessions)
if conversation_id not in self.conversation_dict:
conversation: Conversation = Conversation()
conversation.conversation_id = conversation_id
self.conversation_dict[conversation_id] = conversation
conversation: Conversation = self.conversation_dict[conversation_id]
conversation.message_list.append(message)
return text
Expand Down
111 changes: 68 additions & 43 deletions utils/pentest_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class pentestGPT:

def __init__(self, reasoning_model="gpt-4", useAPI=False):
self.log_dir = "logs"
self.save_dir = "test_history"
self.task_log = {} # the information that can be saved to continue in the next session
self.useAPI = useAPI
if useAPI is False:
self.chatGPTAgent = ChatGPT(ChatGPTConfig())
Expand Down Expand Up @@ -95,21 +97,65 @@ def log_conversation(self, source, text):
source = "exception"
self.history[source].append((timestamp, text))

def _feed_init_prompts(self):
# 1. User firstly provide basic information of the task
init_description = prompt_ask(
"Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ",
multiline=False,
)
self.log_conversation("user", init_description)
self.task_log['task description'] = init_description
## Provide the information to the reasoning session for the task initialization.
prefixed_init_description = self.prompts.task_description + init_description
with self.console.status(
"[bold green] Generating Task Information..."
) as status:
_response = self.reasoning_handler(prefixed_init_description)
self.console.print("- Task information generated. \n", style="bold green")
# 2. Reasoning session generates the first thing to do and provide the information to the generation session
with self.console.status("[bold green]Processing...") as status:
first_generation_response = self.test_generation_handler(
self.prompts.todo_to_command + self.prompts.first_todo
)
# 3. Show user the first thing to do.
self.console.print(
"PentestGPT suggests you to do the following: ", style="bold green"
)
self.console.print(_response)
self.log_conversation(
"PentestGPT", "PentestGPT suggests you to do the following: \n" + _response
)
self.console.print("You may start with:", style="bold green")
self.console.print(first_generation_response)
self.log_conversation(
"PentestGPT", "You may start with: \n" + first_generation_response
)

def initialize(self, previous_session_ids=None):
# initialize the backbone sessions and test the connection to chatGPT
# define three sessions: testGenerationSession, testReasoningSession, and InputParsingSession
if (
previous_session_ids is not None and self.useAPI is False
): # TODO: add support for API usage
self.test_generation_session_id = previous_session_ids.get(
"test_generation_session_id", None
"test_generation", None
)
self.test_reasoning_session_id = previous_session_ids.get(
"test_reasoning_session_id", None
"reasoning", None
)
self.input_parsing_session_id = previous_session_ids.get(
"input_parsing_session_id", None
"parsing", None
)
# debug the three sessions
print("Previous session ids: " + str(previous_session_ids))
print("Test generation session id: " + str(self.test_generation_session_id))
print("Test reasoning session id: " + str(self.test_reasoning_session_id))
print("Input parsing session id: " + str(self.input_parsing_session_id))
print("-----------------")
self.task_log = previous_session_ids.get("task_log", {})
self.console.print("Task log: " + str(self.task_log), style="bold green")
print("You may use discussion function to remind yourself of the task.")

## verify that all the sessions are not None
if (
self.test_generation_session_id is None
Expand All @@ -120,6 +166,7 @@ def initialize(self, previous_session_ids=None):
"[bold red] Error: the previous session ids are not valid. Loading new sessions"
)
self.initialize()

else:
with self.console.status(
"[bold green] Initialize ChatGPT Sessions..."
Expand All @@ -146,6 +193,8 @@ def initialize(self, previous_session_ids=None):
except Exception as e:
logger.error(e)
self.console.print("- ChatGPT Sessions Initialized.", style="bold green")
self._feed_init_prompts()


def reasoning_handler(self, text) -> str:
# summarize the contents if necessary.
Expand Down Expand Up @@ -380,7 +429,6 @@ def input_handler(self) -> str:
self.log_conversation("pentestGPT", response)

### (2.3) local task handler

while True:
local_task_response = self.local_input_handler()
if local_task_response == "continue":
Expand Down Expand Up @@ -432,6 +480,7 @@ def input_handler(self) -> str:
## (2) pass the information to the reasoning session.
with self.console.status("[bold green] PentestGPT Thinking...") as status:
response = self.reasoning_handler(self.prompts.discussion + user_input)
print("debug, finished reasoning")
## (3) print the results
self.console.print("PentestGPT:\n", style="bold green")
self.console.print(response + "\n", style="yellow")
Expand Down Expand Up @@ -477,6 +526,7 @@ def save_session(self):
Save the current session for next round of usage.
The test information is saved in the directory `./test_history`
"""
self.console.print("Before you quit, you may want to save the current session.", style="bold green")
# 1. Require a save name from the user. If not, use the current time as the save name.
save_name = prompt_ask(
"Please enter the name of the current session. (Default with current timestamp)\n> ",
Expand All @@ -486,11 +536,12 @@ def save_session(self):
save_name = str(time.time())
# 2. Save the current session
with open(os.path.join(self.save_dir, save_name), "w") as f:
# store the three ids
# store the three ids and task_log
session_ids = {
"reasoning": self.test_generation_session_id,
"reasoning": self.test_reasoning_session_id,
"test_generation": self.test_generation_session_id,
"parsing": self.input_parsing_session_id,
"task_log": self.task_log,
}
json.dump(session_ids, f)
self.console.print(
Expand All @@ -507,7 +558,6 @@ def _preload_session(self) -> dict:
None if no previous session is found.
"""
# 1. get user input for the saved_session_name
self._preload_session()
continue_from_previous = confirm(
"Do you want to continue from previous session?"
)
Expand All @@ -518,13 +568,18 @@ def _preload_session(self) -> dict:
print("No previous session found. Please start a new session.")
return None
else: # print all the files
print("Please select the previous session you want to continue:")
print("Please select the previous session by its index (integer):")
for i, filename in enumerate(filenames):
print(str(i) + ". " + filename)
# ask for the user input
previous_testing_name = filenames[
int(input("Please key in your option: "))
]
try:
previous_testing_name = filenames[
int(input("Please key in your option (integer): "))
]
print("You selected: " + previous_testing_name)
except ValueError as e:
print("You input an invalid option. Will start a new session.")
return None

elif continue_from_previous is False:
return None
Expand Down Expand Up @@ -554,39 +609,9 @@ def main(self):
loaded_ids = self._preload_session()
self.initialize(previous_session_ids=loaded_ids)

# 1. User firstly provide basic information of the task
init_description = prompt_ask(
"Please describe the penetration testing task in one line, including the target IP, task type, etc.\n> ",
multiline=False,
)
self.log_conversation("user", init_description)
## Provide the information to the reasoning session for the task initialization.
prefixed_init_description = self.prompts.task_description + init_description
with self.console.status(
"[bold green] Generating Task Information..."
) as status:
_response = self.reasoning_handler(prefixed_init_description)
self.console.print("- Task information generated. \n", style="bold green")
# 2. Reasoning session generates the first thing to do and provide the information to the generation session
with self.console.status("[bold green]Processing...") as status:
first_generation_response = self.test_generation_handler(
self.prompts.todo_to_command + self.prompts.first_todo
)
# 3. Show user the first thing to do.
self.console.print(
"PentestGPT suggests you to do the following: ", style="bold green"
)
self.console.print(_response)
self.log_conversation(
"PentestGPT", "PentestGPT suggests you to do the following: \n" + _response
)
self.console.print("You may start with:", style="bold green")
self.console.print(first_generation_response)
self.log_conversation(
"PentestGPT", "You may start with: \n" + first_generation_response
)

# 4. enter the main loop.

# enter the main loop.
while True:
try:
result = self.input_handler()
Expand Down

0 comments on commit 6bdd576

Please sign in to comment.