Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Update optical flow to look at all sat channels #70

Merged
merged 3 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 42 additions & 40 deletions satflow/baseline/optical_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import cv2
from satflow.data.datasets import OpticalFlowDataset
from satflow.data.datasets import OpticalFlowDataset, SatFlowDataset
import webdataset as wds
import yaml
import torch.nn.functional as F
Expand All @@ -16,7 +16,7 @@ def load_config(config_file):
)
dset = wds.WebDataset("/run/media/jacob/data/satflow-flow-144-tiled-{00001..00149}.tar")

dataset = OpticalFlowDataset([dset], config=config)
dataset = SatFlowDataset([dset], config=config)

import matplotlib.pyplot as plt
import torch
Expand All @@ -33,57 +33,59 @@ def warp_flow(img, flow):

debug = False
total_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
channel_total_losses = np.array([total_losses for _ in range(12)])
count = 0
baseline_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep
overall_loss = 0.0
overall_baseline = 0.0
channel_baseline_losses = np.array([baseline_losses for _ in range(12)])

for data in dataset:
tmp_loss = 0
tmp_base = 0
count += 1
prev_frame, curr_frame, next_frames, image, prev_image = data
prev_frame = np.moveaxis(prev_frame, [0], [2])
curr_frame = np.moveaxis(curr_frame, [0], [2])
flow = cv2.calcOpticalFlowFarneback(prev_image, image, None, 0.5, 3, 15, 3, 5, 1.2, 0)
warped_frame = warp_flow(curr_frame.astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame), torch.from_numpy(np.expand_dims(next_frames[0], axis=-1))
)
total_losses[0] += loss.item()
tmp_loss += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame.astype(np.float32)),
torch.from_numpy(np.expand_dims(next_frames[0], axis=-1)),
)
baseline_losses[0] += loss.item()
tmp_base += loss.item()

for i in range(1, 48):
warped_frame = warp_flow(warped_frame.astype(np.float32), flow)
past_frames, next_frames = data
prev_frame = past_frames[1]
curr_frame = past_frames[0]
# Do it for each of the 12 channels
for ch in range(12):
# prev_frame = np.moveaxis(prev_frame, [0], [2])
# curr_frame = np.moveaxis(curr_frame, [0], [2])
flow = cv2.calcOpticalFlowFarneback(
past_frames[1][ch], past_frames[0][ch], None, 0.5, 3, 15, 3, 5, 1.2, 0
)
warped_frame = warp_flow(curr_frame[ch].astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame),
torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)),
torch.from_numpy(np.expand_dims(next_frames[0][ch], axis=-1)),
)
total_losses[i] += loss.item()
tmp_loss += loss.item()
channel_total_losses[ch][0] += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame.astype(np.float32)),
torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)),
torch.from_numpy(curr_frame[ch].astype(np.float32)),
torch.from_numpy(next_frames[0][ch]),
)
baseline_losses[i] += loss.item()
tmp_base += loss.item()
tmp_base /= 48
tmp_loss /= 48
overall_loss += tmp_loss
overall_baseline += tmp_base
channel_baseline_losses[ch][0] += loss.item()

for i in range(1, 48):
warped_frame = warp_flow(warped_frame.astype(np.float32), flow)
warped_frame = np.expand_dims(warped_frame, axis=-1)
loss = F.mse_loss(
torch.from_numpy(warped_frame),
torch.from_numpy(np.expand_dims(next_frames[i][ch], axis=-1)),
)
channel_total_losses[ch][i] += loss.item()
tmp_loss += loss.item()
loss = F.mse_loss(
torch.from_numpy(curr_frame[ch].astype(np.float32)),
torch.from_numpy(next_frames[i][ch]),
)
channel_baseline_losses[ch][i] += loss.item()
print(
f"Avg Total Loss: {np.mean(total_losses) / count} Avg Baseline Loss: {np.mean(baseline_losses) / count} \n Overall Loss: {overall_loss / count} Baseline: {overall_baseline / count}"
f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}"
)
if count % 100 == 0:
np.save("optical_flow_mse_loss.npy", total_losses / count)
np.save("baseline_current_image_mse_loss.npy", baseline_losses / count)
np.save("optical_flow_mse_loss.npy", total_losses / count)
np.save("baseline_current_image_mse_loss.npy", baseline_losses / count)
np.save("optical_flow_mse_loss_channels_reverse.npy", channel_total_losses / count)
np.save(
"baseline_current_image_mse_loss_channels_reverse.npy", channel_baseline_losses / count
)
np.save("optical_flow_mse_loss_reverse.npy", channel_total_losses / count)
np.save("baseline_current_image_mse_loss_reverse.npy", channel_baseline_losses / count)
21 changes: 15 additions & 6 deletions satflow/configs/datamodule/optical_flow_datamodule.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@ config:
use_latlon: False
use_time: False
time_aux: False
use_mask: True
use_image: False
use_mask: False
use_image: True
add_pixel_coords: False
time_as_channels: False
# NIR1.6, VIS0.8 and VIS0.6 RGB for near normal view
bands: [
bands:
[
"HRV",
#"IR016",
#"VIS006",
#"VIS008",
"IR016",
"IR039",
"IR087",
"IR097",
"IR108",
"IR120",
"IR134",
"VIS006",
"VIS008",
"WV062",
"WV073",
]
transforms: {}
2 changes: 1 addition & 1 deletion satflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __iter__(self):
if not self.use_image:
yield image, target_mask
else:
yield image, target_image, target_mask
yield image, target_image

def get_topo_latlon(self, sample: dict) -> None:
if self.use_topo:
Expand Down
27 changes: 8 additions & 19 deletions satflow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,33 @@ def test_satflow_all():
config = load_config("satflow/tests/configs/satflow_all.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (13, channels, 128, 128)
assert y.shape == (24, 1, 128, 128)
assert image.shape == (24, 12, 128, 128)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_satflow_large():
dataset = wds.WebDataset("datasets/satflow-test.tar")
config = load_config("satflow/tests/configs/satflow_large.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (13, channels, 256, 256)
assert y.shape == (24, 1, 256, 256)
assert image.shape == (24, 12, 256, 256)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_satflow_crop():
dataset = wds.WebDataset("datasets/satflow-test.tar")
config = load_config("satflow/tests/configs/satflow_crop.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (13, channels, 256, 256)
assert y.shape == (24, 1, 64, 64)
assert image.shape == (24, 12, 64, 64)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(x[0], x[-1])
Expand Down Expand Up @@ -117,13 +112,11 @@ def test_satflow_time_channels_all():
config = load_config("satflow/tests/configs/satflow_time_channels_all.yaml")
cloudflow = SatFlowDataset([dataset], config)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
channels = check_channels(config)
assert x.shape == (channels, 128, 128)
assert y.shape == (24, 128, 128)
assert image.shape == (12 * 24, 128, 128)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_cloudflow():
Expand All @@ -141,28 +134,24 @@ def test_satflow_all_deterministic_validation():
config = load_config("satflow/tests/configs/satflow_all.yaml")
cloudflow = SatFlowDataset([dataset], config, train=False)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
dataset2 = wds.WebDataset("datasets/satflow-test.tar")
cloudflow2 = SatFlowDataset([dataset2], config, train=False)
data = next(iter(cloudflow2))
x2, image2, y2 = data
x2, image2 = data
np.testing.assert_almost_equal(x, x2)
np.testing.assert_almost_equal(image, image2)
np.testing.assert_almost_equal(y, y2)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])


def test_satflow_all_deterministic_validation_restart():
dataset = wds.WebDataset("datasets/satflow-test.tar")
config = load_config("satflow/tests/configs/satflow_all.yaml")
cloudflow = SatFlowDataset([dataset], config, train=False)
data = next(iter(cloudflow))
x, image, y = data
x, image = data
data = next(iter(cloudflow))
x2, image2, y2 = data
x2, image2 = data
np.testing.assert_almost_equal(x, x2)
np.testing.assert_almost_equal(image, image2)
np.testing.assert_almost_equal(y, y2)
assert not np.allclose(image[0], image[-1])
assert not np.allclose(y[0], y[-1])