From 4b3361ba46b8908bb11b0c8971c4137d6011c673 Mon Sep 17 00:00:00 2001 From: Cesar Luis Aybar Camacho Date: Mon, 2 Dec 2024 14:09:04 +0100 Subject: [PATCH] up --- .gitignore | 3 +- README.md | 136 +++++++++++------- supers2/__init__.py | 6 +- supers2/main.py | 14 +- supers2/utils.py | 55 ++++++++ supers2/xai/lam.py | 335 ++++++++++++++------------------------------ 6 files changed, 257 insertions(+), 292 deletions(-) diff --git a/.gitignore b/.gitignore index 0849c1f..0a5d6c2 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -demo.py \ No newline at end of file +demo.py +demo2.py diff --git a/README.md b/README.md index 180fdb1..27ed05e 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

- A Python package for enhancing the spatial resolution of Sentinel-2 satellite images to 2.5 meters πŸš€ + A Python package for enhancing the spatial resolution of Sentinel-2 satellite images up to 2.5 meters πŸš€

@@ -38,35 +38,28 @@ ## **Overview** πŸ“Š -**supers2** is a Python package designed to enhance the spatial resolution of Sentinel-2 satellite images to 2.5 meters using advanced neural network models. It facilitates downloading (cubo package), preparing, and processing the Sentinel-2 data and applies deep learning models to enhance the spatial resolution of the imagery. +**supers2** is a Python package designed to enhance the spatial resolution of Sentinel-2 satellite images to 2.5 meters using a set of neural network models. ## **Installation** βš™οΈ Install the latest version from PyPI: ```bash -pip install cubo supers2 +pip install supers2 ``` ## **How to use** πŸ› οΈ ### **Basic usage: enhancing spatial resolution of Sentinel-2 images** 🌍 -#### **Load libraries** - ```python -import cubo +import matplotlib.pyplot as plt import numpy as np -import torch - import supers2 +import torch +import cubo -``` - -#### **Download Sentinel-2 L2A cube** - -```python -# Create a Sentinel-2 L2A data cube for a specific location and date range +## Download Sentinel-2 L2A cube da = cubo.create( lat=4.31, lon=-76.2, @@ -77,19 +70,7 @@ da = cubo.create( edge_size=128, resolution=10 ) -``` - -#### **Prepare the data (CPU and GPU usage)** - -When converting the NumPy array to a PyTorch tensor, the use of `cuda()` is optional and depends on whether the user has access to a GPU. Below is the explanation for both cases: - -- **GPU:** If a GPU is available and CUDA is installed, you can transfer the tensor to the GPU using `.cuda()`. This improves the processing speed, especially for large datasets or deep learning models. - -- **CPU:** If no GPU is available, the tensor will be processed on the CPU, which is the default behavior in PyTorch. In this case, simply omit the `.cuda()` call. - -Here’s how you can handle both scenarios dynamically: -```python # Convert the data array to NumPy and scale original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32") @@ -98,48 +79,101 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create the tensor and move it to the appropriate device (CPU or GPU) X = torch.from_numpy(original_s2_numpy).float().to(device) + +# Set up the model to enhance the spatial resolution +models = supers2.setmodel(device=device) + +# Apply spatial resolution enhancement +superX = supers2.predict(X, models=models, resolution="2.5m") + +# Visualize the results +# Plot the original and enhanced-resolution images +fig, ax = plt.subplots(1, 2, figsize=(10, 5)) +ax[0].imshow(X[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4) +ax[0].set_title("Original S2") +ax[1].imshow(superX[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4) +ax[1].set_title("Enhanced Resolution S2") +plt.show() ``` -#### **Define the resolution enhancement model** +

+ +

+ + +## Chante the model settings πŸ› οΈ + +At the end of the document, you can find a table with the available models and their characteristics. + ```python # Set up the model to enhance the spatial resolution models = supers2.setmodel( - SR_model_loss="l1", - SR_model_name="cnn", - SR_model_size="small", - Fusionx2_model_size="lightweight", - Fusionx4_model_size="lightweight" + resolution = "2.5m", # Set the desired resolution + sr_model_snippet = "sr__opensrbaseline__cnn__medium__l1", # RGBN model from 10m to 2.5m + fusionx2_model_snippet = "fusionx2__opensrbaseline__cnn__large__l1", # RedESWIR model from 20m to 10m + fusionx4_model_snippet = "fusionx4__opensrbaseline__cnn__large__l1", #RedESWIR model from 10m to 2.5m + weights_path = None, # Path to the weights file + device = "cpu" # Use the CPU ) -``` -### **Apply spatial resolution enhancement** -```python -# Apply the model to enhance the image resolution to 2.5 meters +# Apply spatial resolution enhancement superX = supers2.predict(X, models=models, resolution="2.5m") ``` -### **Visualize the results** 🎨 +### **Predict only RGBNIR bands** 🌍 + +```python +superX = supers2.predict_rgbnir(X[[2, 1, 0, 6]]) +``` -#### **Display images** +### **Estimate the uncertainty of the model** πŸ“Š ```python -import matplotlib.pyplot as plt +from supers2.trained_models import SRmodels -# Plot the original and enhanced-resolution images +# Get the available models +models = list(SRmodels.model_dump()["object"].keys()) + +# Get only swin transformer models +swin2sr_models = [model for model in models if "swin" in model] + +map_mean, map_std = supers2.uncertainty( + X[[2, 1, 0, 6]], + models=swin2sr_models +) + +# Visualize the uncertainty fig, ax = plt.subplots(1, 2, figsize=(10, 5)) -ax[0].imshow(X[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4) -ax[0].set_title("Original S2") -ax[1].imshow(superX[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4) -ax[1].set_title("Enhanced Resolution S2") +ax[0].imshow(mean_map[0:3].cpu().numpy().transpose(1, 2, 0)*3) +ax[0].set_title("Mean") +ax[1].imshow(std_map[0:3].cpu().numpy().transpose(1, 2, 0)*100) +ax[1].set_title("Standard Deviation") plt.show() ``` -

- -

-## **Supported features and filters** ✨ +### Estimate the Local Attention Map of the model πŸ“Š + + +```python +kde_map, complexity_metric, robustness_metric, robustness_vector = supers2.lam( + X=X[[2, 1, 0, 6]].cpu(), # The input tensor + model=models.srx4, # The SR model + h=240, # The height of the window + w=240, # The width of the window + window=128, # The window size + scales = ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"] +) + +# Visualize the results +plt.imshow(kde_map) +plt.title("Kernel Density Estimation") +plt.show() + +plt.plot(robustness_vector) +plt.title("Robustness Vector") +plt.show() +``` + -- **Enhance spatial resolution to 2.5 meters:** Use advanced CNN models to enhance Sentinel-2 imagery. -- **Neural network-based approach:** Integration of multiple model sizes to fit different computing needs (small, lightweight). -- **Python integration:** Easily interact with data cubes through the Python API, supporting seamless workflows. \ No newline at end of file +### Use the opensr-test and supers2 to analyze the hallucination pixels πŸ“Š diff --git a/supers2/__init__.py b/supers2/__init__.py index 9ba3302..d9b99ad 100644 --- a/supers2/__init__.py +++ b/supers2/__init__.py @@ -1,2 +1,6 @@ -from supers2.main import predict, setmodel, predict_large, predict_rgbnir +from supers2.main import predict, setmodel, predict_large, predict_rgbnir, uncertainty from supers2.xai.lam import lam +from supers2.trained_models import SRmodels + +models = list(SRmodels.model_dump()["object"].keys()) + diff --git a/supers2/main.py b/supers2/main.py index ee1aa2a..c41b2f0 100644 --- a/supers2/main.py +++ b/supers2/main.py @@ -8,10 +8,9 @@ from supers2.dataclass import SRexperiment from supers2.setup import load_model -from supers2.utils import define_iteration +from supers2.utils import define_iteration, gdal_create from supers2.trained_models import SRmodels - def setmodel( resolution: Literal["2.5m", "5m", "10m"] = "2.5m", sr_model_snippet: str = "sr__opensrbaseline__cnn__lightweight__l1", @@ -282,11 +281,7 @@ def predict_large( # Create the output image with rio.open(output_fullname, "w", **output_metadata) as dst: - data_np = np.zeros( - (metadata["count"], metadata["height"] * res_n, metadata["width"] * res_n), - dtype=np.uint16, - ) - dst.write(data_np) + pass # Check if the models are loaded if models is None: @@ -295,8 +290,7 @@ def predict_large( # Iterate over the image with rio.open(output_fullname, "r+") as dst: with rio.open(image_fullname) as src: - for index, point in enumerate(tqdm.tqdm(nruns)): - + for index, point in enumerate(tqdm.tqdm(nruns)): # Read a block of the image window = rio.windows.Window(point[1], point[0], 128, 128) X = torch.from_numpy(src.read(window=window)).float().to(device) @@ -431,7 +425,7 @@ def uncertainty( ) # Run the model - X_torch = torch.from_numpy((X / 10_000)).float().to(device) + X_torch = X.float().to(device) prediction = model_object(X_torch[None]).squeeze().cpu() # Store the prediction diff --git a/supers2/utils.py b/supers2/utils.py index fe493d4..0b2c6c3 100644 --- a/supers2/utils.py +++ b/supers2/utils.py @@ -1,4 +1,8 @@ import itertools +import rasterio +import pathlib + +from rasterio.crs import CRS def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0): @@ -57,3 +61,54 @@ def fix_lastchunk(iterchunks, s2dim, chunk_size): itercontainer.append((index_i, index_j)) return itercontainer + + +def gdal_create( + outfilename: str, + dtype: str = 'uint16', + driver: str = 'GTiff', + count: int = 13, + width: int = 5120, + height: int = 5120, + nodata: int = 65535, + crs: int = 4326, + affine: tuple = (-180, 0.5, 90, -0.5), + **kwargs, +) -> pathlib.Path: + """ + Fast creation of a new raster file using rasterio. + + Args: + outfilename (str): Output filename. + dtype (str): Data type of the raster. + driver (str): GDAL driver to use. + count (int): Number of bands in the raster. + width (int): Width of the raster. + height (int): Height of the raster. + nodata (int): NoData value. + crs (int): EPSG code of the raster. + affine (tuple): Affine transformation of the raster. + + + Returns: + pathlib.Path: Path to the created raster file. + """ + # Define the metadata for the new file + meta = { + 'driver': driver, + 'dtype': dtype, + 'nodata': nodata, + 'width': width, + 'height': height, + 'count': count, + 'crs': CRS.from_epsg(crs), + 'transform': rasterio.transform.from_origin(*affine), + } + + # Merge the metadata with the additional kwargs + meta.update(kwargs) + + with rasterio.open(outfilename, 'w', **meta) as dst: + pass + + return pathlib.Path(outfilename) \ No newline at end of file diff --git a/supers2/xai/lam.py b/supers2/xai/lam.py index a26888a..07b7130 100644 --- a/supers2/xai/lam.py +++ b/supers2/xai/lam.py @@ -1,185 +1,13 @@ -from typing import Optional +from typing import Optional, Tuple import numpy as np import torch import torch.nn.functional as F +from tqdm import tqdm from supers2.xai.utils import gini, vis_saliency_kde -def GaussianBlurPath(sigma: float, fold: int, kernel_size: int = 5): - """ - Generates a function for applying a Gaussian blur path to an image using PyTorch. - The function applies progressively weaker Gaussian blurs to an image and calculates - interpolations between each blurred image, along with derivatives for each step. - - Args: - sigma (float): Initial standard deviation for the Gaussian blur. - fold (int): Number of interpolation steps for the blurring path. - kernel_size (int, optional): Size of the Gaussian kernel. Defaults to 5. - - Returns: - Callable: A function that takes an image and returns a tuple: - - image_interpolation (torch.Tensor): Interpolated blurred images. - - lambda_derivative_interpolation (torch.Tensor): Derivatives of interpolated images. - """ - - def path_interpolation_func(torch_image: torch.Tensor): - """ - Applies the Gaussian blur path to the input image and computes the interpolated images - and their derivatives using PyTorch. - - Args: - torch_image (torch.Tensor): Input image as a torch tensor (channels, height, width). - - Returns: - tuple: Interpolated blurred images and their derivatives along the Gaussian path. - """ - device = torch_image.device - - # Ensure image is 4D (batch, channels, height, width) - torch_image = torch_image.unsqueeze(0) if torch_image.ndim == 3 else torch_image - torch_image = torch_image.to(device) - channels = torch_image.shape[1] - - # Initialize tensors for blurred images and derivatives - image_interpolation = torch.zeros( - (fold, *torch_image.shape[1:]), dtype=torch.float32 - ) - image_interpolation = image_interpolation.to(device) - lambda_derivative_interpolation = torch.zeros_like(image_interpolation) - lambda_derivative_interpolation = lambda_derivative_interpolation.to(device) - kernel_interpolation = torch.zeros( - (fold + 1, channels, kernel_size, kernel_size), dtype=torch.float32 - ) - kernel_interpolation = kernel_interpolation.to(device) - - # Linearly interpolate sigma values from initial to zero - sigma_interpolation = np.linspace(sigma, 0, fold + 1) - - # Create Gaussian kernels for each sigma value - for i in range(fold + 1): - kernel_interpolation[i] = isotropic_gaussian_kernel_torch( - kernel_size, sigma_interpolation[i] - ).squeeze() - - # Calculate padding size - pad_size = kernel_interpolation.shape[-1] // 2 - - # Create Gaussian kernels for each sigma and apply to image - for i in range(fold): - # Apply reflect padding first - padded_image = F.pad( - torch_image, (pad_size, pad_size, pad_size, pad_size), mode="reflect" - ) - - # Store the current blurred image - image_interpolation[i] = F.conv2d( - padded_image, kernel_interpolation[i + 1][:, None], groups=channels - ).squeeze(0) - - # Calculate derivative with respect to lambda - diff_kernel = (kernel_interpolation[i + 1] - kernel_interpolation[i])[ - :, None - ] * fold - lambda_derivative_interpolation[i] = F.conv2d( - padded_image, diff_kernel, groups=channels - ).squeeze(0) - - return ( - image_interpolation, - lambda_derivative_interpolation, - kernel_interpolation, - ) - - return path_interpolation_func - - -def isotropic_gaussian_kernel_torch( - size: int, sigma: float, epsilon: float = 1e-5 -) -> torch.Tensor: - """ - Generates an isotropic Gaussian kernel in PyTorch. - - Args: - size (int): Size of the kernel (size x size). - sigma (float): Standard deviation of the Gaussian distribution. - epsilon (float, optional): Small constant to avoid division by zero. Defaults to 1e-5. - - Returns: - torch.Tensor: Normalized Gaussian kernel. - """ - ax = torch.arange(-size // 2 + 1.0, size // 2 + 1.0) - xx, yy = torch.meshgrid(ax, ax, indexing="ij") - - # Calculate Gaussian function and normalize - kernel = torch.exp(-(xx**2 + yy**2) / (2.0 * (sigma + epsilon) ** 2)) - kernel = kernel / kernel.sum() - - # Reshape for 2D convolution (out_channels, in_channels, height, width) - return kernel.unsqueeze(0).unsqueeze(0) - - -def Path_gradient( - image: torch.Tensor, - model: torch.nn.Module, - attr_objective: callable, - path_interpolation_func: callable, -): - """ - Computes the path gradient for an image using a specified model and attribution objective. - The function calculates gradients for a series of interpolated images produced by a path - interpolation function. - - Args: - numpy_image (np.ndarray): Input image of shape (channels, height, width). - model (torch.nn.Module): The model to compute the objective on. - attr_objective (callable): Function defining the attribution objective for the model output. - path_interpolation_func (callable): Function that generates interpolated images and - their lambda derivatives along a defined path. - - Returns: - tuple: - - grad_accumulate_list (np.ndarray): Accumulated gradients for each interpolated image. - - results_numpy (np.ndarray): Model outputs for each interpolated image. - - image_interpolation (np.ndarray): Interpolated images created by `path_interpolation_func`. - """ - - # Prepare image for interpolation and initialize gradient accumulation array - image_interpolation, lambda_derivative_interpolation, _ = path_interpolation_func( - image - ) - - grad_accumulate_list = torch.zeros_like(image_interpolation).cpu().numpy() - result_list = [] - - # Compute gradient for each interpolated image - for i in range(image_interpolation.shape[0]): - - # Convert interpolated image to tensor and set requires_grad for backpropagation - img_tensor = image_interpolation[i].float()[None] - img_tensor.requires_grad_(True) - - # Forward pass through the model and compute attribution objective - result = model(img_tensor) - target = attr_objective(result) - target.backward() # Compute gradients - - # Extract gradient, handling NaNs if present - grad = img_tensor.grad.cpu().numpy() - grad = np.nan_to_num(grad) # Replace NaNs with 0 - - # Accumulate gradients adjusted by lambda derivatives - grad_accumulate_list[i] = ( - grad * lambda_derivative_interpolation[i].cpu().numpy() - ) - result_list.append(result.detach().cpu().numpy()) - - # Collect results and return final outputs - results_numpy = np.array(result_list) - return grad_accumulate_list, results_numpy, image_interpolation - - def attribution_objective(attr_func, h: int, w: int, window: int = 16): """ Creates an objective function to calculate attribution within a specified window @@ -258,72 +86,118 @@ def attr_grad( else: raise ValueError(f"Invalid reduction type: {reduce}. Use 'sum' or 'mean'.") +def down_up(X: torch.Tensor, scale_factor: float=0.5) -> torch.Tensor: + """Downsample and upsample an image using bilinear interpolation. -def lam( - X: torch.Tensor, - model: torch.nn.Module, - h: Optional[int] = 240, - w: Optional[int] = 240, - window: Optional[int] = 32, - fold: Optional[int] = 25, - kernel_size: Optional[int] = 13, - sigma: Optional[float] = 3.5, - robustness_metric: Optional[str] = True -): + Args: + X (torch.Tensor): The input tensor (Bands x Height x Width). + scale_factor (float, optional): The scaling factor. Defaults to 0.5. + + Returns: + torch.Tensor: The downsampled and upsampled image. """ - Computes the Local Attribution Map (LAM) for an input tensor using - a specified model and attribution function. The function calculates - the path gradient for each band in the input tensor and combines the - results to generate the LAM. + shape_init = X.shape + return torch.nn.functional.interpolate( + input=torch.nn.functional.interpolate( + input=X, + scale_factor=1/scale_factor, + mode="bilinear", + antialias=True + ), + size=shape_init[2:], + mode="bilinear", + antialias=True + ) + +def create_blur_cube(X: torch.Tensor, scales: list) -> torch.Tensor: + """Create a cube of blurred images at different scales. Args: - X (torch.Tensor): Input tensor of shape (channels, height, width). - model (torch.nn.Module): The model to compute the objective on. - attr_func (callable): Function that calculates attributions for an image. - h (int): The top coordinate of the window within the image. - w (int): The left coordinate of the window within the image. - window (int, optional): The size of the square window. Defaults to 16. - fold (int, optional): Number of interpolation steps for the blurring path. - Defaults to 10. - kernel_size (int, optional): Size of the Gaussian kernel. Defaults to 5. - sigma (float, optional): Initial standard deviation for the Gaussian blur. - Defaults to 3.5. - robustness_metric (bool, optional): Whether to return the robustness metric. - Defaults to True. + X (torch.Tensor): The input tensor (Bands x Height x Width). + scales (list): The scales to evaluate. Returns: - tuple: A tuple containing the following elements: - - kde_map (np.ndarray): KDE estimation of the LAM. - - complexity_metric (float): Gini index of the LAM that - measures the consistency of the attribution. The - larger the value, the more use more complex attribution - patterns to solve the task. - - robustness_metric (np.ndarray): Blurriness sensitivity of the LAM. - The sensitivity measures the average gradient magnitude of the - interpolated images. - - robustness_vector (np.ndarray): Vector of gradient magnitudes for - each interpolated image. + torch.Tensor: The cube of blurred images. """ + scales_int = [float(scale[:-1]) for scale in scales] + return torch.stack([down_up(X[None], scale) for scale in scales_int]).squeeze() - # Get the scale of the results - with torch.no_grad(): - output = model(X[None]) - scale = output.shape[-1] // X.shape[-1] - # Create the path interpolation function - path_interpolation_func = GaussianBlurPath( - sigma=sigma, fold=fold, kernel_size=kernel_size - ) - # a, b, c = path_interpolation_func(X) +def create_lam_inputs( + X: torch.Tensor, + scales: list +) -> Tuple[torch.Tensor, torch.Tensor, list]: + """Create the inputs for the Local Attribution Map (LAM). + + Args: + X (torch.Tensor): The input tensor (Bands x Height x Width). + scales (list): The scales to evaluate. + + Returns: + Tuple[torch.Tensor, torch.Tensor, list]: The cube of blurred + images, the difference between the input and the cube, + and the scales. + """ + cube = create_blur_cube(X, scales) + diff = torch.abs(X[None] - cube) + return cube[1:], diff[1:], scales[1:] + + +def lam( + X: torch.Tensor, + model: torch.nn.Module, + model_scale: float = 4, + h: int = 240, + w: int = 240, + window: int = 32, + scales: list = ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"] +) -> Tuple[np.ndarray, float, float, np.ndarray]: + """ Estimate the Local Attribution Map (LAM) + + Args: + X (torch.Tensor): The input tensor (Bands x Height x Width). + model (torch.nn.Module): The model to evaluate. + model_scale (float, optional): The scale of the model. Defaults to 4. + h (int, optional): The height of the window to evaluate. Defaults to 240. + w (int, optional): The width of the window to evaluate. Defaults to 240. + window (int, optional): The window size. Defaults to 32. + scales (list, optional): The scales to evaluate. Defaults to + ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"]. + + Returns: + Tuple[np.ndarray, float, float, np.ndarray]: _description_ + """ + + # Create the LAM inputs + cube, diff, scales = create_lam_inputs(X, scales) # Create the attribution objective function attr_objective = attribution_objective(attr_grad, h, w, window=window) - # Compute the path gradient for the input tensor - grad_accumulate_list,results_numpy, image_interpolation = Path_gradient( - X, model, attr_objective, path_interpolation_func - ) - + # Initialize the gradient accumulation list + grad_accumulate_list = torch.zeros_like(cube).cpu().numpy() + + # Compute gradient for each interpolated image + for i in tqdm(range(cube.shape[0]), desc="Computing gradients"): + + # Convert interpolated image to tensor and set requires_grad for backpropagation + img_tensor = cube[i].float()[None] + img_tensor.requires_grad_(True) + + # Forward pass through the model and compute attribution objective + result = model(img_tensor) + target = attr_objective(result) + target.backward() # Compute gradients + + # Extract gradient, handling NaNs if present + grad = img_tensor.grad.cpu().numpy() + grad = np.nan_to_num(grad) # Replace NaNs with 0 + + # Accumulate gradients adjusted by lambda derivatives + grad_accumulate_list[i] = ( + grad * diff[i].cpu().numpy() + ) + # Sum the accumulated gradients across all bands lam_results = torch.sum(torch.from_numpy(np.abs(grad_accumulate_list)), dim=0) grad_2d = np.abs(lam_results.sum(axis=0)) @@ -333,9 +207,12 @@ def lam( # Estimate gini index gini_index = gini(grad_norm.flatten()) + ## window to image size + #ratio_img_to_window = (X.shape[1] * model_scale) // window + # KDE estimation - kde_map = vis_saliency_kde(grad_norm, scale=scale, bandwidth=1.0) - complexity_metric = (1 - gini_index) * 100 + kde_map = vis_saliency_kde(grad_norm, scale=model_scale, bandwidth=1.0) + complexity_metric = (1 - gini_index) * 100 # / ratio_img_to_window # Estimate blurriness sensitivity robustness_vector = np.abs(grad_accumulate_list).mean(axis=(1, 2, 3))