Skip to content

Commit

Permalink
#42 done. #53 related
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 14, 2023
1 parent 3623170 commit 24a5151
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
9 changes: 9 additions & 0 deletions arekit_ss/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
parser.add_argument("--writer", type=str, default="csv")
parser.add_argument("--source", type=str, default="ruattitudes")
parser.add_argument("--sampler", type=str, default="nn")
parser.add_argument("--splits", type=str, default=None,
help="Manual selection of the data-types related splits that "
"should be chosen for the sampling process; types should be "
"separated by ':' sign; for example: 'train:test'")
parser.add_argument("--src_lang", type=str, default=None, required=False)
parser.add_argument("--dest_lang", type=str, default=None, required=False)
parser.add_argument("--output_dir", type=str, default="_out")
Expand Down Expand Up @@ -60,11 +64,16 @@
cfg.docs_limit = args.docs_limit
cfg.entities_parser = source["entity_parser"]
cfg.text_parser = text_parsing_pipelines[args.text_parser](cfg)
cfg.splits = args.splits

# Extract data to be serialized in a form of the pipeline.
dpp = source["pipeline"]
data_folding, data_type_pipelines = dpp(cfg)

# Filter only those data_types that were chosen.
data_type_pipelines = {k: data_type_pipelines[k] for k in cfg.get_supported_datatypes()
if k in data_type_pipelines}

# Prepare serializer and pass data_type_pipelines.
pipeline_item = create_sampler_pipeline_item(
args=args, writer=writer,
Expand Down
25 changes: 25 additions & 0 deletions arekit_ss/sources/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from arekit.common.experiment.data_type import DataType


class SourcesConfig:

def __init__(self):
Expand All @@ -7,3 +10,25 @@ def __init__(self):
self.docs_limit = None
self.entities_parser = None
self.text_parser = None
self.splits = None

def get_supported_datatypes(self):
""" String split name to data-types converter.
"""

# AREkit 0.23.1 has the predefined type DataType which describes the
# splits in a form of Enum.
data_type_to_split = {
"train": DataType.Train,
"test": DataType.Test,
"dev": DataType.Dev,
"etalon": DataType.Etalon
}

if self.splits is None:
return set(data_type_to_split.values())

chosen_splits = set(self.splits.split(":"))
return set([data_type_to_split[split_name] for split_name in chosen_splits])


0 comments on commit 24a5151

Please sign in to comment.