Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Add Optical Flow Baseline #49

Merged
merged 6 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,3 @@ The data used here is a combination of the UK Met Office's rainfall radar data,
satellite data (12 channels), derived data from the MSG satellites (cloud masks, etc.), and
numerical weather prediction data. Currently, some example transformed EUMETSAT data can be downloaded
from the tagged release, as well as included under ```datasets/```.

##
22 changes: 22 additions & 0 deletions satflow/baseline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## Baseline

To see if our ML models are actually improving the predictions of the cloud masks, we
have benchmarked it against OpenCV's dense optical flow predictions, as well as a naive
baseline of just predicting the current image for all future timesteps.

As we come up with and implement more metrics to compare these models, they will be added
here. Currently, the only metric tested is the mean squared error between the predicted frame
and the ground truth frame. To get a sense if there is a temporal dependence, the mean loss is
done not just for the overall predictions, but for each of the future timesteps, going up to 4 hours (48 timesteps)
in the future.

On average, the optical flow approach has an MSE of 0.1541. The naive baseline has a MSE of 0.1566,
so optical flow beats out the naive baseline by about 1.6%.

## Caveats

We tried obtaining the optical flow of consecutive, or even very temporally separated the cloud masks,
but the optical flow usually ended up not actually changing anything. Instead, we used the
MSG HSV satellite channel to compute the optical flow. This was chosen as that is the highest
resolution satellite channel available, and it resulted in optical flow actually computing some movement.
This flow field was then applied to the cloud masks directly to obtain the flow results.
Empty file added satflow/baseline/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions satflow/baseline/optical_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import cv2
from satflow.data.datasets import OpticalFlowDataset
import webdataset as wds
import yaml
import torch.nn.functional as F
import numpy as np


def load_config(config_file):
with open(config_file, "r") as cfg:
return yaml.load(cfg, Loader=yaml.FullLoader)["config"]


config = load_config(
"/home/jacob/Development/satflow/satflow/configs/datamodule/optical_flow_datamodule.yaml"
)
dset = wds.WebDataset("/run/media/jacob/data/satflow-flow-144-tiled-{00001..00149}.tar")

dataset = OpticalFlowDataset([dset], config=config)

import matplotlib.pyplot as plt
import torch


def warp_flow(img, flow):
h, w = flow.shape[:2]
flow = -flow
flow[:, :, 0] += np.arange(w)
flow[:, :, 1] += np.arange(h)[:, np.newaxis]
res = cv2.remap(img, flow, None, cv2.INTER_LINEAR)
return res


debug = False
total_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
count = 0
baseline_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
overall_loss = 0.0
overall_baseline = 0.0

for data in dataset:
tmp_loss = 0
tmp_base = 0
count += 1
prev_frame, curr_frame, next_frames, image, prev_image = data
prev_frame = np.moveaxis(prev_frame, [0], [2])
curr_frame = np.moveaxis(curr_frame, [0], [2])
flow = cv2.calcOpticalFlowFarneback(prev_image, image, None, 0.5, 3, 15, 3, 5, 1.2, 0)
warped_frame = warp_flow(curr_frame.astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame), torch.from_numpy(np.expand_dims(next_frames[0], axis=-1))
)
total_losses[0] += loss.item()
tmp_loss += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame.astype(np.float32)),
torch.from_numpy(np.expand_dims(next_frames[0], axis=-1)),
)
baseline_losses[0] += loss.item()
tmp_base += loss.item()

for i in range(1, 48):
warped_frame = warp_flow(warped_frame.astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame),
torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)),
)
total_losses[i] += loss.item()
tmp_loss += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame.astype(np.float32)),
torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)),
)
baseline_losses[i] += loss.item()
tmp_base += loss.item()
tmp_base /= 48
tmp_loss /= 48
overall_loss += tmp_loss
overall_baseline += tmp_base
print(
f"Avg Total Loss: {np.mean(total_losses) / count} Avg Baseline Loss: {np.mean(baseline_losses) / count} \n Overall Loss: {overall_loss / count} Baseline: {overall_baseline / count}"
)
if count % 100 == 0:
np.save("optical_flow_mse_loss.npy", total_losses / count)
np.save("baseline_current_image_mse_loss.npy", baseline_losses / count)
np.save("optical_flow_mse_loss.npy", total_losses / count)
np.save("baseline_current_image_mse_loss.npy", baseline_losses / count)
2 changes: 1 addition & 1 deletion satflow/configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ model_checkpoint:
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: "val/loss" # name of the logged metric which determines when model is improving
patience: 15 # how many epochs of not improving until training stops
patience: 30 # how many epochs of not improving until training stops
mode: "min" # can be "max" or "min"
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
35 changes: 35 additions & 0 deletions satflow/configs/datamodule/optical_flow_datamodule.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# @package _group_
_target_: satflow.data.datamodules.MaskFlowDataModule
batch_size: 1
data_dir: ${data_dir} # data_dir is specified in config.yaml
shuffle: 0
sources:
train: "satflow-flow-144-tiled-{00001..00105}.tar"
val: "satflow-flow-144-tiled-{00106..00129}.tar"
test: "satflow-flow-144-tiled-{00130..00149}.tar"
num_workers: 1
pin_memory: False
config:
visualize: False
num_timesteps: 1
skip_timesteps: 1
forecast_times: 48
output_shape: 400
target_type: "cloudmask"
num_crops: 1
use_topo: False
use_latlon: False
use_time: False
time_aux: False
use_mask: True
use_image: False
add_pixel_coords: False
time_as_channels: False
# NIR1.6, VIS0.8 and VIS0.6 RGB for near normal view
bands: [
"HRV",
#"IR016",
#"VIS006",
#"VIS008",
]
transforms: {}
85 changes: 85 additions & 0 deletions satflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,91 @@ def __iter__(self) -> Iterator[T_co]:
yield mask, target_mask


class OpticalFlowDataset(SatFlowDataset):
def __iter__(self) -> Iterator[T_co]:
# Need to make sure same time step for all of them.
# As its all from rapid scan, should be fairly easy.
# Main missing one is the regional and rapid weather ones, which are every 15 minutes,
# but could be interpolated between the previous step and next one by weighting by time difference
# Topographic is same of course, just need to resize to 1km x 1km?
# grid by taking the mean value of the interior ones
sources = [iter(ds) for ds in self.datasets]
while True:
for source in sources:
sample = next(source)
timesteps = pickle.loads(sample["time.pyd"])
available_steps = len(timesteps) # number of available timesteps
# Check to make sure all timesteps exist
sample_keys = [key for key in sample.keys() if self.bands[0].lower() in key]
key_checker = [
f"{self.bands[0].lower()}.{idx:03d}.npy" for idx in range(1, available_steps)
]
if (
not all(e in sample_keys for e in key_checker)
or len(sample_keys)
<= self.num_timesteps * self.skip_timesteps + self.forecast_times
):
continue # Skip this sample as it is missing timesteps, or has none
idxs = list(range(2, available_steps - self.forecast_times))
for idx in idxs:
for _ in range(self.num_crops): # Do random crops as well for training
logger.debug(f"IDX: {idx}")
image, mask = self.get_timestep(
sample,
idx,
return_target=True,
return_image=True,
) # First timestep considered
data = self.aug(image=mask)
replay = data["replay"]
mask = data["image"]
image = self.aug.replay(replay, image=image)["image"]
prev_image, prev_mask = self.get_timestep(
sample,
idx - 1,
return_target=True,
return_image=True,
) # First timestep considered
prev_mask = self.aug.replay(replay, image=prev_mask)["image"]
prev_image = self.aug.replay(replay, image=prev_image)["image"]
# Now in a Time x W x H x Channel order
_, target_mask = self.get_timestep(
sample,
idx + self.forecast_times,
return_target=True,
return_image=False,
)
target_mask = self.aug.replay(replay, image=target_mask)["image"]
target_mask = np.expand_dims(target_mask, axis=0)

if np.isclose(np.min(target_mask), np.max(target_mask)):
continue # Ignore if target timestep has no clouds, or only clouds
# Now create stack here
for i in range(idx + 1, idx + self.forecast_times):
_, t_mask = self.get_timestep(
sample,
i,
return_target=True,
return_image=False,
)
t_mask = self.aug.replay(replay, image=t_mask)["image"]
target_mask = np.concatenate(
[np.expand_dims(t_mask, axis=0), target_mask]
)
target_mask = np.round(target_mask).astype(np.int8)
# Convert to float/half-precision
mask = np.round(mask).astype(np.int8)
prev_mask = np.round(prev_mask).astype(np.int8)
# Move channel to Time x Channel x W x H
mask = np.expand_dims(mask, axis=0)
prev_mask = np.expand_dims(prev_mask, axis=0)
mask = np.nan_to_num(mask, posinf=0.0, neginf=0.0)
prev_mask = np.nan_to_num(prev_mask, posinf=0.0, neginf=0.0)
target_mask = np.nan_to_num(target_mask, posinf=0, neginf=0)
logger.debug(f"Mask: {mask.shape} Target: {target_mask.shape}")
yield prev_mask, mask, target_mask, image, prev_image


def crop_center(img, cropx, cropy):
"""Crops center of image through timestack, fails if all the images are concatenated as channels"""
t, c, y, x = img.shape
Expand Down