Skip to content

Commit

Permalink
add naive rag with bing search
Browse files Browse the repository at this point in the history
  • Loading branch information
ks6088ts committed Sep 29, 2024
1 parent 02acd8a commit a324b4c
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 2 deletions.
8 changes: 8 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@ AZURE_OPENAI_API_KEY="<YOUR_API_KEY>"
AZURE_OPENAI_API_VERSION="2024-08-01-preview"
AZURE_OPENAI_DEPLOYMENT_CHAT="gpt-4o"
AZURE_OPENAI_DEPLOYMENT_EMBEDDING="text-embedding-3-small"

# Bing search resource
BING_SUBSCRIPTION_KEY="<YOUR_SUBSCRIPTION_KEY>"
BING_SEARCH_URL="https://api.bing.microsoft.com/v7.0/search"

# LangSmith
LANGCHAIN_TRACING_V2="true"
LANGCHAIN_API_KEY="<YOUR_API_KEY>"
Empty file.
16 changes: 16 additions & 0 deletions sandbox_python/llms/chains/generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from os import getenv

from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(
temperature=0,
api_key=getenv("AZURE_OPENAI_API_KEY"),
api_version=getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=getenv("AZURE_OPENAI_ENDPOINT"),
model=getenv("AZURE_OPENAI_DEPLOYMENT_CHAT"),
)
prompt = hub.pull("rlm/rag-prompt")

generation_chain = prompt | llm | StrOutputParser()
5 changes: 4 additions & 1 deletion sandbox_python/llms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ def get_retriever(
embedding: Embeddings,
collection_name: str,
persist_directory: str,
k: int,
) -> VectorStoreRetriever:
return Chroma(
collection_name=collection_name,
persist_directory=persist_directory,
embedding_function=embedding,
).as_retriever()
).as_retriever(
search_kwargs={"k": k},
)


def get_text_splitter() -> TextSplitter:
Expand Down
Empty file.
15 changes: 15 additions & 0 deletions sandbox_python/llms/tools/bing_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from os import getenv

from langchain_community.tools.bing_search import BingSearchResults
from langchain_community.utilities import BingSearchAPIWrapper


def get_bing_search_tool(k: int = 1):
return BingSearchResults(
api_wrapper=BingSearchAPIWrapper(
bing_search_url=getenv("BING_SEARCH_URL"),
bing_subscription_key=getenv("BING_SUBSCRIPTION_KEY"),
k=k,
),
num_results=k,
)
77 changes: 76 additions & 1 deletion scripts/llms_cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import sys
from pprint import pprint

sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")

from logging import getLogger

import typer
from langchain_core.documents import Document

from sandbox_python.llms import core
from sandbox_python.llms.chains.generation import generation_chain
from sandbox_python.llms.tools.bing_search import get_bing_search_tool

app = typer.Typer()

Expand Down Expand Up @@ -60,13 +64,84 @@ def search(
embedding=core.get_embedding(),
collection_name=collection_name,
persist_directory=persist_directory,
k=3,
).invoke(query)

print(f"got {len(got_documents)} documents")

for idx, document in enumerate(got_documents):
print(f"{idx+1} =============")
print(document.page_content)
pprint(document)


@app.command()
def bing_search(
query: str = "GitHub",
k: int = 3,
verbose: bool = False,
):
if verbose:
import logging

logging.basicConfig(level=logging.DEBUG)

documents_str = get_bing_search_tool(k=k).invoke(
{
"query": query,
}
)
documents = eval(documents_str)

for idx, document in enumerate(documents):
print(f"{idx+1} =============")
pprint(document)


@app.command()
def rag(
question="初版の発行日と出版社を教えてください。",
vector_store=True,
bing_search=True,
verbose=False,
):
if verbose:
import logging

logging.basicConfig(level=logging.DEBUG)

documents = []
if vector_store:
got_documents = core.get_retriever(
embedding=core.get_embedding(),
collection_name="rag-chroma",
persist_directory="./.chroma",
k=1,
).invoke(question)
documents.extend(got_documents)

if bing_search:
got_documents_str = get_bing_search_tool(k=2).invoke(
{
"query": question,
}
)
got_documents = eval(got_documents_str)
for document in got_documents:
documents.append(Document(page_content=document["snippet"]))

for idx, document in enumerate(documents):
print(f"{idx+1} =============")
pprint(document)

generation = generation_chain.invoke(
{
"context": documents,
"question": question,
}
)

print("Answer =============")
pprint(generation)


if __name__ == "__main__":
Expand Down

0 comments on commit a324b4c

Please sign in to comment.