From 3ea49e7892bc02ff3f28bafb44fb371b45800740 Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Sat, 16 Sep 2023 11:18:55 +0100 Subject: [PATCH] #66 support filtering for relation types --- arekit_ss/filters/__init__.py | 0 arekit_ss/filters/label_type.py | 13 +++++++++++++ arekit_ss/filters/object_type.py | 13 +++++++++++++ arekit_ss/sample.py | 8 +++++++- arekit_ss/sources/config.py | 1 + arekit_ss/sources/nerel/data_pipeline.py | 1 + arekit_ss/sources/nerel_bio/data_pipeline.py | 1 + arekit_ss/sources/ruattitudes/data_pipeline.py | 1 + arekit_ss/sources/rusentrel/data_pipeline.py | 1 + arekit_ss/sources/sentinerel/data_pipeline.py | 1 + 10 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 arekit_ss/filters/__init__.py create mode 100644 arekit_ss/filters/label_type.py create mode 100644 arekit_ss/filters/object_type.py diff --git a/arekit_ss/filters/__init__.py b/arekit_ss/filters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/arekit_ss/filters/label_type.py b/arekit_ss/filters/label_type.py new file mode 100644 index 0000000..a5ab465 --- /dev/null +++ b/arekit_ss/filters/label_type.py @@ -0,0 +1,13 @@ +from arekit.common.text_opinions.base import TextOpinion +from arekit.contrib.utils.pipelines.text_opinion.filters.base import TextOpinionFilter + + +class LabelTextOpinionFilter(TextOpinionFilter): + + def __init__(self, relation_types): + assert(isinstance(relation_types, list)) + self.__relation_types = set(relation_types) + + def filter(self, text_opinion, parsed_doc, entity_service_provider): + assert(isinstance(text_opinion, TextOpinion)) + return type(text_opinion.Label).__name__ in self.__relation_types diff --git a/arekit_ss/filters/object_type.py b/arekit_ss/filters/object_type.py new file mode 100644 index 0000000..646bae6 --- /dev/null +++ b/arekit_ss/filters/object_type.py @@ -0,0 +1,13 @@ +from arekit.common.text_opinions.base import TextOpinion +from arekit.contrib.utils.pipelines.text_opinion.filters.base import TextOpinionFilter + + +class ObjectOpinionFilter(TextOpinionFilter): + + def __init__(self, relation_types): + assert(isinstance(relation_types, list)) + self.__relation_types = set(relation_types) + + def filter(self, text_opinion, parsed_doc, entity_service_provider): + assert(isinstance(text_opinion, TextOpinion)) + return type(text_opinion.Label).__name__ in self.__relation_types diff --git a/arekit_ss/sample.py b/arekit_ss/sample.py index dbbd8ab..b28ac39 100644 --- a/arekit_ss/sample.py +++ b/arekit_ss/sample.py @@ -5,6 +5,7 @@ from arekit.contrib.utils.data.writers.json_opennre import OpenNREJsonWriter from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter +from arekit_ss.filters.label_type import LabelTextOpinionFilter from arekit_ss.framework.samplers_list import create_sampler_pipeline_item from arekit_ss.sources import src_list from arekit_ss.sources.config import SourcesConfig @@ -37,6 +38,7 @@ parser.add_argument("--prompt", type=str, default="{text},`{s_val}`,`{t_val}`, `{label_val}`") parser.add_argument("--text_parser", type=str, default="nn") parser.add_argument("--doc_ids", type=str, default=None) + parser.add_argument("--relation_types", type=str, default=None) parser.add_argument("--docs_limit", type=int, default=None) parser.add_argument("--terms_per_context", type=int, default=50) parser.add_argument('--no-vectorize', dest='vectorize', action='store_false', @@ -53,7 +55,7 @@ elif args.writer in ['jsonl', 'json']: writer = OpenNREJsonWriter(text_columns=["text_a", "text_b"]) elif args.writer == "sqlite": - writer = SQliteWriter() + writer = SQliteWriter(skip_existed=False) else: raise Exception("writer `{}` is not supported!".format(args.writer)) @@ -73,6 +75,10 @@ cfg.text_parser = text_parsing_pipelines[args.text_parser](cfg) cfg.splits = args.splits + # Setup filters for text opinions extraction. + if args.relation_types is not None: + cfg.optional_filters.append(LabelTextOpinionFilter(args.relation_types.split("|"))) + # Extract data to be serialized in a form of the pipeline. dpp = auto_import(name=source["pipeline"]) data_folding, data_type_pipelines = dpp(cfg) diff --git a/arekit_ss/sources/config.py b/arekit_ss/sources/config.py index abc68f3..bd8276b 100644 --- a/arekit_ss/sources/config.py +++ b/arekit_ss/sources/config.py @@ -11,6 +11,7 @@ def __init__(self): self.entities_parser = None self.text_parser = None self.splits = None + self.optional_filters = [] def get_supported_datatypes(self): """ String split name to data-types converter. diff --git a/arekit_ss/sources/nerel/data_pipeline.py b/arekit_ss/sources/nerel/data_pipeline.py index c1b43be..a2e2611 100644 --- a/arekit_ss/sources/nerel/data_pipeline.py +++ b/arekit_ss/sources/nerel/data_pipeline.py @@ -17,6 +17,7 @@ def build_nerel_datapipeline(cfg): terms_per_context=cfg.terms_per_context, label_formatter=NerelAnyLabelFormatter(), docs_limit=cfg.docs_limit, + custom_text_opinion_filters=cfg.optional_filters, doc_ops=None, text_parser=cfg.text_parser) diff --git a/arekit_ss/sources/nerel_bio/data_pipeline.py b/arekit_ss/sources/nerel_bio/data_pipeline.py index fb498fc..19eb20e 100644 --- a/arekit_ss/sources/nerel_bio/data_pipeline.py +++ b/arekit_ss/sources/nerel_bio/data_pipeline.py @@ -19,6 +19,7 @@ def build_nerel_bio_datapipeline(cfg): terms_per_context=cfg.terms_per_context, label_formatter=NerelBioAnyLabelFormatter(), docs_limit=cfg.docs_limit, + custom_text_opinion_filters=cfg.optional_filters, doc_ops=None, text_parser=cfg.text_parser) diff --git a/arekit_ss/sources/ruattitudes/data_pipeline.py b/arekit_ss/sources/ruattitudes/data_pipeline.py index afdc24c..42af290 100644 --- a/arekit_ss/sources/ruattitudes/data_pipeline.py +++ b/arekit_ss/sources/ruattitudes/data_pipeline.py @@ -17,6 +17,7 @@ def build_ruattitudes_datapipeline(cfg): version=version, text_parser=cfg.text_parser, label_scaler=PosNegNeuRelationsLabelScaler(), + custom_text_opinion_filters=cfg.optional_filters, limit=cfg.docs_limit) d = RuAttitudesDocumentProvider.read_ruattitudes_to_brat_in_memory( diff --git a/arekit_ss/sources/rusentrel/data_pipeline.py b/arekit_ss/sources/rusentrel/data_pipeline.py index bcdaf00..9cb244b 100644 --- a/arekit_ss/sources/rusentrel/data_pipeline.py +++ b/arekit_ss/sources/rusentrel/data_pipeline.py @@ -25,6 +25,7 @@ def build_s_rusentrel_datapipeline(cfg): pipeline = create_text_opinion_extraction_pipeline( rusentrel_version=version, text_parser=cfg.text_parser, + custom_text_opinion_filters=cfg.optional_filters, labels_fmt=RuSentRelLabelsFormatter(pos_label_type=PositiveTo, neg_label_type=NegativeTo)) data_folding = { diff --git a/arekit_ss/sources/sentinerel/data_pipeline.py b/arekit_ss/sources/sentinerel/data_pipeline.py index f309884..31be2b5 100644 --- a/arekit_ss/sources/sentinerel/data_pipeline.py +++ b/arekit_ss/sources/sentinerel/data_pipeline.py @@ -11,6 +11,7 @@ def build_sentinerel_datapipeline(cfg): pipelines, data_folding = create_text_opinion_extraction_pipeline( sentinerel_version=SentiNerelVersions.V21, terms_per_context=cfg.terms_per_context, + custom_text_opinion_filters=cfg.optional_filters, docs_limit=cfg.docs_limit, doc_provider=None, text_parser=cfg.text_parser)