Skip to content

Commit

Permalink
implement split_strategy option
Browse files Browse the repository at this point in the history
  • Loading branch information
ladrians committed Aug 24, 2024
1 parent 0ab5458 commit a2809d4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 23 deletions.
4 changes: 3 additions & 1 deletion atlassian_confluence/confluence_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ confluence:
include_attachments: !!bool true|false (default)
include_children: !!bool true|false (default)
cloud: !!bool true|false (default)
namespace: !!str 'namespace name' # Must match the associated RAG assistant, check the index section
download_dir: !!str path to a folder where metadata is stored (mandatory for delta ingestion)
split_strategy: !!str None | id (create a id.json for each page)
namespace: !!str 'namespace name' # Must match the associated RAG assistant, check the index section (deprecated)
saia:
base_url: !!str 'string' # GeneXus Enterprise AI Base URL
api_token: !!str 'string'
Expand Down
74 changes: 52 additions & 22 deletions saia_ingest/ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def check_valid_profile(rag_api, profile_name):
logging.getLogger().error(f"Invalid profile {profile_name}")
return ret

def save_to_file(lc_documents, prefix='module'):
def save_to_file(lc_documents, prefix='module', path=None, name=None):
try:
debug_folder = os.path.join(os.getcwd(), 'debug')
debug_folder = os.path.join(os.getcwd(), 'debug') if path is None else path
create_folder(debug_folder)

serialized_docs = []
Expand All @@ -73,7 +73,7 @@ def save_to_file(lc_documents, prefix='module'):

now = datetime.now()
formatted_timestamp = now.strftime("%Y%m%d%H%M%S") # Format the datetime object as YYYYMMDDHHMMSS
filename = '%s_%s.json' % (prefix, formatted_timestamp)
filename = '%s_%s.json' % (prefix, formatted_timestamp) if name is None else name
file_path = os.path.join(debug_folder, filename)
with open(file_path, 'w', encoding='utf8') as json_file:
json.dump(serialized_docs, json_file, ensure_ascii=False, indent=4)
Expand Down Expand Up @@ -178,7 +178,7 @@ def ingest_jira(
ret = False
finally:
return ret

def ingest_confluence(
configuration: str,
timestamp: datetime = None,
Expand All @@ -187,6 +187,8 @@ def ingest_confluence(
ret = True
start_time = time.time()
try:
message_response = ""

config = get_yaml_config(configuration)
confluence_level = config.get('confluence', {})
user_name = confluence_level.get('email', None)
Expand All @@ -198,15 +200,17 @@ def ingest_confluence(
include_children = confluence_level.get('include_children', None)
cloud = confluence_level.get('cloud', None)
confluence_namespace = confluence_level.get('namespace', None)
download_dir = confluence_level.get('download_dir', None)
split_strategy = confluence_level.get('split_strategy', None)

embeddings_level = config.get('embeddings', {})
openapi_key = embeddings_level.get('openapi_key', None)
openapi_key = embeddings_level.get('openapi_key', '')
chunk_size = embeddings_level.get('chunk_size', None)
chunk_overlap = embeddings_level.get('chunk_overlap', None)
embeddings_model = embeddings_level.get('model', 'text-embedding-ada-002')

vectorstore_level = config.get('vectorstore', {})
vectorstore_api_key = vectorstore_level.get('api_key', None)
vectorstore_api_key = vectorstore_level.get('api_key', '')

os.environ['OPENAI_API_KEY'] = openapi_key
os.environ['CONFLUENCE_USERNAME'] = user_name
Expand All @@ -218,12 +222,12 @@ def ingest_confluence(
documents = []

if page_ids is not None:
try:
list_documents = load_documents(loader, page_ids=page_ids, include_attachments=include_attachments, include_children=include_children)
for item in list_documents:
documents.append(item)
except Exception as e:
logging.getLogger().error(f"Error processing {page_ids}: {e}")
try:
list_documents = load_documents(loader, page_ids=page_ids, include_attachments=include_attachments, include_children=include_children, timestamp=timestamp)
for item in list_documents:
documents.append(item)
except Exception as e:
logging.getLogger().error(f"Error processing {page_ids}: {e}")
elif space_keys is not None:
for key in space_keys:
try:
Expand All @@ -235,32 +239,58 @@ def ingest_confluence(
logging.getLogger().error(f"Error processing {key}: {e}")
continue

lc_documents = split_documents(documents, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

docs_file = save_to_file(lc_documents, prefix='confluence')
if split_strategy is not None:
if split_strategy == 'id':
ids = []
if not os.path.exists(download_dir):
raise Exception(f"Download directory {download_dir} does not exist")
for doc in documents:
lc_documents = split_documents([doc], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
metadata = doc.metadata
doc_id = metadata.get("id", "")
name = f"{doc_id}.json.custom"
docs_file = save_to_file(lc_documents, prefix='confluence', path=download_dir, name=name)
ids.append(docs_file)
else:
lc_documents = split_documents(documents, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs_file = save_to_file(lc_documents, prefix='confluence')

# Saia
saia_level = config.get('saia', {})
saia_base_url = saia_level.get('base_url', None)
saia_api_token = saia_level.get('api_token', None)
saia_profile = saia_level.get('profile', None)
upload_operation_log = saia_level.get('upload_operation_log', False)
max_parallel_executions = saia_level.get('max_parallel_executions', 5)

if saia_base_url is not None:

ragApi = RagApi(saia_base_url, saia_api_token, saia_profile)

target_file = f"{docs_file}.custom"
shutil.copyfile(docs_file, target_file)
if split_strategy is None:
doc_count = 1
target_file = f"{docs_file}.custom"
shutil.copyfile(docs_file, target_file)

response_body = ragApi.upload_document_with_metadata_file(target_file) # ToDo check .metadata
if response_body is None:
logging.getLogger().error("Error uploading document")
return False
_ = ragApi.upload_document_with_metadata_file(target_file)
else:
saia_file_ids_to_delete = search_failed_to_delete(ids)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_executions) as executor:
futures = [executor.submit(ragApi.delete_profile_document, id, saia_profile) for id in saia_file_ids_to_delete]
concurrent.futures.wait(futures)

doc_count = 0
for doc_id_path in ids:
ret = saia_file_upload(saia_base_url, saia_api_token, saia_profile, doc_id_path, False)
if not ret:
message_response += f"Error uploading document {doc_id_path} {ret}\n"
else:
doc_count += 1
message_response += f"{doc_id_path}\n"

if upload_operation_log:
end_time = time.time()
message_response = f"bulk ingest ({end_time - start_time:.2f}s)"
message_response += f"bulk ingest {doc_count} items ({end_time - start_time:.2f}s)"
ret = operation_log_upload(saia_base_url, saia_api_token, saia_profile, "ALL", message_response, 0)

else:
Expand Down

0 comments on commit a2809d4

Please sign in to comment.