Skip to content

Commit

Permalink
add month para for data collection
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalEgn committed Jul 16, 2024
1 parent 19d1b36 commit fd9bcd0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ Set the enviroment variables for inspire-prod es database and run the [`create_d
export ES_USERNAME=XXXX
export ES_PASSWORD=XXXX
poetry run python scripts/create_dataset.py --year-from $YEAR_FROM --year-to $YEAR_TO
poetry run python scripts/create_dataset.py --year-from $YEAR_FROM --month-from $MONTH_FROM --year-to $YEAR_TO --month-to $MONTH_TO
($MONTH_FROM and $MONTH_TO are optional parameters)
```


Expand Down
30 changes: 22 additions & 8 deletions scripts/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def __init__(self, index, **kwargs):


class InspireClassifierSearch(object):
def __init__(self, index, query_filters, year_from, year_to):
def __init__(self, index, query_filters, year_from, year_to, month_from, month_to):
self.search = LiteratureSearch(index=index)
self.year_from = year_from
self.month_from = month_from
self.year_to = year_to
self.month_to = month_to

# Training, validation and test data

Expand All @@ -77,8 +79,8 @@ def __init__(self, index, query_filters, year_from, year_to):
& Q(
"range",
metadata__acquisition_source__datetime={
"gte": self.year_from,
"lt": self.year_to,
"gte": f"{self.year_from}-{self.month_from}",
"lt": f"{self.year_to}-{self.month_to}",
},
),
]
Expand All @@ -90,7 +92,8 @@ def __init__(self, index, query_filters, year_from, year_to):
self.inspire_categories_field = "inspire_categories.term"
self.query_filters = [
query_filters
& Q("range", _created={"gte": self.year_from, "lt": self.year_to}),
& Q("range", _created={"gte": f"{self.year_from}-{self.month_from}",
"lt": f"{self.year_to}-{self.month_to}",}),
]

def _postprocess_record_data(self, record_data):
Expand All @@ -113,13 +116,15 @@ def get_decision_query(self):
return query


def get_data_for_decisions(year_from, year_to):
def get_data_for_decisions(year_from, year_to, month_from, month_to):
for decision in DECISIONS_MAPPING:
inspire_search = InspireClassifierSearch(
index=DECISIONS_MAPPING[decision]["index"],
query_filters=DECISIONS_MAPPING[decision]["filter_query"],
year_from=year_from,
year_to=year_to,
month_from=month_from,
month_to=month_to,
)
query = inspire_search.get_decision_query()
for record_es_data in tqdm(query.scan()):
Expand All @@ -144,14 +149,23 @@ def prepare_inspire_classifier_dataset(data, save_data_path):

@click.command()
@click.option("--year-from", type=int, required=True)
@click.option("--month-from", type=int, required=False, default=1)
@click.option("--year-to", type=int, required=True)
def get_inspire_classifier_dataset(year_from, year_to):
@click.option("--month-to", type=int, required=False, default=12)
def get_inspire_classifier_dataset(year_from, year_to, month_from, month_to):
if year_to < year_from:
raise ValueError("year_to must be before year_from")
if month_to < month_from:
raise ValueError("month_to must be before month_from")
if month_to > 12 or month_from > 12 or month_to < 1 or month_from < 1:
raise ValueError("month_to and month_from must be between 1 and 12")
month_from = f"{month_from:02d}-01"
month_to = f"{month_to:02d}-31"
print(f"Fetching {year_from}-{month_from} to {year_to}-{month_to}")
inspire_classifier_dataset_path = os.path.join(
os.getcwd(), "inspire_classifier_dataset.pkl"
os.getcwd(), f"inspire_classifier_dataset_{year_from}-{month_from}_{year_to}-{month_to}.pkl"
)
data = get_data_for_decisions(year_from, year_to)
data = get_data_for_decisions(year_from, year_to, month_from, month_to)
prepare_inspire_classifier_dataset(data, inspire_classifier_dataset_path)


Expand Down
4 changes: 2 additions & 2 deletions scripts/train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def train_classifier(


# Adjust necessary data
NUMBER_OF_CLASSIFIER_EPOCHS = 15
NUMBER_OF_LANGUAGE_MODEL_EPOCHS = 15
NUMBER_OF_CLASSIFIER_EPOCHS = 10
NUMBER_OF_LANGUAGE_MODEL_EPOCHS = 10
TRAIN_TEST_SPLIT = 0.8

train_classifier(
Expand Down

0 comments on commit fd9bcd0

Please sign in to comment.