From d3f297965bc7e958320271f594734cbac8b447fb Mon Sep 17 00:00:00 2001 From: contradiction29 Date: Tue, 14 May 2024 22:18:56 +0900 Subject: [PATCH] Update: Change embedding logic to inclue postTitle and tags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 類似度検索の精度を上げるため、投稿タイトルとタグ名を含めるようにembeddingのロジックを変更した --- app/modules/embedding.server.ts | 24 ++++++++++++++++++-- app/routes/_layout.archives.edit.$postId.tsx | 2 +- app/routes/_layout.post.tsx | 3 ++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/app/modules/embedding.server.ts b/app/modules/embedding.server.ts index e89d1b78..33fd5a25 100644 --- a/app/modules/embedding.server.ts +++ b/app/modules/embedding.server.ts @@ -5,19 +5,34 @@ import { prisma } from "./db.server"; interface CreateEmbeddingInput { postId : number; postContent : string; + postTitle : string; } + const OpenAIAPIKey = process.env.OPENAI_API_KEY; const OpenAIEmbeddingModel = "text-embedding-3-small" -export async function createEmbedding({ postId, postContent } : CreateEmbeddingInput) { +export async function createEmbedding({ postId, postContent, postTitle } : CreateEmbeddingInput) { if (!OpenAIAPIKey) { throw new Error("OPENAI_API_KEY is not set"); } + const allTags = await prisma.relPostTags.findMany({ + where: { postId }, + select: { + dimTag: { + select: { + tagName: true + } + } + } + }) + const allTagNames = allTags.map(tag => tag.dimTag.tagName) + + const inputText = await getEmbeddingInputText(postContent, postTitle, allTagNames) const openAI = new OpenAI({apiKey: OpenAIAPIKey}) const response = await openAI.embeddings.create({ model: OpenAIEmbeddingModel, - input: postContent, + input: inputText, }) const embedding = Array.from(response.data[0].embedding) @@ -39,4 +54,9 @@ export async function createEmbedding({ postId, postContent } : CreateEmbeddingI status: 200, message: "Embedding created successfully" }); +} + +async function getEmbeddingInputText(postContent: string, postTitle: string, allTagNames: string[]) { + const inputText = `タイトル: ${postTitle}\nタグ: ${allTagNames}\n本文: ${postContent}` + return inputText } \ No newline at end of file diff --git a/app/routes/_layout.archives.edit.$postId.tsx b/app/routes/_layout.archives.edit.$postId.tsx index dc2bfdce..5b8932e5 100644 --- a/app/routes/_layout.archives.edit.$postId.tsx +++ b/app/routes/_layout.archives.edit.$postId.tsx @@ -400,7 +400,7 @@ export async function action({ request, params }: ActionFunctionArgs) { timeout : 20000, }); - await createEmbedding({ postId: Number(updatedPost.postId), postContent: updatedPost.postContent }); + await createEmbedding({ postId: Number(updatedPost.postId), postContent: updatedPost.postContent, postTitle: updatedPost.postTitle}); return redirect(`/archives/${updatedPost.postId}`); } diff --git a/app/routes/_layout.post.tsx b/app/routes/_layout.post.tsx index 985f1864..4ba16f03 100644 --- a/app/routes/_layout.post.tsx +++ b/app/routes/_layout.post.tsx @@ -426,7 +426,8 @@ export async function action({ request }:ActionFunctionArgs ) { await createEmbedding({ postId: Number(newPost.postId), - postContent: newPost.postContent + postContent: newPost.postContent, + postTitle: newPost.postTitle, }); return redirect(`/archives/${newPost.postId}`);