From 1083608995691438efac0edd86a037e6d5108bc9 Mon Sep 17 00:00:00 2001 From: Gijs Hendriksen Date: Thu, 25 May 2023 14:43:32 +0200 Subject: [PATCH] Fix posting sort order and memory issues in CIFF exporter --- geesedb/utils/ciff/to_ciff.py | 68 ++++++++++++++++++++--------------- requirements.txt | 1 + setup.py | 3 +- 3 files changed, 42 insertions(+), 30 deletions(-) diff --git a/geesedb/utils/ciff/to_ciff.py b/geesedb/utils/ciff/to_ciff.py index 493e0a0..989e9b5 100644 --- a/geesedb/utils/ciff/to_ciff.py +++ b/geesedb/utils/ciff/to_ciff.py @@ -2,6 +2,8 @@ from pathlib import Path from typing import Any +from tqdm import tqdm + from .CommonIndexFileFormat_pb2 import Header, Posting, PostingsList, DocRecord from ...connection import get_connection from .ciff_writer import MessageWriter @@ -24,7 +26,8 @@ def get_arguments(kwargs: Any) -> dict: 'docs': 'docs', 'term_dict': 'term_dict', 'term_doc': 'term_doc', - 'delimiter': '|' + 'batch_size': 1000, + 'verbose': False, } for key, item in arguments.items(): if kwargs.get(key) is not None: @@ -45,35 +48,39 @@ def create_ciff(self) -> None: self.cursor.execute("""SELECT SUM(tf) FROM term_doc""") header.total_terms_in_collection = self.cursor.fetchone()[0] header.average_doclength = header.total_terms_in_collection / header.num_docs - header.description = 'This is the first experimental output of (part of) the CommonCrawl in CIFF' + header.description = f'GeeseDB database {self.arguments["database"]}' f.write_message(header) + disable_tqdm = not self.arguments['verbose'] + # Create postings lists self.cursor.execute(""" - SELECT df, string, list(row(doc_id, tf)) + SELECT df, string, list(row(doc_id, tf) ORDER BY doc_id) FROM term_dict, term_doc WHERE term_dict.term_id = term_doc.term_id GROUP BY term_dict.term_id, df, string ORDER BY string """) - for output in self.cursor.fetchall(): - postingsList = PostingsList() - df, term, postings = output - assert len(postings) == df - cf = sum(p['tf'] for p in postings) - postingsList.term = term - postingsList.df = df - postingsList.cf = cf - old_id = 0 - for p in postings: - posting = Posting() - doc_id = p['doc_id'] - tf = p['tf'] - posting.docid = doc_id - old_id - old_id = doc_id - posting.tf = tf - postingsList.postings.append(posting) - f.write_message(postingsList) + with tqdm(total=header.num_postings_lists, disable=disable_tqdm) as pbar: + while batch := self.cursor.fetchmany(self.arguments['batch_size']): + for df, term, postings in batch: + postingsList = PostingsList() + assert len(postings) == df + cf = sum(p['tf'] for p in postings) + postingsList.term = term + postingsList.df = df + postingsList.cf = cf + old_id = 0 + for p in postings: + posting = Posting() + doc_id = p['doc_id'] + tf = p['tf'] + posting.docid = doc_id - old_id + old_id = doc_id + posting.tf = tf + postingsList.postings.append(posting) + f.write_message(postingsList) + pbar.update() # Create doc records self.cursor.execute(""" @@ -81,13 +88,15 @@ def create_ciff(self) -> None: FROM docs ORDER BY doc_id """) - for output in self.cursor.fetchall(): - docRecord = DocRecord() - doc_id, collection_id, length = output - docRecord.docid = doc_id - docRecord.collection_docid = collection_id - docRecord.doclength = length - f.write_message(docRecord) + with tqdm(total=header.num_docs, disable=disable_tqdm) as pbar: + while batch := self.cursor.fetchmany(self.arguments['batch_size']): + for doc_id, collection_id, length in batch: + docRecord = DocRecord() + docRecord.docid = doc_id + docRecord.collection_docid = collection_id + docRecord.doclength = length + f.write_message(docRecord) + pbar.update() if __name__ == '__main__': @@ -97,6 +106,7 @@ def create_ciff(self) -> None: parser.add_argument('--docs') parser.add_argument('--term_dict') parser.add_argument('--term_doc') - parser.add_argument('--delimiter') + parser.add_argument('--batch_size', type=int) + parser.add_argument('--verbose', action='store_true') args = parser.parse_args() ToCiff(**vars(args)) diff --git a/requirements.txt b/requirements.txt index fd9ba92..ac6d252 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ google>=2 protobuf==3.20.2 numpy pandas +tqdm git+https://github.com/informagi/pycypher diff --git a/setup.py b/setup.py index 543ff88..51ee68c 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,8 @@ author='Chris Kamphuis', author_email='chris@cs.ru.nl', url='https://github.com/informagi/GeeseDB', - install_requires=['duckdb', 'numpy', 'pandas', 'protobuf', 'pycypher @ git+https://github.com/informagi/pycypher'], + install_requires=['duckdb', 'numpy', 'pandas', 'protobuf', 'tqdm', + 'pycypher @ git+https://github.com/informagi/pycypher'], packages=find_packages(), include_package_data=True, package_data={'': ['qrels.*', 'topics.*']},