Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating files for security issues #20

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions BrainMaGe/tester/test_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def postprocess_prediction(seg):


def infer_ma(cfg, device, save_brain, weights):

cfg = os.path.abspath(cfg)

if os.path.isfile(cfg):
Expand Down Expand Up @@ -152,7 +151,7 @@ def infer_ma(cfg, device, save_brain, weights):
int(params["num_classes"]),
int(params["base_filters"]),
)
checkpoint = torch.load(str(params["weights"]),map_location=torch.device(device))
checkpoint = torch.load(str(params["weights"]), map_location=torch.device(device))
model.load_state_dict(checkpoint["model_state_dict"])

if device != "cpu":
Expand Down
167 changes: 167 additions & 0 deletions BrainMaGe/tester/test_ma_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 25 19:18:35 2020

@author: siddhesh
"""

from __future__ import print_function, division
import os
import sys
import time
import pandas as pd
import torch
import nibabel as nib
import tqdm
import numpy as np
from skimage.transform import resize
from skimage.measure import label
from scipy.ndimage.morphology import binary_fill_holes
from BrainMaGe.models.networks import fetch_model
from BrainMaGe.utils import csv_creator_adv
from BrainMaGe.utils.utils_test import (
pad_image,
process_image,
interpolate_image,
padder_and_cropper,
)


def postprocess_prediction(seg):
mask = seg != 0
lbls = label(mask, connectivity=3)
lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
largest_region = np.argmax(lbls_sizes[1:]) + 1
seg[lbls != largest_region] = 0
return seg


def infer_ma(cfg, device, save_brain, weights):
cfg = os.path.abspath(cfg)

if os.path.isfile(cfg):
params_df = pd.read_csv(
cfg,
sep=" = ",
names=["param_name", "param_value"],
comment="#",
skip_blank_lines=True,
engine="python",
).fillna(" ")
else:
print("Missing test_params.cfg file? Please give one!")
sys.exit(0)
params = {}
for i in range(params_df.shape[0]):
params[params_df.iloc[i, 0]] = params_df.iloc[i, 1]
params["weights"] = weights
start = time.asctime()
startstamp = time.time()
print("\nHostname :" + str(os.getenv("HOSTNAME")))
print("\nStart Time :" + str(start))
print("\nStart Stamp:" + str(startstamp))
sys.stdout.flush()

print("Generating Test csv")
if not os.path.exists(os.path.join(params["results_dir"])):
os.mkdir(params["results_dir"])
if not params["csv_provided"] == "True":
print("Since CSV were not provided, we are gonna create for you")
csv_creator_adv.generate_csv(
params["test_dir"],
to_save=params["results_dir"],
mode=params["mode"],
ftype="test",
modalities=params["modalities"],
)
test_csv = os.path.join(params["results_dir"], "test.csv")
else:
test_csv = params["test_csv"]

model = fetch_model(
params["model"],
int(params["num_modalities"]),
int(params["num_classes"]),
int(params["base_filters"]),
)
checkpoint = torch.load(str(params["weights"]), map_location=torch.device(device))
model.load_state_dict(checkpoint["model_state_dict"])

if device != "cpu":
model.cuda()
model.eval()

test_df = pd.read_csv(test_csv)
test_df.ID = test_df.ID.astype(str)
temp_dir = os.path.join(params["results_dir"], "Temp")
os.makedirs(temp_dir, exist_ok=True)

print("Resampling the images to isotropic resolution of 1mm x 1mm x 1mm")
print("Also Converting the images to RAI and brats for smarter use.")

for index, patient in tqdm.tqdm(test_df.iterrows()):
os.makedirs(os.path.join(temp_dir, patient["ID"]), exist_ok=True)
patient_path = patient["Image_path"]

patient_nib = nib.load(patient_path)

image_data = patient_nib.get_fdata()
image = process_image(image_data)
image = resize(
image, (128, 128, 128), order=3, mode="edge", cval=0, anti_aliasing=False
)
image = image[np.newaxis, np.newaxis, ...]
image = torch.FloatTensor(image)

if device != "cpu":
image = image.cuda()

with torch.no_grad():
output = model(image)
output = output.cpu().numpy()[0][0]
to_save = interpolate_image(output, patient_nib.shape)
to_save[to_save >= 0.9] = 1
to_save[to_save < 0.9] = 0
for i in range(to_save.shape[2]):
if np.any(to_save[:, :, i]):
to_save[:, :, i] = binary_fill_holes(to_save[:, :, i])
to_save = postprocess_prediction(to_save).astype(np.uint8)
to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine)

os.makedirs(
os.path.join(params["results_dir"], patient["ID"]), exist_ok=True
)

output_path = os.path.join(
params["results_dir"], patient["ID"], patient["ID"] + "_mask.nii.gz"
)

nib.save(to_save_nib, output_path)

if save_brain:
image = nib.load(patient["Image_path"])
image_data = image.get_fdata()
mask = nib.load(
os.path.join(
params["results_dir"],
patient["ID"],
patient["ID"] + "_mask.nii.gz",
)
)
mask_data = mask.get_fdata().astype(np.int8)
image_data[mask_data == 0] = 0
to_save_brain = nib.Nifti1Image(image_data, image.affine)
nib.save(
to_save_brain,
os.path.join(
params["results_dir"],
patient["ID"],
patient["ID"] + "_brain.nii.gz",
),
)

print("*" * 60)
print("Final output stored in : %s" % (params["results_dir"]))
print("Thank you for using BrainMaGe")
print("*" * 60)
8 changes: 5 additions & 3 deletions BrainMaGe/tester/test_single_multi_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def postprocess_prediction(seg):
return seg


def infer_single_multi_4(input_paths, output_path, weights, mask_path=None, device="cpu"):
def infer_single_multi_4(
input_paths, output_path, weights, mask_path=None, device="cpu"
):
"""
Inference using multi modality network

Expand Down Expand Up @@ -69,7 +71,7 @@ def infer_single_multi_4(input_paths, output_path, weights, mask_path=None, devi
num_filters=16,
)

checkpoint = torch.load(str(weights), map_location=torch.device('cpu'))
checkpoint = torch.load(str(weights), map_location=torch.device("cpu"))
model.load_state_dict(checkpoint["model_state_dict"])

if device != "cpu":
Expand Down Expand Up @@ -105,7 +107,7 @@ def infer_single_multi_4(input_paths, output_path, weights, mask_path=None, devi
print("Done with running the model.")

if mask_path is not None:
raise NotImplementedError('Sorry, masking is not implemented (yet).')
raise NotImplementedError("Sorry, masking is not implemented (yet).")

print("Final output stored in : %s" % (output_path))
print("Thank you for using BrainMaGe")
Expand Down
32 changes: 2 additions & 30 deletions BrainMaGe/trainer/lightning_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
@author: siddhesh
"""

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as ptl
import torch
from BrainMaGe.models.networks import fetch_model
from BrainMaGe.utils.cyclicLR import CyclicCosAnnealingLR
from BrainMaGe.utils.losses import dice_loss, dice
from BrainMaGe.utils.data import SkullStripDataset
from BrainMaGe.utils.losses import dice, dice_loss
from BrainMaGe.utils.optimizers import fetch_optimizer


Expand Down Expand Up @@ -79,29 +77,3 @@ def configure_optimizers(self):
eta_min=5e-6,
)
return [optimizer], [scheduler]

@ptl.data_loader
def train_dataloader(self):
dataset_train = SkullStripDataset(
self.params["train_csv"], self.params, test=False
)
return DataLoader(
dataset_train,
batch_size=int(self.params["batch_size"]),
shuffle=True,
num_workers=4,
pin_memory=True,
)

@ptl.data_loader
def val_dataloader(self):
dataset_valid = SkullStripDataset(
self.params["validation_csv"], self.params, test=False
)
return DataLoader(
dataset_valid,
batch_size=int(self.params["batch_size"]),
shuffle=False,
num_workers=4,
pin_memory=True,
)
38 changes: 30 additions & 8 deletions BrainMaGe/trainer/trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@


import os
import time
import sys
import torch
import time

import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from BrainMaGe.utils.csv_creator_adv import generate_csv
import torch
from BrainMaGe.trainer.lightning_networks import SkullStripper
from BrainMaGe.utils.csv_creator_adv import generate_csv
from BrainMaGe.utils.data import SkullStripDataset
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader


def train_network(cfg, device, weights):
Expand Down Expand Up @@ -99,8 +102,8 @@ def train_network(cfg, device, weights):
print("Using device:", device)
if device.type == "cuda":
print("Memory Usage:")
print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), "GB")
print("Cached: ", round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), "GB")
print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024**3, 1), "GB")
print("Cached: ", round(torch.cuda.memory_cached(0) / 1024**3, 1), "GB")
sys.stdout.flush()

# We generate CSV for training if not provided
Expand Down Expand Up @@ -163,4 +166,23 @@ def train_network(cfg, device, weights):
num_sanity_val_steps=5,
resume_from_checkpoint=res_ckpt,
)
trainer.fit(model)

dataset_train = SkullStripDataset(params["train_csv"], params, test=False)
train_dataloader = DataLoader(
dataset_train,
batch_size=int(params["batch_size"]),
shuffle=True,
num_workers=4,
pin_memory=True,
)

dataset_valid = SkullStripDataset(params["validation_csv"], params, test=False)
val_dataloader = DataLoader(
dataset_valid,
batch_size=int(params["batch_size"]),
shuffle=False,
num_workers=4,
pin_memory=True,
)

trainer.fit(model, train_dataloader, val_dataloader)
1 change: 0 additions & 1 deletion BrainMaGe/utils/convert_ckpt_to_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch

if __name__ == "__main__":

parser = ArgumentParser(description="Convert the .ckpt files to .pt files")
parser.add_argument(
"-i",
Expand Down
2 changes: 1 addition & 1 deletion brain_mage_run
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ if __name__ == "__main__":
trainer_main.train_network(params_file, DEVICE, weights)
elif args.test == "True":
if args.mode.lower() == "ma" or args.mode.lower() == "bids":
test_ma.infer_ma(params_file, DEVICE, args.save_brain, weights)
test_ma_new.infer_ma(params_file, DEVICE, args.save_brain, weights)
elif args.mode.lower() == "multi-4":
test_multi_4.infer_multi_4(params_file, DEVICE, args.save_brain, weights)
else:
Expand Down
4 changes: 2 additions & 2 deletions brain_mage_single_run
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ if __name__ == "__main__":
if os.path.isdir(base_dir):
weights = os.path.join(base_dir, "resunet_ma.pt")
else:
# this control path is needed if someone installs brainmage into their virtual environment directly
# this control path is needed if someone installs brainmage into their virtual environment directly
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
base_dir = os.path.join(os.path.dirname(base_dir), "BrainMaGe/weights")
if os.path.isdir(base_dir):
Expand All @@ -108,7 +108,7 @@ if __name__ == "__main__":
if os.path.isfile(weights):
print("Weight file used :", weights)
else:
sys.exit('Weights file at \'' + weights + '\' was not found...')
sys.exit("Weights file at '" + weights + "' was not found...")

# Running Inference
test_single_run.infer_single_ma(
Expand Down
6 changes: 3 additions & 3 deletions brain_mage_single_run_multi_4
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ if __name__ == "__main__":

args = parser.parse_args()

input_paths = args.input_paths.split(',')
input_paths = args.input_paths.split(",")
output_path = args.output_path
mask_path = args.mask_path
DEVICE = args.device
Expand All @@ -98,7 +98,7 @@ if __name__ == "__main__":
if os.path.isdir(base_dir):
weights = os.path.join(base_dir, "resunet_multi_4.pt")
else:
# this control path is needed if someone installs brainmage into their virtual environment directly
# this control path is needed if someone installs brainmage into their virtual environment directly
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
base_dir = os.path.join(os.path.dirname(base_dir), "BrainMaGe/weights")
if os.path.isdir(base_dir):
Expand All @@ -107,7 +107,7 @@ if __name__ == "__main__":
if os.path.isfile(weights):
print("Weight file used :", weights)
else:
sys.exit('Weights file at \'' + weights + '\' was not found...')
sys.exit("Weights file at '" + weights + "' was not found...")

# Running Inference
test_single_multi_4.infer_single_multi_4(
Expand Down
Loading