Skip to content

Commit

Permalink
Add CLI to specify where to store temporary memmaps
Browse files Browse the repository at this point in the history
  • Loading branch information
miquelmassot committed Nov 22, 2023
1 parent a04ec63 commit e992c19
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 14 deletions.
20 changes: 19 additions & 1 deletion src/correct_images/correct_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ def main(args=None):
default="",
help="Expected suffix for correct_images configuration and output folders.",
)
subparser_parse.add_argument(
"--memmap-location",
dest="memmap_location",
default=".",
help="Location for memmap files. Defaults to current directory.",
)
subparser_parse.add_argument(
"--memmap-size-gb",
dest="memmap_size_gb",
default=50,
type=int,
help="Size of individual memmap files in GB. Defaults to 50.",
)
subparser_parse.set_defaults(func=call_parse)

# subparser process
Expand Down Expand Up @@ -244,7 +257,12 @@ def call_parse(args):
"parse", args.force, args.suffix, camera, correct_config=None
)
# call new list-compatible implementation of parse()
corrector.parse(path_list, correct_config_list)
corrector.parse(
path_list,
correct_config_list,
args.memmap_location,
args.memmap_size_gb,
)
corrector.cleanup()

Console.info(
Expand Down
18 changes: 13 additions & 5 deletions src/correct_images/corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,11 @@ def load_configuration(self, correct_config=None):
else:
self.loader.set_loader("default")

def parse(self, path_list, correct_config_list):
def parse(
self, path_list, correct_config_list, memmap_location=".", memmap_size_gb=50.0
):
self.memmap_location = memmap_location
self.memmap_size_gb = memmap_size_gb
# both path_list and correct_config_list are assumed to be valid + equivalent
for i in range(len(path_list)): # for each dive
path = path_list[i]
Expand Down Expand Up @@ -669,7 +673,9 @@ def get_altitude_and_depth_maps(self):
return
elif self.distance_metric == "depth_map":
# Load depth maps
path_depth = self.path_processed/"3d_reconstruction/depth_maps/gp_smoothed"
path_depth = (
self.path_processed / "3d_reconstruction/depth_maps/gp_smoothed"
)
if not path_depth.exists():
Console.quit("Depth maps folder", path_depth, "does not exist.")
images_to_drop = []
Expand Down Expand Up @@ -747,8 +753,7 @@ def generate_attenuation_correction_parameters(self):
* 4.0
/ (1024.0**3)
)
max_bin_size_gb = 50.0
max_bin_size = int(max_bin_size_gb / image_size_gb)
max_bin_size = int(self.memmap_size_gb / image_size_gb)

self.bin_band = 0.1
hist_bins = np.arange(
Expand Down Expand Up @@ -831,7 +836,7 @@ def generate_attenuation_correction_parameters(self):
images_map,
distances_map,
max_bin_size,
max_bin_size_gb,
self.memmap_size_gb,
distance_vector,
)
for idx_bin in range(hist_bins.size - 1)
Expand Down Expand Up @@ -1143,6 +1148,7 @@ def compute_distance_bin(
memmap_filename, memmap_handle = create_memmap(
bin_images,
dimensions,
self.memmap_location,
loader=self.loader,
)
self.memmaps_to_remove.append(memmap_filename)
Expand All @@ -1153,6 +1159,7 @@ def compute_distance_bin(
memmap_filename, memmap_handle = create_memmap(
bin_images,
dimensions,
self.memmap_location,
loader=self.loader,
)
self.memmaps_to_remove.append(memmap_filename)
Expand All @@ -1163,6 +1170,7 @@ def compute_distance_bin(
memmap_filename, memmap_handle = create_memmap(
bin_images,
dimensions,
self.memmap_location,
loader=self.loader,
)
self.memmaps_to_remove.append(memmap_filename)
Expand Down
21 changes: 13 additions & 8 deletions src/correct_images/tools/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import uuid
from datetime import datetime
from pathlib import Path

import cv2
import numpy as np
Expand All @@ -17,7 +18,7 @@
from ..loaders import default


def create_memmap_name() -> str:
def create_memmap_name(memmap_location) -> str:
filename_map = (
"memmap_"
+ datetime.now().strftime("%Y%m%d_%H%M%S_")
Expand All @@ -27,11 +28,15 @@ def create_memmap_name() -> str:
+ str(uuid.uuid4())
+ ".map"
)
return filename_map
memmap_location = Path(memmap_location)
if not memmap_location.exists():
memmap_location.mkdir(parents=True)
filename_map = memmap_location / filename_map
return str(filename_map)


def create_memmap(image_list, dimensions, loader=default.loader):
filename_map = create_memmap_name()
def create_memmap(image_list, dimensions, memmap_location, loader=default.loader):
filename_map = create_memmap_name(memmap_location)
Console.info("Creating memmap at", filename_map)
# If only 1 channel, do not create a 3D array
if dimensions[-1] == 1:
Expand All @@ -55,15 +60,15 @@ def create_memmap(image_list, dimensions, loader=default.loader):
return filename_map, image_memmap


def open_memmap(shape, dtype):
filename_map = create_memmap_name()
def open_memmap(shape, dtype, memmap_location):
filename_map = create_memmap_name(memmap_location)
Console.info("Creating memmap (open_memmap) at", filename_map)
image_memmap = np.memmap(filename=filename_map, mode="w+", shape=shape, dtype=dtype)
return filename_map, image_memmap


def convert_to_memmap(array, loader=default.loader):
filename_map = create_memmap_name()
def convert_to_memmap(array, memmap_location, loader=default.loader):
filename_map = create_memmap_name(memmap_location)
Console.info("Creating memmap (convert_to_memmap) at", filename_map)
image_memmap = np.memmap(
filename=filename_map, mode="w+", shape=array.shape, dtype=array.dtype
Expand Down

0 comments on commit e992c19

Please sign in to comment.