Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add explicit credentials to AzureMLAssetDataset #161

Draft
wants to merge 12 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion kedro_azureml/datasets/asset_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self,
azureml_dataset: str,
dataset: Union[str, Type[AbstractDataset], Dict[str, Any]],
credentials = None,
root_dir: str = "data",
filepath_arg: str = "filepath",
azureml_type: AzureMLDataAssetType = "uri_folder",
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(
self._version_cache = Cache(maxsize=2) # type: Cache
self._download = True
self._local_run = True
self._azureml_config = None
self._azureml_config = AzureMLConfig(**credentials) if credentials else None
self._azureml_type = azureml_type
if self._azureml_type not in get_args(AzureMLDataAssetType):
raise DatasetError(
Expand Down
49 changes: 26 additions & 23 deletions kedro_azureml/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ def after_context_created(self, context) -> None:
context.config_loader.config_patterns.update(
{"azureml": ["azureml*", "azureml*/**", "**/azureml*"]}
)
self.azure_config = AzureMLConfig(**context.config_loader["azureml"]["azure"])

@hook_impl
def after_catalog_created(self, catalog):
for dataset_name, dataset in catalog._data_sets.items():
if isinstance(dataset, AzureMLAssetDataset):
dataset.azure_config = self.azure_config
catalog.add(dataset_name, dataset, replace=True)
azure_config = AzureMLConfig(**context.config_loader["azureml"]["azure"])

azure_creds = {"azureml": azure_config.__dict__}

context.config_loader["credentials"] = {
**context.config_loader["credentials"],
**azure_creds,
}


@hook_impl
def before_pipeline_run(self, run_params, pipeline, catalog):
Expand All @@ -31,21 +33,22 @@ def before_pipeline_run(self, run_params, pipeline, catalog):
pipeline: The ``Pipeline`` object representing the pipeline to be run.
catalog: The ``DataCatalog`` from which to fetch data.
"""
for dataset_name, dataset in catalog._data_sets.items():
if isinstance(dataset, AzureMLAssetDataset):
if AzurePipelinesRunner.__name__ not in run_params["runner"]:
# when running locally using an AzureMLAssetDataset
# as an intermediate dataset we don't want download
# but still set to run local with a local version.
if dataset_name not in pipeline.inputs():
dataset.as_local_intermediate()
# when running remotely we still want to provide information
# from the azureml config for getting the dataset version during
# remote runs
else:
dataset.as_remote()

catalog.add(dataset_name, dataset, replace=True)

for input in pipeline.all_inputs():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the main difference happens. Instead of looping over the catalog, we loop over the pipeline's input, and then verify within the catalog if we have that input. That way, it gives a chance for the dataset factories to be instantiated, and they are then handled as usual (call as_remote(), etc...)

if input in catalog:
dataset = catalog._get_dataset(input)
if isinstance(dataset, AzureMLAssetDataset):
if AzurePipelinesRunner.__name__ not in run_params["runner"]:
# when running locally using an AzureMLAssetDataset
# as an intermediate dataset we don't want download
# but still set to run local with a local version.
if input not in pipeline.inputs():
dataset.as_local_intermediate()
# when running remotely we still want to provide information
# from the azureml config for getting the dataset version during
# remote runs
else:
dataset.as_remote()

catalog.add(input, dataset, replace=True)

azureml_local_run_hook = AzureMLLocalRunHook()
1 change: 1 addition & 0 deletions tests/conf/e2e/catalog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ model_input_table:
type: kedro_azureml.datasets.AzureMLAssetDataset
azureml_dataset: e2e_tests_no_pdp
root_dir: data/02_intermediate
credentials: azureml
dataset:
type: pandas.CSVDataset
filepath: model_input_table.csv
Expand Down
1 change: 1 addition & 0 deletions tests/conf/e2e_pipeline_data_passing/catalog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ model_input_table:
type: kedro_azureml.datasets.AzureMLAssetDataset
azureml_dataset: e2e_tests_pdp
root_dir: data/02_intermediate
credentials: azureml
dataset:
type: pandas.CSVDataset
filepath: model_input_table.csv
Expand Down