Skip to content

Commit

Permalink
Merge pull request #46 from DroneML/35-support-all-resolutions
Browse files Browse the repository at this point in the history
35 support all image resolutions
  • Loading branch information
cwmeijer authored Jan 27, 2025
2 parents 879f7da + 6bed7e6 commit d9d3c8c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 23 deletions.
58 changes: 48 additions & 10 deletions src/segmentmytif/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from segmentmytif.logging_config import log_duration, log_array
from segmentmytif.utils.io import save_tiff, read_geotiff
from segmentmytif.utils.models import UNet
from torchinfo import summary

NUM_FLAIR_CLASSES = 19
logger = logging.getLogger(__name__)
Expand All @@ -29,7 +28,7 @@ def from_string(s):
def get_features(input_data: np.ndarray, input_path: Path, feature_type: FeatureType, features_path: Path, profile,
**extractor_kwargs):
"""
Extract features from the input data, or load them from disk if they have already been extracted.
:param input_data: 'Raw' input data as stored in TIFs by a GIS user. Shape: [n_bands, height, width]
:param input_path:
:param feature_type: See FeatureType enum for options.
Expand All @@ -44,7 +43,7 @@ def get_features(input_data: np.ndarray, input_path: Path, feature_type: Feature
if features_path is None:
features_path = get_features_path(input_path, feature_type)
if not features_path.exists():
logger.info(f"No existing {feature_type.name} found")
logger.info(f"No existing {feature_type.name} features found at {features_path} for input data with shape {input_data.shape}")
with log_duration(f"Extracting {feature_type.name} features", logger):
features = extract_features(input_data, feature_type, **extractor_kwargs)
log_array(features, logger, array_name=f"{feature_type.name} features")
Expand All @@ -69,21 +68,60 @@ def extract_identity_features(input_data: ndarray) -> ndarray:


def extract_flair_features(input_data: ndarray, model_scale=1.0) -> ndarray:
"""
:param input_data: Array-like input data as stored in TIFs. Shape: [n_bands, height, width]
:param model_scale: Scale of the model to use. Must be one of [1.0, 0.5, 0.25, 0.125]
:return: Features extracted from the input data
"""
logger.info(f"Using UNet at scale {model_scale}")
model, device = load_model(model_scale)
n_bands = input_data.shape[0]

outputs = []
for i_band in range(n_bands):
input_band = torch.from_numpy(input_data[None, i_band:i_band + 1, :, :]).float().to(device)
padded_input = pad(input_band, band_name=i_band)
padded_current_predictions = model(padded_input).detach().numpy()
current_predictions = padded_current_predictions[:, :, :input_band.shape[2], :input_band.shape[3]] #unpad
outputs.append(current_predictions)
output = np.concatenate(outputs, axis=1)
return output[0, :, :, :]


def load_model(model_scale:float):
"""
Load the model from disk and return it along with the device it's loaded on to.
:param model_scale: Scale of the model to use. Must be one of [1.0, 0.5, 0.25, 0.125]
:return: Torch model and the device it's loaded on to
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, num_classes=NUM_FLAIR_CLASSES, model_scale=model_scale)
file_name = get_flair_model_file_name(model_scale)
state = torch.load(Path("models") / file_name, map_location=device, weights_only=True)
model.load_state_dict(state)
model.eval()
n_bands = input_data.shape[0]
return model, device

outputs = []
for i_band in range(n_bands):
current_input_data = torch.from_numpy(input_data[None, i_band:i_band + 1, :, :]).float().to(device)
outputs.append(model(current_input_data).detach().numpy())
output = np.concatenate(outputs, axis=1)
return output[0, :, :, :]

def pad(input_band, band_name):
"""
Pad the input band, single-sided at the end of width and height axis, to make its dimensions divisible by 16.
:param input_band: Input band to pad
:return: Padded input
"""
width = input_band.shape[2]
height = input_band.shape[3]
if width % 16 == 0 and height % 16 == 0:
padded = input_band
else:
pad_width = 16 - width % 16
pad_height = 16 - height % 16
padded = torch.nn.functional.pad(input_band, (0, pad_height, 0, pad_width))
logger.info(f"Added temporary padding for band {band_name}: (original {height} x {width})"
f" -> (padded {height + pad_height} x {width + pad_width})")

return padded


def get_flair_model_file_name(model_scale: float) -> str:
Expand Down
20 changes: 14 additions & 6 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@ def test_extract_identity_features(self):
input_data = np.array(get_generated_multiband_image())
result = extract_features(input_data, FeatureType.IDENTITY)
assert np.array_equal(result, input_data)

def test_extract_flair_features(self):
n_bands = 3
input_data = np.array(get_generated_multiband_image(n_bands=n_bands))
@pytest.mark.parametrize(["n_bands", "width", "height"],
[
(3, 1, 1), # too small to be processed by the model, requires padding
(3, 8, 8), # too small to be processed by the model, requires padding
(3, 16, 16), # smallest size that can natively be processed by the model
(3, 61, 39), # not divisible by 16 so requires padding in both directions
(3, 64, 48), # smallest dimensions, > line above, that don't require padding
(1, 512, 512), # size of the model's training data (easiest case)
(3, 1210, 718), # not divisible by 16 so requires padding in both directions
])
def test_extract_flair_features(self, n_bands, width, height):
input_data = np.array(get_generated_multiband_image(n_bands=n_bands, width=width, height=height))
result = extract_features(input_data, FeatureType.FLAIR, model_scale=0.125)
assert np.array_equal(result.shape, [n_bands * NUM_FLAIR_CLASSES] + list(input_data.shape[1:]))

Expand All @@ -21,8 +29,8 @@ def test_extract_features_unsupported_type(self):
extract_features(input_data, "UNSUPPORTED_TYPE")


def get_generated_multiband_image(n_bands=3):
return np.random.random(size=[n_bands, 512, 512])
def get_generated_multiband_image(n_bands=3, width=512, height=512):
return np.random.random(size=[n_bands, width, height])


@pytest.mark.parametrize(["model_scale", "file_name"],
Expand Down
21 changes: 14 additions & 7 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from contextlib import contextmanager
import shutil
from pathlib import Path

import dask.array as da
Expand All @@ -16,22 +15,29 @@
@pytest.mark.parametrize("test_image, test_labels, feature_type",
[
("test_image.tif", "test_image_labels.tif", FeatureType.IDENTITY),
pytest.param("test_image.tif", "test_image_labels.tif", FeatureType.FLAIR,
marks=pytest.mark.xfail(reason="model can only handle 512x512")),
("test_image.tif", "test_image_labels.tif", FeatureType.FLAIR),
("test_image_512x512.tif", "test_image_labels_512x512.tif", FeatureType.IDENTITY),
("test_image_512x512.tif", "test_image_labels_512x512.tif", FeatureType.FLAIR),
])
def test_integration(tmpdir, test_image, test_labels, feature_type):
input_path = TEST_DATA_FOLDER / test_image
labels_path = TEST_DATA_FOLDER / test_labels
input_path = copy_file_and_get_new_path(test_image, tmpdir)
labels_path = copy_file_and_get_new_path(test_labels, tmpdir)
predictions_path = Path(tmpdir) / f"{test_image}_predictions_{str(feature_type)}.tif"

read_input_and_labels_and_save_predictions(input_path, labels_path, predictions_path, feature_type=feature_type,
read_input_and_labels_and_save_predictions(input_path, labels_path,
predictions_path,
feature_type=feature_type,
model_scale=0.125) # scale down feature-extraction-model for testing

assert predictions_path.exists()


def copy_file_and_get_new_path(test_image, tmpdir):
input_path = Path(tmpdir) / test_image
shutil.copy(TEST_DATA_FOLDER / test_image, input_path)
return input_path


@pytest.mark.parametrize("input_path, feature_type, expected_path", [
("input.tiff", FeatureType.FLAIR, "input_FLAIR.tiff"),
("../path/to/input.tiff", FeatureType.FLAIR, "../path/to/input_FLAIR.tiff"),
Expand Down Expand Up @@ -65,3 +71,4 @@ def test_prepare_training_data(array_type):
input_data = da.from_array(random_data)

prepare_training_data(input_data, labels)

0 comments on commit d9d3c8c

Please sign in to comment.