forked from inspirehep/inspirehep
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
reworked httphooks to make more easy to use * ref: cern-sis/issues-inspire/issues/594
- Loading branch information
Showing
22 changed files
with
6,645 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
import datetime | ||
import logging | ||
from datetime import timedelta | ||
|
||
from airflow.decorators import dag, task, task_group | ||
from airflow.models import Variable | ||
from hooks.generic_http_hook import GenericHttpHook | ||
from hooks.inspirehep.inspire_http_record_management_hook import ( | ||
InspireHTTPRecordManagementHook, | ||
) | ||
from tenacity import RetryError | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dag( | ||
start_date=datetime.datetime(2024, 11, 28), | ||
schedule="@daily", | ||
catchup=False, | ||
tags=["data"], | ||
max_active_runs=5, | ||
) | ||
def data_harvest_dag(): | ||
"""Defines the DAG for the HEPData harvest workflow. | ||
Tasks: | ||
1. collect_ids: Obtains all new data ids to process. | ||
2. download_record_versions: fetches a data record and all its previous versions | ||
3. build_record: Build a record that is compatible with the INSPIRE data schema | ||
4. load_record: Creates or Updates the record on INSPIRE. | ||
""" | ||
generic_http_hook = GenericHttpHook(http_conn_id="hepdata_connection") | ||
inspire_http_record_management_hook = InspireHTTPRecordManagementHook() | ||
|
||
data_schema = Variable.get("data_schema") | ||
url = inspire_http_record_management_hook.get_url() | ||
|
||
@task(task_id="collect_ids") | ||
def collect_ids(): | ||
"""Collects the ids of the records that have been updated in the last two days. | ||
Returns: list of ids | ||
""" | ||
|
||
from_date = (datetime.datetime.now().date() - timedelta(days=1)).strftime( | ||
"%Y-%m-%d" | ||
) | ||
payload = {"inspire_ids": True, "last_updated": from_date, "sort_by": "latest"} | ||
hepdata_response = generic_http_hook.call_api( | ||
endpoint="/search/ids", method="GET", params=payload | ||
) | ||
|
||
return hepdata_response.json() | ||
|
||
@task_group | ||
def process_record(record_id): | ||
"""Process the record by downloading the versions, | ||
building the record and loading it to inspirehep. | ||
""" | ||
|
||
@task(max_active_tis_per_dag=5) | ||
def download_record_versions(id): | ||
"""Download the versions of the record. | ||
Args: id (int): The id of the record. | ||
Returns: dict: The record versions. | ||
""" | ||
hepdata_response = generic_http_hook.call_api( | ||
endpoint=f"/record/ins{id}?format=json" | ||
) | ||
payload = hepdata_response.json() | ||
|
||
record = {"base": payload} | ||
for version in range(1, payload["record"]["version"]): | ||
response = generic_http_hook.call_api( | ||
endpoint=f"/record/ins{id}?format=json&version={version}" | ||
) | ||
response.raise_for_status() | ||
record[version] = response.json() | ||
|
||
return record | ||
|
||
@task.virtualenv( | ||
requirements=["inspire-schemas"], | ||
system_site_packages=False, | ||
) | ||
def build_record(data_schema, inspire_url, payload, **context): | ||
"""Build the record from the payload. | ||
Args: data_schema (str): The schema of the data. | ||
payload (dict): The payload of the record. | ||
Returns: dict: The built record. | ||
""" | ||
import datetime | ||
import re | ||
|
||
from inspire_schemas.builders import DataBuilder | ||
|
||
def add_version_specific_dois(record, builder): | ||
"""Add dois to the record.""" | ||
for data_table in record["data_tables"]: | ||
builder.add_doi(data_table["doi"], material="part") | ||
for resource_with_doi in record["resources_with_doi"]: | ||
builder.add_doi(resource_with_doi["doi"], material="part") | ||
|
||
builder.add_doi(record["record"]["hepdata_doi"], material="version") | ||
|
||
def add_keywords(record, builder): | ||
"""Add keywords to the record.""" | ||
for keyword, item in record.get("data_keywords", {}).items(): | ||
if keyword == "cmenergies": | ||
if len(item) >= 1 and "lte" in item[0] and "gte" in item[0]: | ||
builder.add_keyword( | ||
f"{keyword}: {item[0]['lte']}-{item[0]['gte']}" | ||
) | ||
elif keyword == "observables": | ||
builder.add_keyword(f"{keyword}: {','.join(item)}") | ||
else: | ||
for value in item: | ||
builder.add_keyword(value) | ||
|
||
builder = DataBuilder(source="hepdata") | ||
|
||
builder.add_creation_date(datetime.datetime.now(datetime.UTC).isoformat()) | ||
|
||
base_record = payload["base"] | ||
|
||
for collaboration in base_record["record"]["collaborations"]: | ||
builder.add_collaboration(collaboration) | ||
|
||
builder.add_abstract(base_record["record"]["data_abstract"]) | ||
|
||
add_keywords(base_record["record"], builder) | ||
|
||
doi = base_record["record"].get("doi") | ||
inspire_id = base_record["record"]["inspire_id"] | ||
|
||
if doi: | ||
builder.add_literature( | ||
doi={"value": doi}, | ||
record={"$ref": f"{inspire_url}/api/literature/{inspire_id}"}, | ||
) | ||
else: | ||
builder.add_literature( | ||
record={"$ref": f"{inspire_url}/api/literature/{inspire_id}"}, | ||
) | ||
|
||
for resource in base_record["record"]["resources"]: | ||
if resource["url"].startswith( | ||
"https://www.hepdata.net/record/resource/" | ||
): | ||
continue | ||
builder.add_url(resource["url"], description=resource["description"]) | ||
|
||
builder.add_title(base_record["record"]["title"]) | ||
|
||
builder.add_acquisition_source( | ||
method="hepcrawl", | ||
submission_number=base_record["record"]["inspire_id"], | ||
datetime=datetime.datetime.now(datetime.UTC).isoformat(), | ||
) | ||
|
||
mtc = re.match(r"(.*?)\.v\d+", base_record["record"]["hepdata_doi"]) | ||
if mtc: | ||
builder.add_doi(doi=mtc.group(1), material="data") | ||
else: | ||
builder.add_doi( | ||
doi=base_record["record"]["hepdata_doi"], material="data" | ||
) | ||
|
||
for _, record_version in payload.items(): | ||
add_version_specific_dois(record_version, builder) | ||
|
||
data = builder.record | ||
data["$schema"] = data_schema | ||
return data | ||
|
||
@task | ||
def load_record(new_record): | ||
"""Load the record to inspirehep. | ||
Args: new_record (dict): The record to create or update in inspire | ||
""" | ||
|
||
try: | ||
response = inspire_http_record_management_hook.get_record( | ||
pid_type="doi", control_number=new_record["dois"][0]["value"] | ||
) | ||
except RetryError: | ||
logger.info("Creating Record") | ||
post_response = inspire_http_record_management_hook.post_record( | ||
data=new_record, pid_type="data" | ||
) | ||
return post_response.json() | ||
|
||
old_record = response["metadata"] | ||
revision_id = response.get("revision_id", 0) | ||
old_record.update(new_record) | ||
logger.info(f"Updating Record: {old_record['control_number']}") | ||
response = inspire_http_record_management_hook.update_record( | ||
data=old_record, | ||
pid_type="data", | ||
control_number=old_record["control_number"], | ||
revision_id=revision_id + 1, | ||
) | ||
return response.json() | ||
|
||
hepdata_record_versions = download_record_versions(record_id) | ||
record = build_record( | ||
data_schema=data_schema, inspire_url=url, payload=hepdata_record_versions | ||
) | ||
load_record(record) | ||
|
||
process_record.expand(record_id=collect_ids()) | ||
|
||
|
||
data_harvest_dag() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import logging | ||
|
||
import requests | ||
from airflow.providers.http.hooks.http import HttpHook | ||
from hooks.tenacity_config import tenacity_retry_kwargs | ||
from requests import Response | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
class GenericHttpHook(HttpHook): | ||
""" | ||
Hook to interact with Inspire API | ||
It overrides the original `run` method in HttpHook so that | ||
we can pass data argument as data, not params | ||
""" | ||
|
||
def __init__(self, http_conn_id, method="GET", headers=None): | ||
self._headers = headers | ||
super().__init__(method=method, http_conn_id=http_conn_id) | ||
|
||
@property | ||
def tenacity_retry_kwargs(self) -> dict: | ||
return tenacity_retry_kwargs() | ||
|
||
@property | ||
def headers(self) -> dict: | ||
return self._headers | ||
|
||
@headers.setter | ||
def headers(self, headers): | ||
self._headers = headers | ||
|
||
def run( | ||
self, | ||
endpoint: str, | ||
method: str = None, | ||
json: dict = None, | ||
data: dict = None, | ||
params: dict = None, | ||
headers: dict = None, | ||
extra_options: dict = None, | ||
): | ||
extra_options = extra_options or {} | ||
method = method or self.method | ||
headers = headers or self.headers | ||
session = self.get_conn(headers) | ||
|
||
if not self.base_url.endswith("/") and not endpoint.startswith("/"): | ||
url = self.base_url + "/" + endpoint | ||
else: | ||
url = self.base_url + endpoint | ||
|
||
req = requests.Request( | ||
method, url, json=json, data=data, params=params, headers=headers | ||
) | ||
|
||
prepped_request = session.prepare_request(req) | ||
self.log.info("Sending '%s' to url: %s", method, url) | ||
return self.run_and_check(session, prepped_request, extra_options) | ||
|
||
def call_api( | ||
self, | ||
endpoint: str, | ||
method: str = None, | ||
data: dict = None, | ||
params: dict = None, | ||
headers: dict = None, | ||
) -> Response: | ||
return self.run_with_advanced_retry( | ||
_retry_args=self.tenacity_retry_kwargs, | ||
endpoint=endpoint, | ||
headers=headers, | ||
json=data, | ||
params=params, | ||
method=method, | ||
) | ||
|
||
def get_url(self) -> str: | ||
self.get_conn() | ||
return self.base_url |
Oops, something went wrong.