Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
csaybar committed Dec 2, 2024
1 parent f5023e6 commit 4b3361b
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 292 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
demo.py
demo.py
demo2.py
136 changes: 85 additions & 51 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
</p>

<p align="center">
<em>A Python package for enhancing the spatial resolution of Sentinel-2 satellite images to 2.5 meters</em> 🚀
<em>A Python package for enhancing the spatial resolution of Sentinel-2 satellite images up to 2.5 meters</em> 🚀
</p>


Expand Down Expand Up @@ -38,35 +38,28 @@

## **Overview** 📊

**supers2** is a Python package designed to enhance the spatial resolution of Sentinel-2 satellite images to 2.5 meters using advanced neural network models. It facilitates downloading (cubo package), preparing, and processing the Sentinel-2 data and applies deep learning models to enhance the spatial resolution of the imagery.
**supers2** is a Python package designed to enhance the spatial resolution of Sentinel-2 satellite images to 2.5 meters using a set of neural network models.

## **Installation** ⚙️

Install the latest version from PyPI:

```bash
pip install cubo supers2
pip install supers2
```

## **How to use** 🛠️

### **Basic usage: enhancing spatial resolution of Sentinel-2 images** 🌍

#### **Load libraries**

```python
import cubo
import matplotlib.pyplot as plt
import numpy as np
import torch

import supers2
import torch
import cubo

```

#### **Download Sentinel-2 L2A cube**

```python
# Create a Sentinel-2 L2A data cube for a specific location and date range
## Download Sentinel-2 L2A cube
da = cubo.create(
lat=4.31,
lon=-76.2,
Expand All @@ -77,19 +70,7 @@ da = cubo.create(
edge_size=128,
resolution=10
)
```

#### **Prepare the data (CPU and GPU usage)**

When converting the NumPy array to a PyTorch tensor, the use of `cuda()` is optional and depends on whether the user has access to a GPU. Below is the explanation for both cases:

- **GPU:** If a GPU is available and CUDA is installed, you can transfer the tensor to the GPU using `.cuda()`. This improves the processing speed, especially for large datasets or deep learning models.

- **CPU:** If no GPU is available, the tensor will be processed on the CPU, which is the default behavior in PyTorch. In this case, simply omit the `.cuda()` call.

Here’s how you can handle both scenarios dynamically:

```python
# Convert the data array to NumPy and scale
original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32")

Expand All @@ -98,48 +79,101 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the tensor and move it to the appropriate device (CPU or GPU)
X = torch.from_numpy(original_s2_numpy).float().to(device)

# Set up the model to enhance the spatial resolution
models = supers2.setmodel(device=device)

# Apply spatial resolution enhancement
superX = supers2.predict(X, models=models, resolution="2.5m")

# Visualize the results
# Plot the original and enhanced-resolution images
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(X[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
ax[0].set_title("Original S2")
ax[1].imshow(superX[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
ax[1].set_title("Enhanced Resolution S2")
plt.show()
```

#### **Define the resolution enhancement model**
<p align="center">
<img src="./assets/images/example1.png" width="100%">
</p>


## Chante the model settings 🛠️

At the end of the document, you can find a table with the available models and their characteristics.

```python
# Set up the model to enhance the spatial resolution
models = supers2.setmodel(
SR_model_loss="l1",
SR_model_name="cnn",
SR_model_size="small",
Fusionx2_model_size="lightweight",
Fusionx4_model_size="lightweight"
resolution = "2.5m", # Set the desired resolution
sr_model_snippet = "sr__opensrbaseline__cnn__medium__l1", # RGBN model from 10m to 2.5m
fusionx2_model_snippet = "fusionx2__opensrbaseline__cnn__large__l1", # RedESWIR model from 20m to 10m
fusionx4_model_snippet = "fusionx4__opensrbaseline__cnn__large__l1", #RedESWIR model from 10m to 2.5m
weights_path = None, # Path to the weights file
device = "cpu" # Use the CPU
)
```
### **Apply spatial resolution enhancement**

```python
# Apply the model to enhance the image resolution to 2.5 meters
# Apply spatial resolution enhancement
superX = supers2.predict(X, models=models, resolution="2.5m")
```

### **Visualize the results** 🎨
### **Predict only RGBNIR bands** 🌍

```python
superX = supers2.predict_rgbnir(X[[2, 1, 0, 6]])
```

#### **Display images**
### **Estimate the uncertainty of the model** 📊

```python
import matplotlib.pyplot as plt
from supers2.trained_models import SRmodels

# Plot the original and enhanced-resolution images
# Get the available models
models = list(SRmodels.model_dump()["object"].keys())

# Get only swin transformer models
swin2sr_models = [model for model in models if "swin" in model]

map_mean, map_std = supers2.uncertainty(
X[[2, 1, 0, 6]],
models=swin2sr_models
)

# Visualize the uncertainty
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(X[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
ax[0].set_title("Original S2")
ax[1].imshow(superX[[2, 1, 0]].permute(1, 2, 0).cpu().numpy()*4)
ax[1].set_title("Enhanced Resolution S2")
ax[0].imshow(mean_map[0:3].cpu().numpy().transpose(1, 2, 0)*3)
ax[0].set_title("Mean")
ax[1].imshow(std_map[0:3].cpu().numpy().transpose(1, 2, 0)*100)
ax[1].set_title("Standard Deviation")
plt.show()
```

<p align="center">
<img src="./assets/images/example1.png" width="100%">
</p>

## **Supported features and filters**
### Estimate the Local Attention Map of the model 📊


```python
kde_map, complexity_metric, robustness_metric, robustness_vector = supers2.lam(
X=X[[2, 1, 0, 6]].cpu(), # The input tensor
model=models.srx4, # The SR model
h=240, # The height of the window
w=240, # The width of the window
window=128, # The window size
scales = ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"]
)

# Visualize the results
plt.imshow(kde_map)
plt.title("Kernel Density Estimation")
plt.show()

plt.plot(robustness_vector)
plt.title("Robustness Vector")
plt.show()
```


- **Enhance spatial resolution to 2.5 meters:** Use advanced CNN models to enhance Sentinel-2 imagery.
- **Neural network-based approach:** Integration of multiple model sizes to fit different computing needs (small, lightweight).
- **Python integration:** Easily interact with data cubes through the Python API, supporting seamless workflows.
### Use the opensr-test and supers2 to analyze the hallucination pixels 📊
6 changes: 5 additions & 1 deletion supers2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from supers2.main import predict, setmodel, predict_large, predict_rgbnir
from supers2.main import predict, setmodel, predict_large, predict_rgbnir, uncertainty
from supers2.xai.lam import lam
from supers2.trained_models import SRmodels

models = list(SRmodels.model_dump()["object"].keys())

14 changes: 4 additions & 10 deletions supers2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from supers2.dataclass import SRexperiment
from supers2.setup import load_model
from supers2.utils import define_iteration
from supers2.utils import define_iteration, gdal_create
from supers2.trained_models import SRmodels


def setmodel(
resolution: Literal["2.5m", "5m", "10m"] = "2.5m",
sr_model_snippet: str = "sr__opensrbaseline__cnn__lightweight__l1",
Expand Down Expand Up @@ -282,11 +281,7 @@ def predict_large(

# Create the output image
with rio.open(output_fullname, "w", **output_metadata) as dst:
data_np = np.zeros(
(metadata["count"], metadata["height"] * res_n, metadata["width"] * res_n),
dtype=np.uint16,
)
dst.write(data_np)
pass

# Check if the models are loaded
if models is None:
Expand All @@ -295,8 +290,7 @@ def predict_large(
# Iterate over the image
with rio.open(output_fullname, "r+") as dst:
with rio.open(image_fullname) as src:
for index, point in enumerate(tqdm.tqdm(nruns)):

for index, point in enumerate(tqdm.tqdm(nruns)):
# Read a block of the image
window = rio.windows.Window(point[1], point[0], 128, 128)
X = torch.from_numpy(src.read(window=window)).float().to(device)
Expand Down Expand Up @@ -431,7 +425,7 @@ def uncertainty(
)

# Run the model
X_torch = torch.from_numpy((X / 10_000)).float().to(device)
X_torch = X.float().to(device)
prediction = model_object(X_torch[None]).squeeze().cpu()

# Store the prediction
Expand Down
55 changes: 55 additions & 0 deletions supers2/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import itertools
import rasterio
import pathlib

from rasterio.crs import CRS


def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
Expand Down Expand Up @@ -57,3 +61,54 @@ def fix_lastchunk(iterchunks, s2dim, chunk_size):
itercontainer.append((index_i, index_j))

return itercontainer


def gdal_create(
outfilename: str,
dtype: str = 'uint16',
driver: str = 'GTiff',
count: int = 13,
width: int = 5120,
height: int = 5120,
nodata: int = 65535,
crs: int = 4326,
affine: tuple = (-180, 0.5, 90, -0.5),
**kwargs,
) -> pathlib.Path:
"""
Fast creation of a new raster file using rasterio.
Args:
outfilename (str): Output filename.
dtype (str): Data type of the raster.
driver (str): GDAL driver to use.
count (int): Number of bands in the raster.
width (int): Width of the raster.
height (int): Height of the raster.
nodata (int): NoData value.
crs (int): EPSG code of the raster.
affine (tuple): Affine transformation of the raster.
Returns:
pathlib.Path: Path to the created raster file.
"""
# Define the metadata for the new file
meta = {
'driver': driver,
'dtype': dtype,
'nodata': nodata,
'width': width,
'height': height,
'count': count,
'crs': CRS.from_epsg(crs),
'transform': rasterio.transform.from_origin(*affine),
}

# Merge the metadata with the additional kwargs
meta.update(kwargs)

with rasterio.open(outfilename, 'w', **meta) as dst:
pass

return pathlib.Path(outfilename)
Loading

0 comments on commit 4b3361b

Please sign in to comment.