Skip to content

Commit

Permalink
#66 support filtering for relation types
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 16, 2023
1 parent 0fbc8be commit 3ea49e7
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 1 deletion.
Empty file added arekit_ss/filters/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions arekit_ss/filters/label_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from arekit.common.text_opinions.base import TextOpinion
from arekit.contrib.utils.pipelines.text_opinion.filters.base import TextOpinionFilter


class LabelTextOpinionFilter(TextOpinionFilter):

def __init__(self, relation_types):
assert(isinstance(relation_types, list))
self.__relation_types = set(relation_types)

def filter(self, text_opinion, parsed_doc, entity_service_provider):
assert(isinstance(text_opinion, TextOpinion))
return type(text_opinion.Label).__name__ in self.__relation_types
13 changes: 13 additions & 0 deletions arekit_ss/filters/object_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from arekit.common.text_opinions.base import TextOpinion
from arekit.contrib.utils.pipelines.text_opinion.filters.base import TextOpinionFilter


class ObjectOpinionFilter(TextOpinionFilter):

def __init__(self, relation_types):
assert(isinstance(relation_types, list))
self.__relation_types = set(relation_types)

def filter(self, text_opinion, parsed_doc, entity_service_provider):
assert(isinstance(text_opinion, TextOpinion))
return type(text_opinion.Label).__name__ in self.__relation_types
8 changes: 7 additions & 1 deletion arekit_ss/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from arekit.contrib.utils.data.writers.json_opennre import OpenNREJsonWriter
from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter

from arekit_ss.filters.label_type import LabelTextOpinionFilter
from arekit_ss.framework.samplers_list import create_sampler_pipeline_item
from arekit_ss.sources import src_list
from arekit_ss.sources.config import SourcesConfig
Expand Down Expand Up @@ -37,6 +38,7 @@
parser.add_argument("--prompt", type=str, default="{text},`{s_val}`,`{t_val}`, `{label_val}`")
parser.add_argument("--text_parser", type=str, default="nn")
parser.add_argument("--doc_ids", type=str, default=None)
parser.add_argument("--relation_types", type=str, default=None)
parser.add_argument("--docs_limit", type=int, default=None)
parser.add_argument("--terms_per_context", type=int, default=50)
parser.add_argument('--no-vectorize', dest='vectorize', action='store_false',
Expand All @@ -53,7 +55,7 @@
elif args.writer in ['jsonl', 'json']:
writer = OpenNREJsonWriter(text_columns=["text_a", "text_b"])
elif args.writer == "sqlite":
writer = SQliteWriter()
writer = SQliteWriter(skip_existed=False)
else:
raise Exception("writer `{}` is not supported!".format(args.writer))

Expand All @@ -73,6 +75,10 @@
cfg.text_parser = text_parsing_pipelines[args.text_parser](cfg)
cfg.splits = args.splits

# Setup filters for text opinions extraction.
if args.relation_types is not None:
cfg.optional_filters.append(LabelTextOpinionFilter(args.relation_types.split("|")))

# Extract data to be serialized in a form of the pipeline.
dpp = auto_import(name=source["pipeline"])
data_folding, data_type_pipelines = dpp(cfg)
Expand Down
1 change: 1 addition & 0 deletions arekit_ss/sources/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self):
self.entities_parser = None
self.text_parser = None
self.splits = None
self.optional_filters = []

def get_supported_datatypes(self):
""" String split name to data-types converter.
Expand Down
1 change: 1 addition & 0 deletions arekit_ss/sources/nerel/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def build_nerel_datapipeline(cfg):
terms_per_context=cfg.terms_per_context,
label_formatter=NerelAnyLabelFormatter(),
docs_limit=cfg.docs_limit,
custom_text_opinion_filters=cfg.optional_filters,
doc_ops=None,
text_parser=cfg.text_parser)

Expand Down
1 change: 1 addition & 0 deletions arekit_ss/sources/nerel_bio/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def build_nerel_bio_datapipeline(cfg):
terms_per_context=cfg.terms_per_context,
label_formatter=NerelBioAnyLabelFormatter(),
docs_limit=cfg.docs_limit,
custom_text_opinion_filters=cfg.optional_filters,
doc_ops=None,
text_parser=cfg.text_parser)

Expand Down
1 change: 1 addition & 0 deletions arekit_ss/sources/ruattitudes/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def build_ruattitudes_datapipeline(cfg):
version=version,
text_parser=cfg.text_parser,
label_scaler=PosNegNeuRelationsLabelScaler(),
custom_text_opinion_filters=cfg.optional_filters,
limit=cfg.docs_limit)

d = RuAttitudesDocumentProvider.read_ruattitudes_to_brat_in_memory(
Expand Down
1 change: 1 addition & 0 deletions arekit_ss/sources/rusentrel/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def build_s_rusentrel_datapipeline(cfg):
pipeline = create_text_opinion_extraction_pipeline(
rusentrel_version=version,
text_parser=cfg.text_parser,
custom_text_opinion_filters=cfg.optional_filters,
labels_fmt=RuSentRelLabelsFormatter(pos_label_type=PositiveTo, neg_label_type=NegativeTo))

data_folding = {
Expand Down
1 change: 1 addition & 0 deletions arekit_ss/sources/sentinerel/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def build_sentinerel_datapipeline(cfg):
pipelines, data_folding = create_text_opinion_extraction_pipeline(
sentinerel_version=SentiNerelVersions.V21,
terms_per_context=cfg.terms_per_context,
custom_text_opinion_filters=cfg.optional_filters,
docs_limit=cfg.docs_limit,
doc_provider=None,
text_parser=cfg.text_parser)
Expand Down

0 comments on commit 3ea49e7

Please sign in to comment.