diff --git a/tests/model_integration_test_manual.py b/tests/model_integration_test_manual.py index ed5b931..8decfa9 100644 --- a/tests/model_integration_test_manual.py +++ b/tests/model_integration_test_manual.py @@ -6,7 +6,7 @@ import logging MODEL_ID = "YOUR_MODEL_ID_HERE" -MODEL_ID = 23 #"00023_PIX2PIXHD_BREAST_DCEMRI" #"00002_DCGAN_MMG_MASS_ROI" # "00007_BEZIERCURVE_TUMOUR_MASK" +MODEL_ID = 23 # "00023_PIX2PIXHD_BREAST_DCEMRI" #"00002_DCGAN_MMG_MASS_ROI" # "00007_BEZIERCURVE_TUMOUR_MASK" NUM_SAMPLES = 2 OUTPUT_PATH = f"output/{MODEL_ID}/" try: @@ -22,10 +22,9 @@ num_samples=NUM_SAMPLES, output_path=OUTPUT_PATH, input_path="input/", - gpu_id= 0, + gpu_id=0, image_size=448, install_dependencies=True, - ) data_loader = generators.get_as_torch_dataloader( @@ -42,7 +41,8 @@ if len(data_loader) != NUM_SAMPLES: logging.warning( - f"{MODEL_ID}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={NUM_SAMPLES}).") + f"{MODEL_ID}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={NUM_SAMPLES})." + ) #### Get the object at index 0 from the dataloader data_dict = next(iter(data_loader)) diff --git a/tests/test_model_executor.py b/tests/test_model_executor.py index d3327ed..82a23a8 100644 --- a/tests/test_model_executor.py +++ b/tests/test_model_executor.py @@ -264,8 +264,10 @@ def test_get_dataloader_method(self, model_id): self.logger.debug(f"{model_id}: len(data_loader): {len(data_loader)}") if len(data_loader) != self.num_samples: - logging.warning(f"{model_id}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={self.num_samples}). " - f"Hint: Revise if the model's internal generate() function returned tuples as required in get_as_torch_dataloader().") + logging.warning( + f"{model_id}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={self.num_samples}). " + f"Hint: Revise if the model's internal generate() function returned tuples as required in get_as_torch_dataloader()." + ) #### Get the object at index 0 from the dataloader data_dict = next(iter(data_loader))