Skip to content

Commit

Permalink
Merge branch 'develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
payalcha authored Dec 6, 2024
2 parents ea7c509 + 2c7bd71 commit cdce7bb
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 16 deletions.
1 change: 0 additions & 1 deletion openfl-workspace/tf_3dunet_brats/plan/cols.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

collaborators:
- one
10 changes: 0 additions & 10 deletions openfl-workspace/tf_3dunet_brats/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,6 @@ data_loader:
template: src.tf_brats_dataloader.TensorFlowBratsDataLoader
network:
defaults: plan/defaults/network.yaml
settings:
agg_addr: DESKTOP-AOKV1IJ.localdomain
agg_port: auto
cert_folder: cert
client_reconnect_interval: 5
disable_client_auth: false
disable_tls: false
hash_salt: auto
template: openfl.federation.Network
task_runner:
defaults: plan/defaults/task_runner.yaml
settings:
Expand Down Expand Up @@ -80,4 +71,3 @@ tasks:
epochs: 1
metrics:
- loss
num_batches: 1
3 changes: 2 additions & 1 deletion openfl-workspace/tf_3dunet_brats/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
keras==2.13.1
nibabel
numpy

setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
tensorflow>=2
tensorflow==2.13.0
13 changes: 11 additions & 2 deletions openfl-workspace/tf_3dunet_brats/src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,24 @@ def create_file_list(self):
Split into training and testing sets.
"""
searchpath = os.path.join(self.data_path, '*/*_seg.nii.gz')
extension = '_seg.nii.gz'
flair_extension = '_flair.nii.gz'
searchpath = os.path.join(self.data_path, "*/*" + extension)
filenames = tf.io.gfile.glob(searchpath)

# check for uncompressed files
if not filenames:
extension = '_seg.nii'
flair_extension = '_flair.nii'
searchpath = os.path.join(self.data_path, "*/*" + extension)
filenames = tf.io.gfile.glob(searchpath)

# Create a dictionary of tuples with image filename and label filename

self.num_files = len(filenames)
self.filenames = {}
for idx, filename in enumerate(filenames):
self.filenames[idx] = [filename.replace('_seg.nii.gz', '_flair.nii.gz'), filename]
self.filenames[idx] = [filename.replace(extension, flair_extension), filename]

def z_normalize_img(self, img):
"""
Expand Down
4 changes: 2 additions & 2 deletions openfl-workspace/tf_3dunet_brats/src/tf_3dunet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def create_model(self,
initial_filters=initial_filters,
batch_norm=batch_norm)

self.optimizer = tf.keras.optimizers.Adam()
self.optimizer = tf.keras.optimizers.legacy.Adam()

model.compile(
loss=dice_loss,
Expand Down Expand Up @@ -193,7 +193,7 @@ def create_model(self,
)

model.compile(loss=dice_loss,
optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.01),
metrics=[dice_coef, soft_dice_coef]
)

Expand Down

0 comments on commit cdce7bb

Please sign in to comment.