diff --git a/fmcib/run.py b/fmcib/run.py index 3da4711..558462e 100644 --- a/fmcib/run.py +++ b/fmcib/run.py @@ -9,12 +9,7 @@ from .preprocessing import get_dataloader -def get_features( - csv_path, - weights_path=None, - spatial_size=(50, 50, 50), - precropped=False, -): +def get_features(csv_path, weights_path=None, spatial_size=(50, 50, 50), precropped=False, **kwargs): """ Extracts features from images specified in a CSV file. @@ -23,13 +18,13 @@ def get_features( weights_path (str, optional): Path to the pre-trained weights file. Default is None. spatial_size (tuple, optional): Spatial size of the input images. Default is (50, 50, 50). precropped (bool, optional): Whether the images are already pre-cropped. Default is False. - + **kwargs: Additional arguments to be passed to the dataloader. Returns: pandas.DataFrame: DataFrame containing the original data from the CSV file along with the extracted features. """ logger.info("Loading CSV file ...") df = pd.read_csv(csv_path) - dataloader = get_dataloader(csv_path, spatial_size=spatial_size, precropped=precropped) + dataloader = get_dataloader(csv_path, spatial_size=spatial_size, precropped=precropped, **kwargs) device = "cuda" if torch.cuda.is_available() else "cpu" if weights_path is None: