Skip to content

Commit

Permalink
Merge branch '00023' of github.com:RichardObi/medigan into 00023
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardObi committed Nov 29, 2023
2 parents e0cb2d2 + fe48dc8 commit 9427fd5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
8 changes: 4 additions & 4 deletions tests/model_integration_test_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging

MODEL_ID = "YOUR_MODEL_ID_HERE"
MODEL_ID = 23 #"00023_PIX2PIXHD_BREAST_DCEMRI" #"00002_DCGAN_MMG_MASS_ROI" # "00007_BEZIERCURVE_TUMOUR_MASK"
MODEL_ID = 23 # "00023_PIX2PIXHD_BREAST_DCEMRI" #"00002_DCGAN_MMG_MASS_ROI" # "00007_BEZIERCURVE_TUMOUR_MASK"
NUM_SAMPLES = 2
OUTPUT_PATH = f"output/{MODEL_ID}/"
try:
Expand All @@ -22,10 +22,9 @@
num_samples=NUM_SAMPLES,
output_path=OUTPUT_PATH,
input_path="input/",
gpu_id= 0,
gpu_id=0,
image_size=448,
install_dependencies=True,

)

data_loader = generators.get_as_torch_dataloader(
Expand All @@ -42,7 +41,8 @@

if len(data_loader) != NUM_SAMPLES:
logging.warning(
f"{MODEL_ID}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={NUM_SAMPLES}).")
f"{MODEL_ID}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={NUM_SAMPLES})."
)

#### Get the object at index 0 from the dataloader
data_dict = next(iter(data_loader))
Expand Down
6 changes: 4 additions & 2 deletions tests/test_model_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,10 @@ def test_get_dataloader_method(self, model_id):
self.logger.debug(f"{model_id}: len(data_loader): {len(data_loader)}")

if len(data_loader) != self.num_samples:
logging.warning(f"{model_id}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={self.num_samples}). "
f"Hint: Revise if the model's internal generate() function returned tuples as required in get_as_torch_dataloader().")
logging.warning(
f"{model_id}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={self.num_samples}). "
f"Hint: Revise if the model's internal generate() function returned tuples as required in get_as_torch_dataloader()."
)

#### Get the object at index 0 from the dataloader
data_dict = next(iter(data_loader))
Expand Down

0 comments on commit 9427fd5

Please sign in to comment.