diff --git a/openfl/federated/data/loader_gandlf.py b/openfl/federated/data/loader_gandlf.py index 6e1a04342a..648ebe2930 100644 --- a/openfl/federated/data/loader_gandlf.py +++ b/openfl/federated/data/loader_gandlf.py @@ -25,7 +25,10 @@ def __init__(self, data_path, feature_shape): data_path (str): The path to the directory containing the data. feature_shape (tuple): The shape of an example feature array. """ - self.train_csv = data_path + "/train.csv" + if "inference" in data_path: + self.train_csv = None + else: + self.train_csv = data_path + "/train.csv" self.val_csv = data_path + "/valid.csv" self.train_dataloader = None self.val_dataloader = None