diff --git a/openfl/interface/model.py b/openfl/interface/model.py index 3cadb25051..9852124c6d 100644 --- a/openfl/interface/model.py +++ b/openfl/interface/model.py @@ -11,7 +11,6 @@ from click import confirm, group, option, pass_context, style from openfl.federated import Plan -from openfl.pipelines import NoCompressionPipeline from openfl.protocols import utils from openfl.utilities.click_types import InputSpec from openfl.utilities.dataloading import get_dataloader @@ -168,13 +167,14 @@ def get_model( ) data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) task_runner = plan.get_task_runner(data_loader=data_loader) + tensor_pipe = plan.get_tensor_pipe() model_protobuf_path = Path(model_protobuf_path).resolve() logger.info("Loading OpenFL model protobuf: 🠆 %s", model_protobuf_path) model_protobuf = utils.load_proto(model_protobuf_path) - tensor_dict, _ = utils.deconstruct_model_proto(model_protobuf, NoCompressionPipeline()) + tensor_dict, _ = utils.deconstruct_model_proto(model_protobuf, tensor_pipe) # This may break for multiple models. # task_runner.set_tensor_dict will need to handle multiple models