Skip to content

Commit

Permalink
Merge pull request #1037 from theakshaypant/akshay/fix-read_shard_fro…
Browse files Browse the repository at this point in the history
…m_default_path

Modifying the `shard_num` read for templates to indicate that `data_path` flag needs to be passed as an integer
  • Loading branch information
rahulga1 authored Nov 5, 2024
2 parents 648500d + aae8185 commit 1fad044
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 11 deletions.
8 changes: 8 additions & 0 deletions openfl-workspace/keras_cnn_mnist/src/tfmnist_inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def __init__(self, data_path, batch_size, **kwargs):
# what index/rank is this collaborator.
# Then we have a way to automatically shard based on rank and size of
# collaborator list.
try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,18 @@ def __init__(self, data_path, batch_size, **kwargs):
# what index/rank is this collaborator.
# Then we have a way to automatically shard based on rank and size of
# collaborator list.
try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
Expand Down
13 changes: 11 additions & 2 deletions openfl-workspace/torch_cnn_histology/src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,18 @@ def __init__(self, data_path, batch_size, **kwargs):
"""
super().__init__(batch_size, random_seed=0, **kwargs)

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs)
try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,18 @@ def __init__(self, data_path, batch_size, **kwargs):
"""
super().__init__(batch_size, random_seed=0, **kwargs)

try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs)
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
Expand Down
10 changes: 9 additions & 1 deletion openfl-workspace/torch_cnn_mnist/src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ def __init__(self, data_path, batch_size, **kwargs):
"""
super().__init__(batch_size, **kwargs)

try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,18 @@ def __init__(self, data_path, batch_size, **kwargs):
# Then we have a way to automatically shard based on rank and size
# of collaborator list.

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs)
try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
13 changes: 11 additions & 2 deletions openfl-workspace/torch_cnn_mnist_fed_eval/src/ptmnist_inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,18 @@ def __init__(self, data_path, batch_size, **kwargs):
# Then we have a way to automatically shard based on rank and size
# of collaborator list.

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs)
try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ def __init__(self, data_path, batch_size, **kwargs):
# what index/rank is this collaborator.
# Then we have a way to automatically shard based on rank and size
# of collaborator list.
try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs)

shard_num=int(data_path), **kwargs
)
self.X_train = X_train
self.y_train = y_train
self.train_loader = self.get_train_loader()
Expand Down
10 changes: 10 additions & 0 deletions openfl-workspace/torch_unet_kvasir/src/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from os import listdir
from pathlib import Path


import numpy as np
import PIL
from skimage import io
Expand Down Expand Up @@ -121,6 +122,15 @@ def __init__(self, data_path, batch_size, **kwargs):
"""
super().__init__(batch_size, **kwargs)

try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

load_kvasir_dataset()
self.valid_dataset = KvasirDataset(True, shard_num=int(data_path), **kwargs)
self.train_dataset = KvasirDataset(False, shard_num=int(data_path), **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from os import listdir
from pathlib import Path


import numpy as np
import PIL
from skimage import io
Expand Down Expand Up @@ -121,9 +122,19 @@ def __init__(self, data_path, batch_size, **kwargs):
"""
super().__init__(batch_size, **kwargs)

try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

load_kvasir_dataset()
self.valid_dataset = KvasirDataset(True, shard_num=int(data_path), **kwargs)
self.train_dataset = KvasirDataset(False, shard_num=int(data_path), **kwargs)

self.train_loader = self.get_train_loader()
self.val_loader = self.get_valid_loader()
self.batch_size = batch_size
Expand Down

0 comments on commit 1fad044

Please sign in to comment.