-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(load): improve loading to support loading from dir of TSVs
- Loading branch information
Showing
8 changed files
with
419 additions
and
215 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
#!/usr/bin/sudo python | ||
""" | ||
Usage: | ||
- Run app: poetry run python run.py | ||
""" | ||
import glob | ||
import os | ||
|
||
from gen3.auth import Gen3Auth | ||
from gen3.tools.metadata.discovery import output_expanded_discovery_metadata | ||
from gen3.utils import get_or_create_event_loop_for_thread | ||
from langchain.document_loaders.csv_loader import CSVLoader | ||
from langchain.text_splitter import TokenTextSplitter | ||
|
||
from gen3discoveryai import config, logging | ||
from gen3discoveryai.topic_chains.question_answer import TopicChainQuestionAnswerRAG | ||
|
||
|
||
def load_tsvs_from_dir( | ||
directory, source_column_name="guid", token_splitter_chunk_size=1000, delimiter="\t" | ||
): | ||
""" | ||
Load TSVs from specified directory in the knowledge database. | ||
This expects filenames to START with a configured topic and will aggregate | ||
documents from all files that begin with that topic name. This will recursively retrieve | ||
all filenames in the directory and subdirectories. | ||
In the following example, both TSVs starting with "default" would populate documents | ||
for the "default" topic knowledge store and the nested "anothertopic.tsv" would populate | ||
documents for the "anothertopic" topic. | ||
- default_data_1.tsv | ||
- default_data_2.tsv | ||
- some folder | ||
- anothertopic.tsv | ||
Args: | ||
directory: path to directory where relevant TSVs are | ||
source_column_name: what column to get the "source" information from for the document | ||
token_splitter_chunk_size: how many tokens to chunk the content into per doc | ||
delimiter: \t or , or whatever else is delimited the TSV/CSV-like file | ||
""" | ||
files = glob.glob(f"{directory.rstrip('/')}/**/*.*", recursive=True) | ||
topics = config.TOPICS.split(",") | ||
|
||
logging.info(f"Loading TSVs for topics: {topics}") | ||
|
||
topics_files = {} | ||
|
||
for topic in topics: | ||
topics_files[topic] = [] | ||
for file in files: | ||
if os.path.basename(file).startswith(topic): | ||
topics_files[topic].append(file) | ||
|
||
for topic, files in topics_files.items(): | ||
topic_documents = [] | ||
for file in files: | ||
# Load the document, split it into chunks, embed each chunk and load it into the vector store. | ||
loader = CSVLoader( | ||
source_column=source_column_name, | ||
file_path=file, | ||
csv_args={ | ||
"delimiter": delimiter, | ||
"quotechar": '"', | ||
}, | ||
) | ||
data = loader.load() | ||
|
||
# 4097 is OpenAI's max, so if we split into 1000, we can get 4 results with | ||
# 97 tokens left for the query? | ||
text_splitter = TokenTextSplitter.from_tiktoken_encoder( | ||
chunk_size=token_splitter_chunk_size, chunk_overlap=0 | ||
) | ||
documents = text_splitter.split_documents(data) | ||
|
||
topic_documents.extend(documents) | ||
|
||
topic_chain = TopicChainQuestionAnswerRAG( | ||
topic=topic, | ||
# metadata shouldn't matter much here, we just need the topic chain initialized so we can store the data | ||
metadata={"model_name": "gpt-3.5-turbo", "model_temperature": 0.33}, | ||
) | ||
|
||
_store_documents_in_chain(topic_chain, topic_documents) | ||
|
||
|
||
def _store_documents_in_chain(topic_chain, topic_documents): | ||
""" | ||
Tiny helper to store documents in the provided chain. This makes the testing/mocking simpler in unit tests | ||
""" | ||
topic_chain.store_knowledge(topic_documents) | ||
|
||
|
||
def main(): | ||
""" | ||
Get all discovery metadata and load into knowledge library based on GUID. | ||
This relies on using the commons from whatever API Key you have configured. See the Gen3 SDK's `Gen3Auth` class | ||
for info. | ||
""" | ||
auth = Gen3Auth() | ||
loop = get_or_create_event_loop_for_thread() | ||
output_file = loop.run_until_complete( | ||
output_expanded_discovery_metadata(auth, output_format="tsv") | ||
) | ||
|
||
# Load the document, split it into chunks, embed each chunk and load it into the vector store. | ||
loader = CSVLoader( | ||
source_column="guid", | ||
file_path=output_file, | ||
csv_args={ | ||
"delimiter": "\t", | ||
"quotechar": '"', | ||
}, | ||
) | ||
data = loader.load() | ||
|
||
# 4097 is OpenAI's max, so if we split into 1000, we can get 4 results with | ||
# 97 tokens left for the query? | ||
text_splitter = TokenTextSplitter.from_tiktoken_encoder( | ||
chunk_size=1000, chunk_overlap=0 | ||
) | ||
documents = text_splitter.split_documents(data) | ||
# | ||
# output_docs = [doc.to_json() for doc in documents] | ||
# | ||
# # could output | ||
# | ||
# input_docs = [doc for doc in output_docs] | ||
|
||
topic_chain = TopicChainQuestionAnswerRAG( | ||
topic="bdc", | ||
metadata={"model_name": "gpt-3.5-turbo", "model_temperature": 0.33}, | ||
) | ||
|
||
topic_chain.store_knowledge(documents) | ||
|
||
|
||
def aggmds(): | ||
""" | ||
Use aggregate MDS | ||
""" | ||
auth = Gen3Auth() | ||
# loop = get_or_create_event_loop_for_thread() | ||
# output_file = loop.run_until_complete( | ||
# output_expanded_discovery_metadata(auth, output_format="tsv", use_agg_mds=True) | ||
# ) | ||
|
||
# TODO remove __manifest column | ||
output_file = "brh-data-commons-org-discovery_metadata.tsv" | ||
|
||
# Load the document, split it into chunks, embed each chunk and load it into the vector store. | ||
loader = CSVLoader( | ||
source_column="guid", | ||
file_path=output_file, | ||
csv_args={ | ||
"delimiter": "\t", | ||
"quotechar": '"', | ||
}, | ||
) | ||
data = loader.load() | ||
|
||
# 4097 is OpenAI's max, so if we split into 1000, we can get 4 results with | ||
# 97 tokens left for the query? | ||
text_splitter = TokenTextSplitter.from_tiktoken_encoder( | ||
chunk_size=1000, chunk_overlap=0 | ||
) | ||
documents = text_splitter.split_documents(data) | ||
|
||
topic_chain = TopicChainQuestionAnswerRAG( | ||
topic="default", | ||
metadata={"model_name": "gpt-3.5-turbo", "model_temperature": 0.33}, | ||
) | ||
|
||
topic_chain.store_knowledge(documents) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
# aggmds() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
from unittest.mock import patch, MagicMock | ||
|
||
from gen3discoveryai import config | ||
|
||
from load_into_knowledge_store import load_tsvs_from_dir | ||
|
||
|
||
@patch("load_into_knowledge_store._store_documents_in_chain") | ||
@patch("load_into_knowledge_store.TopicChainQuestionAnswerRAG") | ||
def test_load_from_tsvs(topic_chain, store_documents_in_chain): | ||
""" | ||
Test that the loading from TSVs pulls the correct information from various files and | ||
aggregates files that begin with the same topic name. | ||
""" | ||
config.TOPICS = "default,bdc" | ||
|
||
directory = os.path.abspath( | ||
os.path.dirname(os.path.abspath(__file__)).rstrip("/") + "/../tests/tsvs" | ||
) | ||
|
||
load_tsvs_from_dir( | ||
directory=directory, | ||
source_column_name="guid", | ||
token_splitter_chunk_size=1000, | ||
delimiter="\t", | ||
) | ||
|
||
config.TOPICS = "default" | ||
|
||
topic_chain.store_knowledge.return_value = True | ||
|
||
assert topic_chain.call_count == 2 | ||
assert store_documents_in_chain.call_count == 2 | ||
|
||
for item in store_documents_in_chain.call_args_list: | ||
assert len(item.args[1]) > 0 # documents |
Oops, something went wrong.