diff --git a/geesedb/index/fulltext_from_ciff.py b/geesedb/index/fulltext_from_ciff.py index 14701b7..d13b90d 100644 --- a/geesedb/index/fulltext_from_ciff.py +++ b/geesedb/index/fulltext_from_ciff.py @@ -5,9 +5,9 @@ import os import duckdb from typing import Any, List, Union, Tuple +from ciff_toolkit.read import CiffReader from ..connection import get_connection -from ..utils import CommonIndexFileFormat_pb2 as Ciff class FullTextFromCiff: @@ -109,50 +109,37 @@ def fill_tables(self) -> None: with open(self.arguments['protobuf_file'], 'rb') as f: data = f.read() - # start with reading header info - next_pos, pos = 0, 0 - header = Ciff.Header() - next_pos, pos = self.decode(data, pos) - header.ParseFromString(data[pos:pos + next_pos]) - pos += next_pos + with CiffReader(self.arguments['protobuf_file']) as reader: + for term_id, postings_list in enumerate(reader.read_postings_lists()): + self.connection.begin() + q = f'INSERT INTO {self.arguments["table_names"][1]} ' \ + f'({",".join(self.arguments["columns_names_term_dict"])}) ' \ + f"VALUES ({term_id},{postings_list.df},'{postings_list.term}')" + try: + self.cursor.execute(q) + except RuntimeError: + print(q) + + docid = 0 + for posting in postings_list.postings: + docid += posting.docid + q = f'INSERT INTO {self.arguments["table_names"][2]} ' \ + f'({",".join(self.arguments["columns_names_term_doc"])}) ' \ + f'VALUES ({term_id},{docid},{posting.tf})' + self.cursor.execute(q) + self.connection.commit() - # read posting lists - postings_list = Ciff.PostingsList() - for term_id in range(header.num_postings_lists): self.connection.begin() - next_pos, pos = self.decode(data, pos) - postings_list.ParseFromString(data[pos:pos + next_pos]) - pos += next_pos - q = f'INSERT INTO {self.arguments["table_names"][1]} ' \ - f'({",".join(self.arguments["columns_names_term_dict"])}) ' \ - f"VALUES ({term_id},{postings_list.df},'{postings_list.term}')" - try: - self.cursor.execute(q) - except RuntimeError: - print(q) - for posting in postings_list.postings: - q = f'INSERT INTO {self.arguments["table_names"][2]} ' \ - f'({",".join(self.arguments["columns_names_term_doc"])}) ' \ - f'VALUES ({term_id},{posting.docid},{posting.tf})' + for n, doc_record in enumerate(reader.read_documents()): + if n % 1000 == 0: + self.connection.commit() + self.connection.begin() + q = f'INSERT INTO {self.arguments["table_names"][0]} ' \ + f'({",".join(self.arguments["columns_names_docs"])}) ' \ + f"VALUES ('{doc_record.collection_docid}',{doc_record.docid},{doc_record.doclength})" self.cursor.execute(q) self.connection.commit() - # read doc information - doc_record = Ciff.DocRecord() - self.connection.begin() - for n in range(header.num_docs): - if n % 1000 == 0: - self.connection.commit() - self.connection.begin() - next_pos, pos = self.decode(data, pos) - doc_record.ParseFromString(data[pos:pos + next_pos]) - pos += next_pos - q = f'INSERT INTO {self.arguments["table_names"][0]} ' \ - f'({",".join(self.arguments["columns_names_docs"])}) ' \ - f"VALUES ('{doc_record.collection_docid}',{doc_record.docid},{doc_record.doclength})" - self.cursor.execute(q) - self.connection.commit() - if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/geesedb/utils/__init__.py b/geesedb/utils/__init__.py index 86f3401..c8e33b1 100644 --- a/geesedb/utils/__init__.py +++ b/geesedb/utils/__init__.py @@ -1,5 +1,4 @@ -from .ciff import CommonIndexFileFormat_pb2 from .ciff.to_csv import ToCSV from .ciff.to_ciff import ToCiff -__all__ = ['CommonIndexFileFormat_pb2', 'ToCSV', 'ToCiff'] \ No newline at end of file +__all__ = ['ToCSV', 'ToCiff'] \ No newline at end of file diff --git a/geesedb/utils/ciff/CommonIndexFileFormat.proto b/geesedb/utils/ciff/CommonIndexFileFormat.proto deleted file mode 100644 index b52a970..0000000 --- a/geesedb/utils/ciff/CommonIndexFileFormat.proto +++ /dev/null @@ -1,57 +0,0 @@ -syntax = "proto3"; - -package io.osirrc.ciff; - -// An index stored in CIFF is a single file comprised of exactly the following: -// - A Header protobuf message, -// - Exactly the number of PostingsList messages specified in the num_postings_lists field of the Header -// - Exactly the number of DocRecord messages specified in the num_doc_records field of the Header -// The protobuf messages are defined below. - -// This is the CIFF header. It always comes first. -message Header { - int32 version = 1; // Version. - - int32 num_postings_lists = 2; // Exactly the number of PostingsList messages that follow the Header. - int32 num_docs = 3; // Exactly the number of DocRecord messages that follow the PostingsList messages. - - // The total number of postings lists in the collection; the vocabulary size. This might differ from - // num_postings_lists, for example, because we only export the postings lists of query terms. - int32 total_postings_lists = 4; - - // The total number of documents in the collection; might differ from num_doc_records for a similar reason as above. - int32 total_docs = 5; - - // The total number of terms in the entire collection. This is the sum of all document lengths of all documents in - // the collection. - int64 total_terms_in_collection = 6; - - // The average document length. We store this value explicitly in case the exporting application wants a particular - // level of precision. - double average_doclength = 7; - - // Description of this index, meant for human consumption. Describing, for example, the exporting application, - // document processing and tokenization pipeline, etc. - string description = 8; -} - -// An individual posting. -message Posting { - int32 docid = 1; - int32 tf = 2; -} - -// A postings list, comprised of one ore more postings. -message PostingsList { - string term = 1; // The term. - int64 df = 2; // The document frequency. - int64 cf = 3; // The collection frequency. - repeated Posting postings = 4; -} - -// A record containing metadata about an individual document. -message DocRecord { - int32 docid = 1; // Refers to the docid in the postings lists. - string collection_docid = 2; // Refers to a docid in the external collection. - int32 doclength = 3; // Length of this document. -} diff --git a/geesedb/utils/ciff/CommonIndexFileFormat_pb2.py b/geesedb/utils/ciff/CommonIndexFileFormat_pb2.py deleted file mode 100644 index fc80d8e..0000000 --- a/geesedb/utils/ciff/CommonIndexFileFormat_pb2.py +++ /dev/null @@ -1,278 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: CommonIndexFileFormat.proto - -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -from google.protobuf import descriptor_pb2 -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='CommonIndexFileFormat.proto', - package='io.osirrc.ciff', - syntax='proto3', - serialized_pb=_b('\n\x1b\x43ommonIndexFileFormat.proto\x12\x0eio.osirrc.ciff\"\xcc\x01\n\x06Header\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x1a\n\x12num_postings_lists\x18\x02 \x01(\x05\x12\x10\n\x08num_docs\x18\x03 \x01(\x05\x12\x1c\n\x14total_postings_lists\x18\x04 \x01(\x05\x12\x12\n\ntotal_docs\x18\x05 \x01(\x05\x12!\n\x19total_terms_in_collection\x18\x06 \x01(\x03\x12\x19\n\x11\x61verage_doclength\x18\x07 \x01(\x01\x12\x13\n\x0b\x64\x65scription\x18\x08 \x01(\t\"$\n\x07Posting\x12\r\n\x05\x64ocid\x18\x01 \x01(\x05\x12\n\n\x02tf\x18\x02 \x01(\x05\"_\n\x0cPostingsList\x12\x0c\n\x04term\x18\x01 \x01(\t\x12\n\n\x02\x64\x66\x18\x02 \x01(\x03\x12\n\n\x02\x63\x66\x18\x03 \x01(\x03\x12)\n\x08postings\x18\x04 \x03(\x0b\x32\x17.io.osirrc.ciff.Posting\"G\n\tDocRecord\x12\r\n\x05\x64ocid\x18\x01 \x01(\x05\x12\x18\n\x10\x63ollection_docid\x18\x02 \x01(\t\x12\x11\n\tdoclength\x18\x03 \x01(\x05\x62\x06proto3') -) - - - - -_HEADER = _descriptor.Descriptor( - name='Header', - full_name='io.osirrc.ciff.Header', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='version', full_name='io.osirrc.ciff.Header.version', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='num_postings_lists', full_name='io.osirrc.ciff.Header.num_postings_lists', index=1, - number=2, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='num_docs', full_name='io.osirrc.ciff.Header.num_docs', index=2, - number=3, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='total_postings_lists', full_name='io.osirrc.ciff.Header.total_postings_lists', index=3, - number=4, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='total_docs', full_name='io.osirrc.ciff.Header.total_docs', index=4, - number=5, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='total_terms_in_collection', full_name='io.osirrc.ciff.Header.total_terms_in_collection', index=5, - number=6, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='average_doclength', full_name='io.osirrc.ciff.Header.average_doclength', index=6, - number=7, type=1, cpp_type=5, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='description', full_name='io.osirrc.ciff.Header.description', index=7, - number=8, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=48, - serialized_end=252, -) - - -_POSTING = _descriptor.Descriptor( - name='Posting', - full_name='io.osirrc.ciff.Posting', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='docid', full_name='io.osirrc.ciff.Posting.docid', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='tf', full_name='io.osirrc.ciff.Posting.tf', index=1, - number=2, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=254, - serialized_end=290, -) - - -_POSTINGSLIST = _descriptor.Descriptor( - name='PostingsList', - full_name='io.osirrc.ciff.PostingsList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='term', full_name='io.osirrc.ciff.PostingsList.term', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='df', full_name='io.osirrc.ciff.PostingsList.df', index=1, - number=2, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='cf', full_name='io.osirrc.ciff.PostingsList.cf', index=2, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='postings', full_name='io.osirrc.ciff.PostingsList.postings', index=3, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=292, - serialized_end=387, -) - - -_DOCRECORD = _descriptor.Descriptor( - name='DocRecord', - full_name='io.osirrc.ciff.DocRecord', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='docid', full_name='io.osirrc.ciff.DocRecord.docid', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='collection_docid', full_name='io.osirrc.ciff.DocRecord.collection_docid', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='doclength', full_name='io.osirrc.ciff.DocRecord.doclength', index=2, - number=3, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=389, - serialized_end=460, -) - -_POSTINGSLIST.fields_by_name['postings'].message_type = _POSTING -DESCRIPTOR.message_types_by_name['Header'] = _HEADER -DESCRIPTOR.message_types_by_name['Posting'] = _POSTING -DESCRIPTOR.message_types_by_name['PostingsList'] = _POSTINGSLIST -DESCRIPTOR.message_types_by_name['DocRecord'] = _DOCRECORD -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( - DESCRIPTOR = _HEADER, - __module__ = 'CommonIndexFileFormat_pb2' - # @@protoc_insertion_point(class_scope:io.osirrc.ciff.Header) - )) -_sym_db.RegisterMessage(Header) - -Posting = _reflection.GeneratedProtocolMessageType('Posting', (_message.Message,), dict( - DESCRIPTOR = _POSTING, - __module__ = 'CommonIndexFileFormat_pb2' - # @@protoc_insertion_point(class_scope:io.osirrc.ciff.Posting) - )) -_sym_db.RegisterMessage(Posting) - -PostingsList = _reflection.GeneratedProtocolMessageType('PostingsList', (_message.Message,), dict( - DESCRIPTOR = _POSTINGSLIST, - __module__ = 'CommonIndexFileFormat_pb2' - # @@protoc_insertion_point(class_scope:io.osirrc.ciff.PostingsList) - )) -_sym_db.RegisterMessage(PostingsList) - -DocRecord = _reflection.GeneratedProtocolMessageType('DocRecord', (_message.Message,), dict( - DESCRIPTOR = _DOCRECORD, - __module__ = 'CommonIndexFileFormat_pb2' - # @@protoc_insertion_point(class_scope:io.osirrc.ciff.DocRecord) - )) -_sym_db.RegisterMessage(DocRecord) - - -# @@protoc_insertion_point(module_scope) diff --git a/geesedb/utils/ciff/ciff_writer.py b/geesedb/utils/ciff/ciff_writer.py deleted file mode 100644 index a5fdbd4..0000000 --- a/geesedb/utils/ciff/ciff_writer.py +++ /dev/null @@ -1,59 +0,0 @@ -from pathlib import Path -from typing import IO, Iterable, Optional, Union - -from google.protobuf.internal.encoder import _VarintEncoder # type: ignore -from google.protobuf.message import Message - -from ..ciff.CommonIndexFileFormat_pb2 import DocRecord, Header, PostingsList - - -class MessageWriter: - filename: Optional[Path] = None - output: Optional[IO[bytes]] = None - - def __init__(self, output: Union[Path, IO[bytes]]) -> None: - if isinstance(output, Path): - self.filename = output - else: - self.output = output - - self.varint_encoder = _VarintEncoder() - - def write_message(self, message: Message): - self.write_serialized(message.SerializeToString()) - - def write_serialized(self, serialized_message: bytes): - if self.output is None: - raise ValueError('cannot write to closed file') - - self.varint_encoder(self.output.write, len(serialized_message)) - self.output.write(serialized_message) - - def write(self, data: bytes): - if self.output is None: - raise ValueError('cannot write to closed file') - - self.output.write(data) - - def __enter__(self): - if self.filename is not None: - self.output = open(self.filename, 'wb') - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.filename is not None: - self.output.close() - - -class CiffWriter(MessageWriter): - def write_header(self, header: Header): - self.write_message(header) - - def write_documents(self, documents: Iterable[DocRecord]): - for doc in documents: - self.write_message(doc) - - def write_postings_lists(self, postings_lists: Iterable[PostingsList]): - for pl in postings_lists: - self.write_message(pl) \ No newline at end of file diff --git a/geesedb/utils/ciff/to_ciff.py b/geesedb/utils/ciff/to_ciff.py index 989e9b5..7af242a 100644 --- a/geesedb/utils/ciff/to_ciff.py +++ b/geesedb/utils/ciff/to_ciff.py @@ -3,10 +3,10 @@ from typing import Any from tqdm import tqdm +from ciff_toolkit.ciff_pb2 import Header, Posting, PostingsList, DocRecord +from ciff_toolkit.write import CiffWriter -from .CommonIndexFileFormat_pb2 import Header, Posting, PostingsList, DocRecord from ...connection import get_connection -from .ciff_writer import MessageWriter class ToCiff: @@ -35,68 +35,76 @@ def get_arguments(kwargs: Any) -> dict: return arguments def create_ciff(self) -> None: - with MessageWriter(Path(self.arguments['ciff'])) as f: - # Create header - header = Header() - header.version = 1 # We work with ciff v1 - self.cursor.execute("""SELECT COUNT(*) FROM term_dict""") - header.num_postings_lists = self.cursor.fetchone()[0] - self.cursor.execute("""SELECT COUNT(*) FROM docs""") - header.num_docs = self.cursor.fetchone()[0] - header.total_postings_lists = header.num_postings_lists - header.total_docs = header.num_docs - self.cursor.execute("""SELECT SUM(tf) FROM term_doc""") - header.total_terms_in_collection = self.cursor.fetchone()[0] - header.average_doclength = header.total_terms_in_collection / header.num_docs - header.description = f'GeeseDB database {self.arguments["database"]}' - f.write_message(header) - - disable_tqdm = not self.arguments['verbose'] - - # Create postings lists - self.cursor.execute(""" - SELECT df, string, list(row(doc_id, tf) ORDER BY doc_id) - FROM term_dict, term_doc - WHERE term_dict.term_id = term_doc.term_id - GROUP BY term_dict.term_id, df, string - ORDER BY string - """) - with tqdm(total=header.num_postings_lists, disable=disable_tqdm) as pbar: - while batch := self.cursor.fetchmany(self.arguments['batch_size']): - for df, term, postings in batch: - postingsList = PostingsList() - assert len(postings) == df - cf = sum(p['tf'] for p in postings) - postingsList.term = term - postingsList.df = df - postingsList.cf = cf - old_id = 0 - for p in postings: - posting = Posting() - doc_id = p['doc_id'] - tf = p['tf'] - posting.docid = doc_id - old_id - old_id = doc_id - posting.tf = tf - postingsList.postings.append(posting) - f.write_message(postingsList) - pbar.update() - - # Create doc records - self.cursor.execute(""" - SELECT doc_id, collection_id, len - FROM docs - ORDER BY doc_id - """) - with tqdm(total=header.num_docs, disable=disable_tqdm) as pbar: - while batch := self.cursor.fetchmany(self.arguments['batch_size']): - for doc_id, collection_id, length in batch: - docRecord = DocRecord() - docRecord.docid = doc_id - docRecord.collection_docid = collection_id - docRecord.doclength = length - f.write_message(docRecord) - pbar.update() + disable_tqdm = not self.arguments['verbose'] + + with CiffWriter(self.arguments['ciff']) as writer: + header = self.get_ciff_header() + writer.write_header(header) + + postings_lists = tqdm(self.get_ciff_postings_lists(), total=header.num_postings_lists, disable=disable_tqdm) + writer.write_postings_lists(postings_lists) + + doc_records = tqdm(self.get_ciff_doc_records(), total=header.num_docs, disable=disable_tqdm) + writer.write_documents(doc_records) + + def get_ciff_header(self): + header = Header() + header.version = 1 # We work with ciff v1 + self.cursor.execute("""SELECT COUNT(*) FROM term_dict""") + header.num_postings_lists = self.cursor.fetchone()[0] + self.cursor.execute("""SELECT COUNT(*) FROM docs""") + header.num_docs = self.cursor.fetchone()[0] + header.total_postings_lists = header.num_postings_lists + header.total_docs = header.num_docs + self.cursor.execute("""SELECT SUM(tf) FROM term_doc""") + header.total_terms_in_collection = self.cursor.fetchone()[0] + header.average_doclength = header.total_terms_in_collection / header.num_docs + header.description = f'GeeseDB database {self.arguments["database"]}' + + return header + + def get_ciff_postings_lists(self): + self.cursor.execute(""" + SELECT df, string, list(row(doc_id, tf) ORDER BY doc_id) + FROM term_dict, term_doc + WHERE term_dict.term_id = term_doc.term_id + GROUP BY term_dict.term_id, df, string + ORDER BY string + """) + while batch := self.cursor.fetchmany(self.arguments['batch_size']): + for df, term, postings in batch: + postings_list = PostingsList() + assert len(postings) == df + cf = sum(p['tf'] for p in postings) + postings_list.term = term + postings_list.df = df + postings_list.cf = cf + old_id = 0 + for p in postings: + posting = Posting() + doc_id = p['doc_id'] + tf = p['tf'] + posting.docid = doc_id - old_id + old_id = doc_id + posting.tf = tf + postings_list.postings.append(posting) + + yield postings_list + + def get_ciff_doc_records(self): + self.cursor.execute(""" + SELECT doc_id, collection_id, len + FROM docs + ORDER BY doc_id + """) + while batch := self.cursor.fetchmany(self.arguments['batch_size']): + for doc_id, collection_id, length in batch: + doc_record = DocRecord() + doc_record.docid = doc_id + doc_record.collection_docid = collection_id + doc_record.doclength = length + + yield doc_record if __name__ == '__main__': diff --git a/geesedb/utils/ciff/to_csv.py b/geesedb/utils/ciff/to_csv.py index 78d0fe8..9095982 100755 --- a/geesedb/utils/ciff/to_csv.py +++ b/geesedb/utils/ciff/to_csv.py @@ -4,7 +4,7 @@ import gzip from typing import Union, Any, Tuple -from . import CommonIndexFileFormat_pb2 as Ciff +from ciff_toolkit.read import CiffReader class ToCSV: @@ -58,37 +58,19 @@ def create_csv_files(self) -> None: data = f.read() next_pos, pos = 0, 0 - # start with reading header info - header = Ciff.Header() - next_pos, pos = self.decode(data, pos) - header.ParseFromString(data[pos:pos+next_pos]) - pos += next_pos + with CiffReader(self.arguments['protobuf_file']) as reader: + with open(self.arguments['output_term_dict'], 'w') as term_dict_writer, \ + open(self.arguments['output_term_doc'], 'w') as term_doc_writer: + for term_id, postings_list in enumerate(reader.read_postings_lists()): + term_dict_writer.write(f'{term_id}|{postings_list.term}|{postings_list.df}\n') + docid = 0 + for posting in postings_list.postings: + docid += posting.docid + term_doc_writer.write(f'{term_id}|{docid}|{posting.tf}\n') - # read posting lists - term_dict_writer = open(self.arguments['output_term_dict'], 'w') - term_doc_writer = open(self.arguments['output_term_doc'], 'w') - postings_list = Ciff.PostingsList() - for term_id in range(header.num_postings_lists): - next_pos, pos = self.decode(data, pos) - postings_list.ParseFromString(data[pos:pos+next_pos]) - pos += next_pos - term_dict_writer.write(f'{term_id}|{postings_list.term}|{postings_list.df}\n') - docid = 0 - for posting in postings_list.postings: - docid += posting.docid - term_doc_writer.write(f'{term_id}|{docid}|{posting.tf}\n') - term_dict_writer.close() - term_doc_writer.close() - - # read doc information - docs_writer = open(self.arguments['output_docs'], 'w') - doc_record = Ciff.DocRecord() - for n in range(header.num_docs): - next_pos, pos = self.decode(data, pos) - doc_record.ParseFromString(data[pos:pos+next_pos]) - pos += next_pos - docs_writer.write(f'{doc_record.collection_docid}|{doc_record.docid}|{doc_record.doclength}\n') - docs_writer.close() + with open(self.arguments['output_docs'], 'w') as docs_writer: + for doc_record in reader.read_documents(): + docs_writer.write(f'{doc_record.collection_docid}|{doc_record.docid}|{doc_record.doclength}\n') if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index ac6d252..19ade1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ duckdb google>=2 -protobuf==3.20.2 numpy pandas +ciff-toolkit tqdm git+https://github.com/informagi/pycypher diff --git a/setup.py b/setup.py index 51ee68c..398dad7 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ author='Chris Kamphuis', author_email='chris@cs.ru.nl', url='https://github.com/informagi/GeeseDB', - install_requires=['duckdb', 'numpy', 'pandas', 'protobuf', 'tqdm', + install_requires=['duckdb', 'numpy', 'pandas', 'ciff-toolkit', 'tqdm', 'pycypher @ git+https://github.com/informagi/pycypher'], packages=find_packages(), include_package_data=True,