From fd9bcd01c7b94e52764cd70ab2f78d4c807598fa Mon Sep 17 00:00:00 2001 From: PascalEgn Date: Tue, 16 Jul 2024 14:12:21 +0200 Subject: [PATCH] add month para for data collection --- README.md | 4 +++- scripts/create_dataset.py | 30 ++++++++++++++++++++++-------- scripts/train_classifier.py | 4 ++-- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 845f680..eec57bc 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,9 @@ Set the enviroment variables for inspire-prod es database and run the [`create_d export ES_USERNAME=XXXX export ES_PASSWORD=XXXX -poetry run python scripts/create_dataset.py --year-from $YEAR_FROM --year-to $YEAR_TO +poetry run python scripts/create_dataset.py --year-from $YEAR_FROM --month-from $MONTH_FROM --year-to $YEAR_TO --month-to $MONTH_TO + +($MONTH_FROM and $MONTH_TO are optional parameters) ``` diff --git a/scripts/create_dataset.py b/scripts/create_dataset.py index 29c7ada..f994d25 100644 --- a/scripts/create_dataset.py +++ b/scripts/create_dataset.py @@ -54,10 +54,12 @@ def __init__(self, index, **kwargs): class InspireClassifierSearch(object): - def __init__(self, index, query_filters, year_from, year_to): + def __init__(self, index, query_filters, year_from, year_to, month_from, month_to): self.search = LiteratureSearch(index=index) self.year_from = year_from + self.month_from = month_from self.year_to = year_to + self.month_to = month_to # Training, validation and test data @@ -77,8 +79,8 @@ def __init__(self, index, query_filters, year_from, year_to): & Q( "range", metadata__acquisition_source__datetime={ - "gte": self.year_from, - "lt": self.year_to, + "gte": f"{self.year_from}-{self.month_from}", + "lt": f"{self.year_to}-{self.month_to}", }, ), ] @@ -90,7 +92,8 @@ def __init__(self, index, query_filters, year_from, year_to): self.inspire_categories_field = "inspire_categories.term" self.query_filters = [ query_filters - & Q("range", _created={"gte": self.year_from, "lt": self.year_to}), + & Q("range", _created={"gte": f"{self.year_from}-{self.month_from}", + "lt": f"{self.year_to}-{self.month_to}",}), ] def _postprocess_record_data(self, record_data): @@ -113,13 +116,15 @@ def get_decision_query(self): return query -def get_data_for_decisions(year_from, year_to): +def get_data_for_decisions(year_from, year_to, month_from, month_to): for decision in DECISIONS_MAPPING: inspire_search = InspireClassifierSearch( index=DECISIONS_MAPPING[decision]["index"], query_filters=DECISIONS_MAPPING[decision]["filter_query"], year_from=year_from, year_to=year_to, + month_from=month_from, + month_to=month_to, ) query = inspire_search.get_decision_query() for record_es_data in tqdm(query.scan()): @@ -144,14 +149,23 @@ def prepare_inspire_classifier_dataset(data, save_data_path): @click.command() @click.option("--year-from", type=int, required=True) +@click.option("--month-from", type=int, required=False, default=1) @click.option("--year-to", type=int, required=True) -def get_inspire_classifier_dataset(year_from, year_to): +@click.option("--month-to", type=int, required=False, default=12) +def get_inspire_classifier_dataset(year_from, year_to, month_from, month_to): if year_to < year_from: raise ValueError("year_to must be before year_from") + if month_to < month_from: + raise ValueError("month_to must be before month_from") + if month_to > 12 or month_from > 12 or month_to < 1 or month_from < 1: + raise ValueError("month_to and month_from must be between 1 and 12") + month_from = f"{month_from:02d}-01" + month_to = f"{month_to:02d}-31" + print(f"Fetching {year_from}-{month_from} to {year_to}-{month_to}") inspire_classifier_dataset_path = os.path.join( - os.getcwd(), "inspire_classifier_dataset.pkl" + os.getcwd(), f"inspire_classifier_dataset_{year_from}-{month_from}_{year_to}-{month_to}.pkl" ) - data = get_data_for_decisions(year_from, year_to) + data = get_data_for_decisions(year_from, year_to, month_from, month_to) prepare_inspire_classifier_dataset(data, inspire_classifier_dataset_path) diff --git a/scripts/train_classifier.py b/scripts/train_classifier.py index 996b1d0..b5e69cb 100644 --- a/scripts/train_classifier.py +++ b/scripts/train_classifier.py @@ -55,8 +55,8 @@ def train_classifier( # Adjust necessary data -NUMBER_OF_CLASSIFIER_EPOCHS = 15 -NUMBER_OF_LANGUAGE_MODEL_EPOCHS = 15 +NUMBER_OF_CLASSIFIER_EPOCHS = 10 +NUMBER_OF_LANGUAGE_MODEL_EPOCHS = 10 TRAIN_TEST_SPLIT = 0.8 train_classifier(