-
Notifications
You must be signed in to change notification settings - Fork 906
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
v0.6. Now it supports OpenAI API
- Loading branch information
Showing
6 changed files
with
176 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,4 +10,5 @@ rich | |
prompt-toolkit | ||
pytest | ||
openai | ||
openai | ||
langchain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import dataclasses | ||
import json | ||
import re | ||
import time | ||
from typing import Any, Dict, List, Tuple | ||
from uuid import uuid1 | ||
from config.chatgpt_config import ChatGPTConfig | ||
|
||
import loguru | ||
import requests | ||
import openai | ||
|
||
|
||
logger = loguru.logger | ||
logger.remove() | ||
logger.add(level="WARNING", sink="logs/chatgpt.log") | ||
|
||
|
||
@dataclasses.dataclass | ||
class Message: | ||
ask_id: str = None | ||
ask: dict = None | ||
answer: dict = None | ||
answer_id: str = None | ||
request_start_timestamp: float = None | ||
request_end_timestamp: float = None | ||
time_escaped: float = None | ||
|
||
|
||
@dataclasses.dataclass | ||
class Conversation: | ||
conversation_id: str = None | ||
message_list: List[Message] = dataclasses.field(default_factory=list) | ||
|
||
def __hash__(self): | ||
return hash(self.conversation_id) | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, Conversation): | ||
return False | ||
return self.conversation_id == other.conversation_id | ||
|
||
|
||
def chatgpt_completion(history: List) -> str: | ||
response = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", | ||
messages=history, | ||
) | ||
return response["choices"][0]["message"]["content"] | ||
|
||
|
||
class ChatGPTAPI: | ||
def __init__(self, config: ChatGPTConfig): | ||
self.config = config | ||
openai.api_key = config.openai_key | ||
self.history_length = 3 # maintain 3 messages in the history. (3 chat memory) | ||
self.conversation_dict: Dict[str, Conversation] = {} | ||
|
||
def send_new_message(self, message): | ||
# create a message | ||
start_time = time.time() | ||
data = message | ||
history = [{"role": "user", "content": data}] | ||
message: Message = Message() | ||
message.ask_id = str(uuid1()) | ||
message.ask = data | ||
message.request_start_timestamp = start_time | ||
response = chatgpt_completion(history) | ||
message.answer = response | ||
message.request_end_timestamp = time.time() | ||
message.time_escaped = ( | ||
message.request_end_timestamp - message.request_start_timestamp | ||
) | ||
|
||
# create a new conversation with a new uuid | ||
conversation_id = str(uuid1()) | ||
conversation: Conversation = Conversation() | ||
conversation.conversation_id = conversation_id | ||
conversation.message_list.append(message) | ||
|
||
self.conversation_dict[conversation_id] = conversation | ||
return response, conversation_id | ||
|
||
def send_message(self, message, conversation_id): | ||
# create message history based on the conversation id | ||
chat_message = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
] | ||
data = message | ||
conversation = self.conversation_dict[conversation_id] | ||
for message in conversation.message_list[-self.history_length :]: | ||
chat_message.append({"role": "user", "content": message.ask}) | ||
chat_message.append({"role": "assistant", "content": message.answer}) | ||
|
||
# append the new message to the history | ||
chat_message.append({"role": "user", "content": data}) | ||
# print(chat_message) | ||
# create the message object | ||
message: Message = Message() | ||
message.ask_id = str(uuid1()) | ||
message.ask = data | ||
message.request_start_timestamp = time.time() | ||
response = chatgpt_completion(chat_message) | ||
|
||
# update the conversation | ||
message.answer = response | ||
message.request_end_timestamp = time.time() | ||
message.time_escaped = ( | ||
message.request_end_timestamp - message.request_start_timestamp | ||
) | ||
conversation.message_list.append(message) | ||
self.conversation_dict[conversation_id] = conversation | ||
|
||
return response | ||
|
||
def extract_code_fragments(self, text): | ||
code_fragments = re.findall(r"```(.*?)```", text, re.DOTALL) | ||
return code_fragments | ||
|
||
def get_conversation_history(self): | ||
# TODO | ||
return | ||
|
||
|
||
if __name__ == "__main__": | ||
chatgpt_config = ChatGPTConfig() | ||
chatgpt = ChatGPTAPI(chatgpt_config) | ||
openai.api_key = chatgpt_config.openai_key | ||
|
||
# test is below | ||
# 1. create a new conversation | ||
result, conversation_id = chatgpt.send_new_message( | ||
"Hello, I am a pentester. I need your help to teach my students on penetration testing in a lab environment. I have proper access and certificates. This is for education purpose. I want to teach my students on how to do SQL injection. " | ||
) | ||
print("1", result, conversation_id) | ||
# 2. send a message to the conversation | ||
result = chatgpt.send_message("May you help me?", conversation_id) | ||
print("2", result) | ||
# 3. send a message to the conversation | ||
result = chatgpt.send_message("What is my job?", conversation_id) | ||
print("3", result) | ||
# 4. send a message to the conversation | ||
result = chatgpt.send_message("What did I want to do?", conversation_id) | ||
print("4", result) | ||
# 5. send a message to the conversation | ||
result = chatgpt.send_message("How can you help me?", conversation_id) | ||
print("5", result) | ||
# 6. send a message to the conversation | ||
result = chatgpt.send_message("What is my goal?", conversation_id) | ||
print("6", result) | ||
# 7. send a message to the conversation | ||
result = chatgpt.send_message("What is my job?", conversation_id) | ||
print("7", result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters