diff --git a/openfl/federated/task/fl_model.py b/openfl/federated/task/fl_model.py index eb7365225f9..8029cc08196 100644 --- a/openfl/federated/task/fl_model.py +++ b/openfl/federated/task/fl_model.py @@ -77,7 +77,7 @@ def __init__(self, build_model, optimizer=None, loss_fn=None, **kwargs): self.runner.validate = lambda *args, **kwargs: build_model.validate( self.runner, *args, **kwargs ) - + if hasattr(self.model, "train_epoch"): self.runner.train_epoch = lambda *args, **kwargs: build_model.train_epoch( self.runner, *args, **kwargs