-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add chat function #118
base: dev
Are you sure you want to change the base?
add chat function #118
Changes from 7 commits
da64e94
7db1729
7019a6f
bfed94a
fe45ecc
2e30a84
c42b0de
0a1e701
bfdcd90
aab8772
17169c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
FROM public.ecr.aws/lambda/python:3.9 | ||
|
||
COPY requirements.txt ./ | ||
RUN python3.9 -m pip install --upgrade pip | ||
RUN python3.9 -m pip install -r requirements.txt -t . | ||
|
||
COPY . . | ||
|
||
# Command can be overwritten by providing a different command in the template directly. | ||
CMD ["app.lambda_handler"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import json | ||
from tools.utils import DynamoDBManager, OpenAIManager | ||
from tools.response import HttpResponse | ||
from role.role_tetris import TetrisAssistant | ||
import logging | ||
|
||
# ロギングの設定 | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
|
||
class ChatHandler: | ||
|
||
def __init__(self, event): | ||
self.event = event | ||
self.db_manager = DynamoDBManager() | ||
self.openai_manager = OpenAIManager() | ||
self.http_response = HttpResponse() | ||
self.tetris_assistant = TetrisAssistant() | ||
|
||
def get_message_from_event(self): | ||
body_content = self.event.get('body', None) | ||
if not body_content: | ||
raise ValueError("The 'body' field in the event is missing or empty.") | ||
try: | ||
return json.loads(body_content)['input_text'] | ||
except KeyError: | ||
raise ValueError("Invalid input. 'input_text' key is required.") | ||
|
||
def handle_get_request(self): | ||
return self.http_response.success('hello world') | ||
|
||
def handle_delete_request(self): | ||
# RESTfulな設計では、DELETEはbodyを持たせるべきではないが、他に方法が分からなかった。 | ||
try: | ||
data = json.loads(self.event["body"]) | ||
user_id = data['identity_id'] | ||
char_name = data['character_name'] | ||
try: | ||
self.db_manager.delete_items_with_secondary_index(user_id, char_name) | ||
return self.http_response.success('delete success') | ||
|
||
except Exception as e: | ||
print(f"An error occurred: {e}") | ||
return self.http_response.server_error(f'Error during delete operation: {e}') | ||
except KeyError as e: | ||
print(f"An error occurred: {e}") | ||
return self.http_response.client_error(f'Error during delete operation: {e}') | ||
|
||
|
||
def gpt_function_call(self, response_data, messages, functions, user_id, char_name, max_order_id, input_text): | ||
print(f"function call defined\n") | ||
# gptが定義した関数を実行し、結果を取得する | ||
func_name, args, function_response, function_args = self.openai_manager.execute_function_call(response_data) | ||
|
||
# 2回目のAPI実行のための関数の引数を作成 | ||
function_args = self.openai_manager.create_function_args(func_name, args) | ||
#print(f"function_args: {function_args}") | ||
messages.append({"role": "assistant", "content": None, "function_call": function_args}) | ||
messages.append({"role": "function", "content": function_response, "name": func_name}) | ||
|
||
response_2nd = self.openai_manager.get_chat_response_func(messages, functions) | ||
response_content = response_2nd.choices[0]["message"]["content"] | ||
|
||
# DynamoDBにトーク履歴を記録 | ||
self.db_manager.store_conversation(user_id, char_name, max_order_id + 0, "user", input_text) | ||
self.db_manager.store_conversation(user_id, char_name, max_order_id + 1, "assistant", None, name=None, function_call=function_args) | ||
self.db_manager.store_conversation(user_id, char_name, max_order_id + 2, "function", function_response, name=func_name, function_call=None) | ||
self.db_manager.store_conversation(user_id, char_name, max_order_id + 3, "assistant", response_content) | ||
|
||
return response_content | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. openaiの公式ページに従い、function callを実装 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function call周りのことは完全お任せです! |
||
def gpt_simple_response(self, response_data, user_id, char_name, max_order_id, input_text): | ||
print(f"function call undefined\n") | ||
response_content = response_data["message"]["content"] | ||
self.db_manager.store_conversation(user_id, char_name, max_order_id + 0, "user", input_text) | ||
self.db_manager.store_conversation(user_id, char_name, max_order_id + 1, "assistant", response_content) | ||
return response_content | ||
|
||
|
||
def call_openai_api(self, messages, functions, user_id, char_name, input_text, max_order_id): | ||
response_1st = self.openai_manager.get_chat_response_func(messages, functions) | ||
response_data = response_1st["choices"][0] | ||
if response_data["finish_reason"] == "function_call": | ||
if response_data["message"]["function_call"]["name"]: | ||
return self.gpt_function_call(response_data, messages, functions, user_id, char_name, max_order_id, input_text) | ||
else: | ||
return self.gpt_simple_response(response_data, user_id, char_name, max_order_id, input_text) | ||
|
||
|
||
def handle_post_request(self): | ||
self.openai_manager.get_secret() | ||
|
||
# メッセージを取得 | ||
try: | ||
data = json.loads(self.event["body"]) | ||
logger.info('Event: %s', json.dumps(data)) | ||
user_id = data['identity_id'] | ||
char_name = data['character_name'] | ||
input_text = data['input_text'] | ||
except ValueError as e: | ||
return self.http_response.client_error(f'Error during post operation: {e}') | ||
|
||
# 過去の応答を取得 | ||
try: | ||
messages = self.tetris_assistant.get_chat_messages() | ||
functions = self.tetris_assistant.get_chat_functions() | ||
max_order_id,items = self.db_manager.get_max_conversation_id(user_id, char_name) | ||
|
||
#今までの対話をmessagesに並べる | ||
messages.extend([ | ||
{ | ||
"role": item["role"], | ||
"content": item["content"], | ||
**({"name": item["name"]} if "name" in item else {}), | ||
**({"function_call": item["function_call"]} if "function_call" in item else {}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "name"と"function_call"はfunction callを使わない時は必要ない |
||
} | ||
for item in items | ||
]) | ||
messages.append({"role": "user", "content": data['input_text']}) | ||
|
||
#openAIのAPIを叩く | ||
response_content = self.call_openai_api(messages, functions, user_id, char_name, input_text, max_order_id) | ||
return self.http_response.success(response_content) | ||
|
||
except Exception as e: | ||
return self.http_response.server_error(f'Error during post operation: {e}') | ||
|
||
def handle(self): | ||
http_method = self.event.get('httpMethod', '') | ||
if http_method == 'GET': | ||
return self.handle_get_request() | ||
elif http_method == 'DELETE': | ||
return self.handle_delete_request() | ||
elif http_method == 'POST': | ||
return self.handle_post_request() | ||
|
||
|
||
def lambda_handler(event, context): | ||
handler = ChatHandler(event) | ||
return handler.handle() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"httpMethod": "DELETE", | ||
"body": "{\"identity_id\": \"id_hoge\", \"character_name\": \"test_user\"}", | ||
"resource": "/ask", | ||
"path": "/ask" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"httpMethod": "GET", | ||
"resource": "/ask", | ||
"path": "/ask" | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"httpMethod": "POST", | ||
"body": "{\"identity_id\": \"id_hoge\", \"input_text\": \"Rustはどんな言語ですか?\", \"character_name\": \"test_user\"}", | ||
"resource": "/ask", | ||
"path": "/ask" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"httpMethod": "POST", | ||
"body": "{\"identity_id\": \"id_hoge\", \"input_text\": \"ミノを消した時の点数は?\", \"character_name\": \"test_user\"}", | ||
"resource": "/ask", | ||
"path": "/ask" | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sam local invoke "ChatFunction" -e scripts/chat_function/events/event_get.json |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import openai | ||
|
||
from dotenv import load_dotenv | ||
from llama_index import VectorStoreIndex, SimpleDirectoryReader | ||
from llama_index import StorageContext, load_index_from_storage | ||
|
||
# ローカルで辞書を作成し、storageに保存する | ||
# .envにopenaiのAPIkeyを記述する | ||
# ./dataに.txtや.md等のテキストファイルを格納し、ローカル環境で実行する | ||
# ./storageに辞書データを出力し、そのデータを用いて辞書検索を行う | ||
def main(): | ||
# APIkeyの設定 | ||
load_dotenv() | ||
try: | ||
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] | ||
except KeyError: | ||
print("OPENAI_API_KEY environment variable not found. Please make sure it is set.") | ||
openai.api_key = OPENAI_API_KEY | ||
|
||
# モデルの読み込み | ||
if(1): | ||
# 辞書データを作成する | ||
documents = SimpleDirectoryReader(input_dir="./data").load_data() | ||
print("documents: ", documents) | ||
index = VectorStoreIndex.from_documents(documents) | ||
# 保存 | ||
index.storage_context.persist() | ||
else: | ||
# rebuild storage context | ||
storage_context = StorageContext.from_defaults(persist_dir='storage') | ||
# load index | ||
index = load_index_from_storage(storage_context) | ||
|
||
# クエリの実行 | ||
query_engine = index.as_query_engine() | ||
response = query_engine.query("Dockerイメージの取得途中で止まる原因は?") | ||
print("response: ", response) | ||
|
||
if __name__ == "__main__": | ||
main() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. local動作用まで作っていただいて感謝です! |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
requests | ||
urllib3==1.26 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. openaiとboto3を同時に利用するためには、urllib3==1.26を用いる必要がある |
||
langchain==0.0.234 | ||
llama-index==0.7.9 | ||
openai==0.27.8 | ||
python-dotenv |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import json | ||
from langchain import OpenAI | ||
from llama_index import VectorStoreIndex, SimpleDirectoryReader | ||
from llama_index import StorageContext, load_index_from_storage | ||
|
||
class TetrisAssistant: | ||
@staticmethod | ||
def get_chat_messages(): | ||
return [ | ||
{ | ||
"role": "system", | ||
"content": | ||
""" | ||
質問には100文字以内で回答する。 | ||
必要に応じて"function"を実行し、その応答を用いて、"user"の言語に合わせた言語で100文字以内で回答する。 | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. llama_indexを実行すると、英語で回答する時があるため、言語を"user"と同じ言語に指定 |
||
}, | ||
{ | ||
"role": "user", | ||
"content":"テトリスのルールを教えて。" | ||
}, | ||
{ | ||
"role": "assistant", | ||
"content": "テトリスは、異なる形のブロックを落としてきて、行を埋めるゲームです。行を完全に埋めると、その行は消えます。ブロックが画面上部に積み上げられるとゲームオーバーです。" | ||
} | ||
] | ||
|
||
@staticmethod | ||
def get_chat_functions(): | ||
return [ | ||
{ | ||
"name": "search_tetris_index", | ||
"description": | ||
""" | ||
テトリスをpythonで操作することを通してプログラミングを学ぶ時に用いる資料を検索する。 | ||
この資料は以下の内容を含む。 | ||
・テトリスのルール(点数,ルール,ボード情報,フィールド情報) | ||
・プログラムによるテトリスの操作方法(実行コマンド,ミノの説明,コマンド、各種ファイル,アートの作り方) | ||
・環境構築(AI,docker,Git,Windows(WSL,PowerShell),Linux,Mac) | ||
""", | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"SearchContent": { | ||
"type": "string", | ||
"description": "テトリス対戦に関する情報を検索 e.g.ミノを消した時の点数は?" | ||
} | ||
}, | ||
"required": ["SearchContent"] | ||
} | ||
} | ||
] | ||
|
||
class TetrisIndexSearch: | ||
def __init__(self): | ||
# indexの読み込み | ||
storage_context = StorageContext.from_defaults(persist_dir='./storage') | ||
self.index = load_index_from_storage(storage_context) | ||
|
||
def search_tetris_index(self, SearchContent): | ||
# クエリの実行 | ||
query_engine = self.index.as_query_engine() | ||
response = query_engine.query(SearchContent) | ||
|
||
result = { | ||
"response": response.response | ||
} | ||
return json.dumps(result) |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"graph_dict": {}} |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
|
||
DYNAMODB_TABLE_NAME = 'tetris_chat_log' | ||
OPENAI_MODEL_NAME = "gpt-3.5-turbo-0613"#gpt-3.5-turbo-0613" | ||
DYNAMODB_INDEX_NAME = 'user_id-char_name-index' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import json | ||
|
||
class HttpResponse: | ||
|
||
@staticmethod | ||
def generate_response(status_code, message): | ||
return { | ||
"statusCode": status_code, | ||
"body": json.dumps({ | ||
"message": message, | ||
}), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CORS対策しとかないといけないので、ほかのlambda functionを参考にheaderを追加しといてほしいです |
||
} | ||
|
||
@classmethod | ||
def success(cls, message): | ||
return cls.generate_response(200, message) | ||
|
||
@classmethod | ||
def redirect(cls, message): | ||
return cls.generate_response(300, message) | ||
|
||
@classmethod | ||
def client_error(cls, message): | ||
return cls.generate_response(400, message) | ||
|
||
@classmethod | ||
def server_error(cls, message): | ||
return cls.generate_response(500, message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lambda用imagのリポジトリを自動で作成、削除してくれるオプション