From 4b3361ba46b8908bb11b0c8971c4137d6011c673 Mon Sep 17 00:00:00 2001
From: Cesar Luis Aybar Camacho
Date: Mon, 2 Dec 2024 14:09:04 +0100
Subject: [PATCH] up
---
.gitignore | 3 +-
README.md | 136 +++++++++++-------
supers2/__init__.py | 6 +-
supers2/main.py | 14 +-
supers2/utils.py | 55 ++++++++
supers2/xai/lam.py | 335 ++++++++++++++------------------------------
6 files changed, 257 insertions(+), 292 deletions(-)
diff --git a/.gitignore b/.gitignore
index 0849c1f..0a5d6c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -165,4 +165,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
-demo.py
\ No newline at end of file
+demo.py
+demo2.py
diff --git a/README.md b/README.md
index 180fdb1..27ed05e 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
- A Python package for enhancing the spatial resolution of Sentinel-2 satellite images to 2.5 meters π
+ A Python package for enhancing the spatial resolution of Sentinel-2 satellite images up to 2.5 meters π
@@ -38,35 +38,28 @@
## **Overview** π
-**supers2** is a Python package designed to enhance the spatial resolution of Sentinel-2 satellite images to 2.5 meters using advanced neural network models. It facilitates downloading (cubo package), preparing, and processing the Sentinel-2 data and applies deep learning models to enhance the spatial resolution of the imagery.
+**supers2** is a Python package designed to enhance the spatial resolution of Sentinel-2 satellite images to 2.5 meters using a set of neural network models.
## **Installation** βοΈ
Install the latest version from PyPI:
```bash
-pip install cubo supers2
+pip install supers2
```
## **How to use** π οΈ
### **Basic usage: enhancing spatial resolution of Sentinel-2 images** π
-#### **Load libraries**
-
```python
-import cubo
+import matplotlib.pyplot as plt
import numpy as np
-import torch
-
import supers2
+import torch
+import cubo
-```
-
-#### **Download Sentinel-2 L2A cube**
-
-```python
-# Create a Sentinel-2 L2A data cube for a specific location and date range
+## Download Sentinel-2 L2A cube
da = cubo.create(
lat=4.31,
lon=-76.2,
@@ -77,19 +70,7 @@ da = cubo.create(
edge_size=128,
resolution=10
)
-```
-
-#### **Prepare the data (CPU and GPU usage)**
-
-When converting the NumPy array to a PyTorch tensor, the use of `cuda()` is optional and depends on whether the user has access to a GPU. Below is the explanation for both cases:
-
-- **GPU:** If a GPU is available and CUDA is installed, you can transfer the tensor to the GPU using `.cuda()`. This improves the processing speed, especially for large datasets or deep learning models.
-
-- **CPU:** If no GPU is available, the tensor will be processed on the CPU, which is the default behavior in PyTorch. In this case, simply omit the `.cuda()` call.
-
-Hereβs how you can handle both scenarios dynamically:
-```python
# Convert the data array to NumPy and scale
original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32")
@@ -98,48 +79,101 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create the tensor and move it to the appropriate device (CPU or GPU)
X = torch.from_numpy(original_s2_numpy).float().to(device)
+
+# Set up the model to enhance the spatial resolution
+models = supers2.setmodel(device=device)
+
+# Apply spatial resolution enhancement
+superX = supers2.predict(X, models=models, resolution="2.5m")
+
+# Visualize the results
+# Plot the original and enhanced-resolution images
+fig, ax = plt.subplots(1, 2, figsize=(10, 5))
+ax[0].imshow(X[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
+ax[0].set_title("Original S2")
+ax[1].imshow(superX[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
+ax[1].set_title("Enhanced Resolution S2")
+plt.show()
```
-#### **Define the resolution enhancement model**
+
+
+
+
+
+## Chante the model settings π οΈ
+
+At the end of the document, you can find a table with the available models and their characteristics.
+
```python
# Set up the model to enhance the spatial resolution
models = supers2.setmodel(
- SR_model_loss="l1",
- SR_model_name="cnn",
- SR_model_size="small",
- Fusionx2_model_size="lightweight",
- Fusionx4_model_size="lightweight"
+ resolution = "2.5m", # Set the desired resolution
+ sr_model_snippet = "sr__opensrbaseline__cnn__medium__l1", # RGBN model from 10m to 2.5m
+ fusionx2_model_snippet = "fusionx2__opensrbaseline__cnn__large__l1", # RedESWIR model from 20m to 10m
+ fusionx4_model_snippet = "fusionx4__opensrbaseline__cnn__large__l1", #RedESWIR model from 10m to 2.5m
+ weights_path = None, # Path to the weights file
+ device = "cpu" # Use the CPU
)
-```
-### **Apply spatial resolution enhancement**
-```python
-# Apply the model to enhance the image resolution to 2.5 meters
+# Apply spatial resolution enhancement
superX = supers2.predict(X, models=models, resolution="2.5m")
```
-### **Visualize the results** π¨
+### **Predict only RGBNIR bands** π
+
+```python
+superX = supers2.predict_rgbnir(X[[2, 1, 0, 6]])
+```
-#### **Display images**
+### **Estimate the uncertainty of the model** π
```python
-import matplotlib.pyplot as plt
+from supers2.trained_models import SRmodels
-# Plot the original and enhanced-resolution images
+# Get the available models
+models = list(SRmodels.model_dump()["object"].keys())
+
+# Get only swin transformer models
+swin2sr_models = [model for model in models if "swin" in model]
+
+map_mean, map_std = supers2.uncertainty(
+ X[[2, 1, 0, 6]],
+ models=swin2sr_models
+)
+
+# Visualize the uncertainty
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
-ax[0].imshow(X[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
-ax[0].set_title("Original S2")
-ax[1].imshow(superX[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
-ax[1].set_title("Enhanced Resolution S2")
+ax[0].imshow(mean_map[0:3].cpu().numpy().transpose(1, 2, 0)*3)
+ax[0].set_title("Mean")
+ax[1].imshow(std_map[0:3].cpu().numpy().transpose(1, 2, 0)*100)
+ax[1].set_title("Standard Deviation")
plt.show()
```
-
-
-
-## **Supported features and filters** β¨
+### Estimate the Local Attention Map of the model π
+
+
+```python
+kde_map, complexity_metric, robustness_metric, robustness_vector = supers2.lam(
+ X=X[[2, 1, 0, 6]].cpu(), # The input tensor
+ model=models.srx4, # The SR model
+ h=240, # The height of the window
+ w=240, # The width of the window
+ window=128, # The window size
+ scales = ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"]
+)
+
+# Visualize the results
+plt.imshow(kde_map)
+plt.title("Kernel Density Estimation")
+plt.show()
+
+plt.plot(robustness_vector)
+plt.title("Robustness Vector")
+plt.show()
+```
+
-- **Enhance spatial resolution to 2.5 meters:** Use advanced CNN models to enhance Sentinel-2 imagery.
-- **Neural network-based approach:** Integration of multiple model sizes to fit different computing needs (small, lightweight).
-- **Python integration:** Easily interact with data cubes through the Python API, supporting seamless workflows.
\ No newline at end of file
+### Use the opensr-test and supers2 to analyze the hallucination pixels π
diff --git a/supers2/__init__.py b/supers2/__init__.py
index 9ba3302..d9b99ad 100644
--- a/supers2/__init__.py
+++ b/supers2/__init__.py
@@ -1,2 +1,6 @@
-from supers2.main import predict, setmodel, predict_large, predict_rgbnir
+from supers2.main import predict, setmodel, predict_large, predict_rgbnir, uncertainty
from supers2.xai.lam import lam
+from supers2.trained_models import SRmodels
+
+models = list(SRmodels.model_dump()["object"].keys())
+
diff --git a/supers2/main.py b/supers2/main.py
index ee1aa2a..c41b2f0 100644
--- a/supers2/main.py
+++ b/supers2/main.py
@@ -8,10 +8,9 @@
from supers2.dataclass import SRexperiment
from supers2.setup import load_model
-from supers2.utils import define_iteration
+from supers2.utils import define_iteration, gdal_create
from supers2.trained_models import SRmodels
-
def setmodel(
resolution: Literal["2.5m", "5m", "10m"] = "2.5m",
sr_model_snippet: str = "sr__opensrbaseline__cnn__lightweight__l1",
@@ -282,11 +281,7 @@ def predict_large(
# Create the output image
with rio.open(output_fullname, "w", **output_metadata) as dst:
- data_np = np.zeros(
- (metadata["count"], metadata["height"] * res_n, metadata["width"] * res_n),
- dtype=np.uint16,
- )
- dst.write(data_np)
+ pass
# Check if the models are loaded
if models is None:
@@ -295,8 +290,7 @@ def predict_large(
# Iterate over the image
with rio.open(output_fullname, "r+") as dst:
with rio.open(image_fullname) as src:
- for index, point in enumerate(tqdm.tqdm(nruns)):
-
+ for index, point in enumerate(tqdm.tqdm(nruns)):
# Read a block of the image
window = rio.windows.Window(point[1], point[0], 128, 128)
X = torch.from_numpy(src.read(window=window)).float().to(device)
@@ -431,7 +425,7 @@ def uncertainty(
)
# Run the model
- X_torch = torch.from_numpy((X / 10_000)).float().to(device)
+ X_torch = X.float().to(device)
prediction = model_object(X_torch[None]).squeeze().cpu()
# Store the prediction
diff --git a/supers2/utils.py b/supers2/utils.py
index fe493d4..0b2c6c3 100644
--- a/supers2/utils.py
+++ b/supers2/utils.py
@@ -1,4 +1,8 @@
import itertools
+import rasterio
+import pathlib
+
+from rasterio.crs import CRS
def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
@@ -57,3 +61,54 @@ def fix_lastchunk(iterchunks, s2dim, chunk_size):
itercontainer.append((index_i, index_j))
return itercontainer
+
+
+def gdal_create(
+ outfilename: str,
+ dtype: str = 'uint16',
+ driver: str = 'GTiff',
+ count: int = 13,
+ width: int = 5120,
+ height: int = 5120,
+ nodata: int = 65535,
+ crs: int = 4326,
+ affine: tuple = (-180, 0.5, 90, -0.5),
+ **kwargs,
+) -> pathlib.Path:
+ """
+ Fast creation of a new raster file using rasterio.
+
+ Args:
+ outfilename (str): Output filename.
+ dtype (str): Data type of the raster.
+ driver (str): GDAL driver to use.
+ count (int): Number of bands in the raster.
+ width (int): Width of the raster.
+ height (int): Height of the raster.
+ nodata (int): NoData value.
+ crs (int): EPSG code of the raster.
+ affine (tuple): Affine transformation of the raster.
+
+
+ Returns:
+ pathlib.Path: Path to the created raster file.
+ """
+ # Define the metadata for the new file
+ meta = {
+ 'driver': driver,
+ 'dtype': dtype,
+ 'nodata': nodata,
+ 'width': width,
+ 'height': height,
+ 'count': count,
+ 'crs': CRS.from_epsg(crs),
+ 'transform': rasterio.transform.from_origin(*affine),
+ }
+
+ # Merge the metadata with the additional kwargs
+ meta.update(kwargs)
+
+ with rasterio.open(outfilename, 'w', **meta) as dst:
+ pass
+
+ return pathlib.Path(outfilename)
\ No newline at end of file
diff --git a/supers2/xai/lam.py b/supers2/xai/lam.py
index a26888a..07b7130 100644
--- a/supers2/xai/lam.py
+++ b/supers2/xai/lam.py
@@ -1,185 +1,13 @@
-from typing import Optional
+from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
+from tqdm import tqdm
from supers2.xai.utils import gini, vis_saliency_kde
-def GaussianBlurPath(sigma: float, fold: int, kernel_size: int = 5):
- """
- Generates a function for applying a Gaussian blur path to an image using PyTorch.
- The function applies progressively weaker Gaussian blurs to an image and calculates
- interpolations between each blurred image, along with derivatives for each step.
-
- Args:
- sigma (float): Initial standard deviation for the Gaussian blur.
- fold (int): Number of interpolation steps for the blurring path.
- kernel_size (int, optional): Size of the Gaussian kernel. Defaults to 5.
-
- Returns:
- Callable: A function that takes an image and returns a tuple:
- - image_interpolation (torch.Tensor): Interpolated blurred images.
- - lambda_derivative_interpolation (torch.Tensor): Derivatives of interpolated images.
- """
-
- def path_interpolation_func(torch_image: torch.Tensor):
- """
- Applies the Gaussian blur path to the input image and computes the interpolated images
- and their derivatives using PyTorch.
-
- Args:
- torch_image (torch.Tensor): Input image as a torch tensor (channels, height, width).
-
- Returns:
- tuple: Interpolated blurred images and their derivatives along the Gaussian path.
- """
- device = torch_image.device
-
- # Ensure image is 4D (batch, channels, height, width)
- torch_image = torch_image.unsqueeze(0) if torch_image.ndim == 3 else torch_image
- torch_image = torch_image.to(device)
- channels = torch_image.shape[1]
-
- # Initialize tensors for blurred images and derivatives
- image_interpolation = torch.zeros(
- (fold, *torch_image.shape[1:]), dtype=torch.float32
- )
- image_interpolation = image_interpolation.to(device)
- lambda_derivative_interpolation = torch.zeros_like(image_interpolation)
- lambda_derivative_interpolation = lambda_derivative_interpolation.to(device)
- kernel_interpolation = torch.zeros(
- (fold + 1, channels, kernel_size, kernel_size), dtype=torch.float32
- )
- kernel_interpolation = kernel_interpolation.to(device)
-
- # Linearly interpolate sigma values from initial to zero
- sigma_interpolation = np.linspace(sigma, 0, fold + 1)
-
- # Create Gaussian kernels for each sigma value
- for i in range(fold + 1):
- kernel_interpolation[i] = isotropic_gaussian_kernel_torch(
- kernel_size, sigma_interpolation[i]
- ).squeeze()
-
- # Calculate padding size
- pad_size = kernel_interpolation.shape[-1] // 2
-
- # Create Gaussian kernels for each sigma and apply to image
- for i in range(fold):
- # Apply reflect padding first
- padded_image = F.pad(
- torch_image, (pad_size, pad_size, pad_size, pad_size), mode="reflect"
- )
-
- # Store the current blurred image
- image_interpolation[i] = F.conv2d(
- padded_image, kernel_interpolation[i + 1][:, None], groups=channels
- ).squeeze(0)
-
- # Calculate derivative with respect to lambda
- diff_kernel = (kernel_interpolation[i + 1] - kernel_interpolation[i])[
- :, None
- ] * fold
- lambda_derivative_interpolation[i] = F.conv2d(
- padded_image, diff_kernel, groups=channels
- ).squeeze(0)
-
- return (
- image_interpolation,
- lambda_derivative_interpolation,
- kernel_interpolation,
- )
-
- return path_interpolation_func
-
-
-def isotropic_gaussian_kernel_torch(
- size: int, sigma: float, epsilon: float = 1e-5
-) -> torch.Tensor:
- """
- Generates an isotropic Gaussian kernel in PyTorch.
-
- Args:
- size (int): Size of the kernel (size x size).
- sigma (float): Standard deviation of the Gaussian distribution.
- epsilon (float, optional): Small constant to avoid division by zero. Defaults to 1e-5.
-
- Returns:
- torch.Tensor: Normalized Gaussian kernel.
- """
- ax = torch.arange(-size // 2 + 1.0, size // 2 + 1.0)
- xx, yy = torch.meshgrid(ax, ax, indexing="ij")
-
- # Calculate Gaussian function and normalize
- kernel = torch.exp(-(xx**2 + yy**2) / (2.0 * (sigma + epsilon) ** 2))
- kernel = kernel / kernel.sum()
-
- # Reshape for 2D convolution (out_channels, in_channels, height, width)
- return kernel.unsqueeze(0).unsqueeze(0)
-
-
-def Path_gradient(
- image: torch.Tensor,
- model: torch.nn.Module,
- attr_objective: callable,
- path_interpolation_func: callable,
-):
- """
- Computes the path gradient for an image using a specified model and attribution objective.
- The function calculates gradients for a series of interpolated images produced by a path
- interpolation function.
-
- Args:
- numpy_image (np.ndarray): Input image of shape (channels, height, width).
- model (torch.nn.Module): The model to compute the objective on.
- attr_objective (callable): Function defining the attribution objective for the model output.
- path_interpolation_func (callable): Function that generates interpolated images and
- their lambda derivatives along a defined path.
-
- Returns:
- tuple:
- - grad_accumulate_list (np.ndarray): Accumulated gradients for each interpolated image.
- - results_numpy (np.ndarray): Model outputs for each interpolated image.
- - image_interpolation (np.ndarray): Interpolated images created by `path_interpolation_func`.
- """
-
- # Prepare image for interpolation and initialize gradient accumulation array
- image_interpolation, lambda_derivative_interpolation, _ = path_interpolation_func(
- image
- )
-
- grad_accumulate_list = torch.zeros_like(image_interpolation).cpu().numpy()
- result_list = []
-
- # Compute gradient for each interpolated image
- for i in range(image_interpolation.shape[0]):
-
- # Convert interpolated image to tensor and set requires_grad for backpropagation
- img_tensor = image_interpolation[i].float()[None]
- img_tensor.requires_grad_(True)
-
- # Forward pass through the model and compute attribution objective
- result = model(img_tensor)
- target = attr_objective(result)
- target.backward() # Compute gradients
-
- # Extract gradient, handling NaNs if present
- grad = img_tensor.grad.cpu().numpy()
- grad = np.nan_to_num(grad) # Replace NaNs with 0
-
- # Accumulate gradients adjusted by lambda derivatives
- grad_accumulate_list[i] = (
- grad * lambda_derivative_interpolation[i].cpu().numpy()
- )
- result_list.append(result.detach().cpu().numpy())
-
- # Collect results and return final outputs
- results_numpy = np.array(result_list)
- return grad_accumulate_list, results_numpy, image_interpolation
-
-
def attribution_objective(attr_func, h: int, w: int, window: int = 16):
"""
Creates an objective function to calculate attribution within a specified window
@@ -258,72 +86,118 @@ def attr_grad(
else:
raise ValueError(f"Invalid reduction type: {reduce}. Use 'sum' or 'mean'.")
+def down_up(X: torch.Tensor, scale_factor: float=0.5) -> torch.Tensor:
+ """Downsample and upsample an image using bilinear interpolation.
-def lam(
- X: torch.Tensor,
- model: torch.nn.Module,
- h: Optional[int] = 240,
- w: Optional[int] = 240,
- window: Optional[int] = 32,
- fold: Optional[int] = 25,
- kernel_size: Optional[int] = 13,
- sigma: Optional[float] = 3.5,
- robustness_metric: Optional[str] = True
-):
+ Args:
+ X (torch.Tensor): The input tensor (Bands x Height x Width).
+ scale_factor (float, optional): The scaling factor. Defaults to 0.5.
+
+ Returns:
+ torch.Tensor: The downsampled and upsampled image.
"""
- Computes the Local Attribution Map (LAM) for an input tensor using
- a specified model and attribution function. The function calculates
- the path gradient for each band in the input tensor and combines the
- results to generate the LAM.
+ shape_init = X.shape
+ return torch.nn.functional.interpolate(
+ input=torch.nn.functional.interpolate(
+ input=X,
+ scale_factor=1/scale_factor,
+ mode="bilinear",
+ antialias=True
+ ),
+ size=shape_init[2:],
+ mode="bilinear",
+ antialias=True
+ )
+
+def create_blur_cube(X: torch.Tensor, scales: list) -> torch.Tensor:
+ """Create a cube of blurred images at different scales.
Args:
- X (torch.Tensor): Input tensor of shape (channels, height, width).
- model (torch.nn.Module): The model to compute the objective on.
- attr_func (callable): Function that calculates attributions for an image.
- h (int): The top coordinate of the window within the image.
- w (int): The left coordinate of the window within the image.
- window (int, optional): The size of the square window. Defaults to 16.
- fold (int, optional): Number of interpolation steps for the blurring path.
- Defaults to 10.
- kernel_size (int, optional): Size of the Gaussian kernel. Defaults to 5.
- sigma (float, optional): Initial standard deviation for the Gaussian blur.
- Defaults to 3.5.
- robustness_metric (bool, optional): Whether to return the robustness metric.
- Defaults to True.
+ X (torch.Tensor): The input tensor (Bands x Height x Width).
+ scales (list): The scales to evaluate.
Returns:
- tuple: A tuple containing the following elements:
- - kde_map (np.ndarray): KDE estimation of the LAM.
- - complexity_metric (float): Gini index of the LAM that
- measures the consistency of the attribution. The
- larger the value, the more use more complex attribution
- patterns to solve the task.
- - robustness_metric (np.ndarray): Blurriness sensitivity of the LAM.
- The sensitivity measures the average gradient magnitude of the
- interpolated images.
- - robustness_vector (np.ndarray): Vector of gradient magnitudes for
- each interpolated image.
+ torch.Tensor: The cube of blurred images.
"""
+ scales_int = [float(scale[:-1]) for scale in scales]
+ return torch.stack([down_up(X[None], scale) for scale in scales_int]).squeeze()
- # Get the scale of the results
- with torch.no_grad():
- output = model(X[None])
- scale = output.shape[-1] // X.shape[-1]
- # Create the path interpolation function
- path_interpolation_func = GaussianBlurPath(
- sigma=sigma, fold=fold, kernel_size=kernel_size
- )
- # a, b, c = path_interpolation_func(X)
+def create_lam_inputs(
+ X: torch.Tensor,
+ scales: list
+) -> Tuple[torch.Tensor, torch.Tensor, list]:
+ """Create the inputs for the Local Attribution Map (LAM).
+
+ Args:
+ X (torch.Tensor): The input tensor (Bands x Height x Width).
+ scales (list): The scales to evaluate.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, list]: The cube of blurred
+ images, the difference between the input and the cube,
+ and the scales.
+ """
+ cube = create_blur_cube(X, scales)
+ diff = torch.abs(X[None] - cube)
+ return cube[1:], diff[1:], scales[1:]
+
+
+def lam(
+ X: torch.Tensor,
+ model: torch.nn.Module,
+ model_scale: float = 4,
+ h: int = 240,
+ w: int = 240,
+ window: int = 32,
+ scales: list = ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"]
+) -> Tuple[np.ndarray, float, float, np.ndarray]:
+ """ Estimate the Local Attribution Map (LAM)
+
+ Args:
+ X (torch.Tensor): The input tensor (Bands x Height x Width).
+ model (torch.nn.Module): The model to evaluate.
+ model_scale (float, optional): The scale of the model. Defaults to 4.
+ h (int, optional): The height of the window to evaluate. Defaults to 240.
+ w (int, optional): The width of the window to evaluate. Defaults to 240.
+ window (int, optional): The window size. Defaults to 32.
+ scales (list, optional): The scales to evaluate. Defaults to
+ ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"].
+
+ Returns:
+ Tuple[np.ndarray, float, float, np.ndarray]: _description_
+ """
+
+ # Create the LAM inputs
+ cube, diff, scales = create_lam_inputs(X, scales)
# Create the attribution objective function
attr_objective = attribution_objective(attr_grad, h, w, window=window)
- # Compute the path gradient for the input tensor
- grad_accumulate_list,results_numpy, image_interpolation = Path_gradient(
- X, model, attr_objective, path_interpolation_func
- )
-
+ # Initialize the gradient accumulation list
+ grad_accumulate_list = torch.zeros_like(cube).cpu().numpy()
+
+ # Compute gradient for each interpolated image
+ for i in tqdm(range(cube.shape[0]), desc="Computing gradients"):
+
+ # Convert interpolated image to tensor and set requires_grad for backpropagation
+ img_tensor = cube[i].float()[None]
+ img_tensor.requires_grad_(True)
+
+ # Forward pass through the model and compute attribution objective
+ result = model(img_tensor)
+ target = attr_objective(result)
+ target.backward() # Compute gradients
+
+ # Extract gradient, handling NaNs if present
+ grad = img_tensor.grad.cpu().numpy()
+ grad = np.nan_to_num(grad) # Replace NaNs with 0
+
+ # Accumulate gradients adjusted by lambda derivatives
+ grad_accumulate_list[i] = (
+ grad * diff[i].cpu().numpy()
+ )
+
# Sum the accumulated gradients across all bands
lam_results = torch.sum(torch.from_numpy(np.abs(grad_accumulate_list)), dim=0)
grad_2d = np.abs(lam_results.sum(axis=0))
@@ -333,9 +207,12 @@ def lam(
# Estimate gini index
gini_index = gini(grad_norm.flatten())
+ ## window to image size
+ #ratio_img_to_window = (X.shape[1] * model_scale) // window
+
# KDE estimation
- kde_map = vis_saliency_kde(grad_norm, scale=scale, bandwidth=1.0)
- complexity_metric = (1 - gini_index) * 100
+ kde_map = vis_saliency_kde(grad_norm, scale=model_scale, bandwidth=1.0)
+ complexity_metric = (1 - gini_index) * 100 # / ratio_img_to_window
# Estimate blurriness sensitivity
robustness_vector = np.abs(grad_accumulate_list).mean(axis=(1, 2, 3))