Skip to content

Commit

Permalink
fix(openfl-workspaces): added type check for data path as shard num (…
Browse files Browse the repository at this point in the history
…int) before loading dataset

Signed-off-by: Pant, Akshay <[email protected]>
  • Loading branch information
theakshaypant committed Aug 26, 2024
1 parent 70fe34a commit 3035e0e
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 88 deletions.
17 changes: 7 additions & 10 deletions openfl-workspace/keras_cnn_mnist/src/tfmnist_inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openfl.federated import TensorFlowDataLoader
from .mnist_utils import load_mnist_shard
from logging import getLogger

logger = getLogger(__name__)


class TensorFlowMNISTInMemory(TensorFlowDataLoader):
Expand All @@ -29,14 +26,14 @@ def __init__(self, data_path, batch_size, **kwargs):
# what index/rank is this collaborator.
# Then we have a way to automatically shard based on rank and size of
# collaborator list.

try:
_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openfl.federated import TensorFlowDataLoader
from .mnist_utils import load_mnist_shard
from logging import getLogger

logger = getLogger(__name__)


class TensorFlowMNISTInMemory(TensorFlowDataLoader):
Expand All @@ -29,15 +26,14 @@ def __init__(self, data_path, batch_size, **kwargs):
# what index/rank is this collaborator.
# Then we have a way to automatically shard based on rank and size of
# collaborator list.

try:
_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
Expand Down
12 changes: 6 additions & 6 deletions openfl-workspace/torch_cnn_histology/src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def __init__(self, data_path, batch_size, **kwargs):
super().__init__(batch_size, random_seed=0, **kwargs)

try:
_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs
)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openfl.federated import PyTorchDataLoader
from .histology_utils import load_histology_shard
from logging import getLogger

logger = getLogger(__name__)


class PyTorchHistologyInMemory(PyTorchDataLoader):
Expand All @@ -25,11 +22,13 @@ def __init__(self, data_path, batch_size, **kwargs):
super().__init__(batch_size, random_seed=0, **kwargs)

try:
_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
Expand Down
13 changes: 6 additions & 7 deletions openfl-workspace/torch_cnn_mnist/src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ def __init__(self, data_path, batch_size, **kwargs):
super().__init__(batch_size, **kwargs)

try:
num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return

int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openfl.federated import PyTorchDataLoader
from .mnist_utils import load_mnist_shard
from logging import getLogger

logger = getLogger(__name__)


class PyTorchMNISTInMemory(PyTorchDataLoader):
Expand All @@ -31,13 +28,13 @@ def __init__(self, data_path, batch_size, **kwargs):
# of collaborator list.

try:
num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openfl.federated import PyTorchDataLoader
from .mnist_utils import load_mnist_shard
from logging import getLogger

logger = getLogger(__name__)


class PyTorchMNISTInMemory(PyTorchDataLoader):
Expand All @@ -31,13 +28,13 @@ def __init__(self, data_path, batch_size, **kwargs):
# of collaborator list.

try:
num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from openfl.federated import PyTorchDataLoader
from .mnist_utils import load_mnist_shard
from logging import getLogger

logger = getLogger(__name__)


class PyTorchMNISTInMemory(PyTorchDataLoader):
Expand All @@ -29,15 +26,14 @@ def __init__(self, data_path, batch_size, **kwargs):
# what index/rank is this collaborator.
# Then we have a way to automatically shard based on rank and size
# of collaborator list.

try:
num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
17 changes: 7 additions & 10 deletions openfl-workspace/torch_unet_kvasir/src/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

from openfl.federated import PyTorchDataLoader
from openfl.utilities import validate_file_hash
from logging import getLogger

logger = getLogger(__name__)


def read_data(image_path, mask_path):
Expand Down Expand Up @@ -125,14 +122,14 @@ def __init__(self, data_path, batch_size, **kwargs):
"""
super().__init__(batch_size, **kwargs)

load_kvasir_dataset()
try:
self.valid_dataset = KvasirDataset(True, shard_num=int(data_path), **kwargs)
self.train_dataset = KvasirDataset(False, shard_num=int(data_path), **kwargs)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return

int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

load_kvasir_dataset()
self.valid_dataset = KvasirDataset(True, shard_num=int(data_path), **kwargs)
self.train_dataset = KvasirDataset(False, shard_num=int(data_path), **kwargs)
self.train_loader = self.get_train_loader()
self.val_loader = self.get_valid_loader()
self.batch_size = batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

from openfl.federated import PyTorchDataLoader
from openfl.utilities import validate_file_hash
from logging import getLogger

logger = getLogger(__name__)


def read_data(image_path, mask_path):
Expand Down Expand Up @@ -125,13 +122,14 @@ def __init__(self, data_path, batch_size, **kwargs):
"""
super().__init__(batch_size, **kwargs)

load_kvasir_dataset()
try:
self.valid_dataset = KvasirDataset(True, shard_num=int(data_path), **kwargs)
self.train_dataset = KvasirDataset(False, shard_num=int(data_path), **kwargs)
except ValueError:
logger.error("Please pass the shard number (integer) for the collaborator using data path flag.")
return
int(data_path)
except:
raise ValueError("Pass shard number using data path flag as an int.")

load_kvasir_dataset()
self.valid_dataset = KvasirDataset(True, shard_num=int(data_path), **kwargs)
self.train_dataset = KvasirDataset(False, shard_num=int(data_path), **kwargs)

self.train_loader = self.get_train_loader()
self.val_loader = self.get_valid_loader()
Expand Down

0 comments on commit 3035e0e

Please sign in to comment.