diff --git a/satflow/baseline/optical_flow.py b/satflow/baseline/optical_flow.py index 65da2fb8..76254d8c 100644 --- a/satflow/baseline/optical_flow.py +++ b/satflow/baseline/optical_flow.py @@ -1,5 +1,5 @@ import cv2 -from satflow.data.datasets import OpticalFlowDataset +from satflow.data.datasets import OpticalFlowDataset, SatFlowDataset import webdataset as wds import yaml import torch.nn.functional as F @@ -16,7 +16,7 @@ def load_config(config_file): ) dset = wds.WebDataset("/run/media/jacob/data/satflow-flow-144-tiled-{00001..00149}.tar") -dataset = OpticalFlowDataset([dset], config=config) +dataset = SatFlowDataset([dset], config=config) import matplotlib.pyplot as plt import torch @@ -33,57 +33,59 @@ def warp_flow(img, flow): debug = False total_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep +channel_total_losses = np.array([total_losses for _ in range(12)]) count = 0 baseline_losses = np.array([0.0 for _ in range(48)]) # Want to break down loss by future timestep -overall_loss = 0.0 -overall_baseline = 0.0 +channel_baseline_losses = np.array([baseline_losses for _ in range(12)]) for data in dataset: tmp_loss = 0 tmp_base = 0 count += 1 - prev_frame, curr_frame, next_frames, image, prev_image = data - prev_frame = np.moveaxis(prev_frame, [0], [2]) - curr_frame = np.moveaxis(curr_frame, [0], [2]) - flow = cv2.calcOpticalFlowFarneback(prev_image, image, None, 0.5, 3, 15, 3, 5, 1.2, 0) - warped_frame = warp_flow(curr_frame.astype(np.float32), flow) - warped_frame = np.expand_dims(warped_frame, axis=-1) - loss = F.mse_loss( - torch.from_numpy(warped_frame), torch.from_numpy(np.expand_dims(next_frames[0], axis=-1)) - ) - total_losses[0] += loss.item() - tmp_loss += loss.item() - loss = F.mse_loss( - torch.from_numpy(curr_frame.astype(np.float32)), - torch.from_numpy(np.expand_dims(next_frames[0], axis=-1)), - ) - baseline_losses[0] += loss.item() - tmp_base += loss.item() - - for i in range(1, 48): - warped_frame = warp_flow(warped_frame.astype(np.float32), flow) + past_frames, next_frames = data + prev_frame = past_frames[1] + curr_frame = past_frames[0] + # Do it for each of the 12 channels + for ch in range(12): + # prev_frame = np.moveaxis(prev_frame, [0], [2]) + # curr_frame = np.moveaxis(curr_frame, [0], [2]) + flow = cv2.calcOpticalFlowFarneback( + past_frames[1][ch], past_frames[0][ch], None, 0.5, 3, 15, 3, 5, 1.2, 0 + ) + warped_frame = warp_flow(curr_frame[ch].astype(np.float32), flow) warped_frame = np.expand_dims(warped_frame, axis=-1) loss = F.mse_loss( torch.from_numpy(warped_frame), - torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)), + torch.from_numpy(np.expand_dims(next_frames[0][ch], axis=-1)), ) - total_losses[i] += loss.item() - tmp_loss += loss.item() + channel_total_losses[ch][0] += loss.item() loss = F.mse_loss( - torch.from_numpy(curr_frame.astype(np.float32)), - torch.from_numpy(np.expand_dims(next_frames[i], axis=-1)), + torch.from_numpy(curr_frame[ch].astype(np.float32)), + torch.from_numpy(next_frames[0][ch]), ) - baseline_losses[i] += loss.item() - tmp_base += loss.item() - tmp_base /= 48 - tmp_loss /= 48 - overall_loss += tmp_loss - overall_baseline += tmp_base + channel_baseline_losses[ch][0] += loss.item() + + for i in range(1, 48): + warped_frame = warp_flow(warped_frame.astype(np.float32), flow) + warped_frame = np.expand_dims(warped_frame, axis=-1) + loss = F.mse_loss( + torch.from_numpy(warped_frame), + torch.from_numpy(np.expand_dims(next_frames[i][ch], axis=-1)), + ) + channel_total_losses[ch][i] += loss.item() + tmp_loss += loss.item() + loss = F.mse_loss( + torch.from_numpy(curr_frame[ch].astype(np.float32)), + torch.from_numpy(next_frames[i][ch]), + ) + channel_baseline_losses[ch][i] += loss.item() print( - f"Avg Total Loss: {np.mean(total_losses) / count} Avg Baseline Loss: {np.mean(baseline_losses) / count} \n Overall Loss: {overall_loss / count} Baseline: {overall_baseline / count}" + f"Avg Total Loss: {np.mean(channel_total_losses) / count} Avg Baseline Loss: {np.mean(channel_baseline_losses) / count}" ) if count % 100 == 0: - np.save("optical_flow_mse_loss.npy", total_losses / count) - np.save("baseline_current_image_mse_loss.npy", baseline_losses / count) -np.save("optical_flow_mse_loss.npy", total_losses / count) -np.save("baseline_current_image_mse_loss.npy", baseline_losses / count) + np.save("optical_flow_mse_loss_channels_reverse.npy", channel_total_losses / count) + np.save( + "baseline_current_image_mse_loss_channels_reverse.npy", channel_baseline_losses / count + ) +np.save("optical_flow_mse_loss_reverse.npy", channel_total_losses / count) +np.save("baseline_current_image_mse_loss_reverse.npy", channel_baseline_losses / count) diff --git a/satflow/configs/datamodule/optical_flow_datamodule.yaml b/satflow/configs/datamodule/optical_flow_datamodule.yaml index 0a834047..98001b9e 100644 --- a/satflow/configs/datamodule/optical_flow_datamodule.yaml +++ b/satflow/configs/datamodule/optical_flow_datamodule.yaml @@ -21,15 +21,24 @@ config: use_latlon: False use_time: False time_aux: False - use_mask: True - use_image: False + use_mask: False + use_image: True add_pixel_coords: False time_as_channels: False # NIR1.6, VIS0.8 and VIS0.6 RGB for near normal view - bands: [ + bands: + [ "HRV", - #"IR016", - #"VIS006", - #"VIS008", + "IR016", + "IR039", + "IR087", + "IR097", + "IR108", + "IR120", + "IR134", + "VIS006", + "VIS008", + "WV062", + "WV073", ] transforms: {} diff --git a/satflow/data/datasets.py b/satflow/data/datasets.py index bb994916..a4c9fc79 100644 --- a/satflow/data/datasets.py +++ b/satflow/data/datasets.py @@ -493,7 +493,7 @@ def __iter__(self): if not self.use_image: yield image, target_mask else: - yield image, target_image, target_mask + yield image, target_image def get_topo_latlon(self, sample: dict) -> None: if self.use_topo: diff --git a/satflow/tests/test_dataset.py b/satflow/tests/test_dataset.py index 36b37526..46f3766c 100644 --- a/satflow/tests/test_dataset.py +++ b/satflow/tests/test_dataset.py @@ -52,13 +52,11 @@ def test_satflow_all(): config = load_config("satflow/tests/configs/satflow_all.yaml") cloudflow = SatFlowDataset([dataset], config) data = next(iter(cloudflow)) - x, image, y = data + x, image = data channels = check_channels(config) assert x.shape == (13, channels, 128, 128) - assert y.shape == (24, 1, 128, 128) assert image.shape == (24, 12, 128, 128) assert not np.allclose(image[0], image[-1]) - assert not np.allclose(y[0], y[-1]) def test_satflow_large(): @@ -66,13 +64,11 @@ def test_satflow_large(): config = load_config("satflow/tests/configs/satflow_large.yaml") cloudflow = SatFlowDataset([dataset], config) data = next(iter(cloudflow)) - x, image, y = data + x, image = data channels = check_channels(config) assert x.shape == (13, channels, 256, 256) - assert y.shape == (24, 1, 256, 256) assert image.shape == (24, 12, 256, 256) assert not np.allclose(image[0], image[-1]) - assert not np.allclose(y[0], y[-1]) def test_satflow_crop(): @@ -80,10 +76,9 @@ def test_satflow_crop(): config = load_config("satflow/tests/configs/satflow_crop.yaml") cloudflow = SatFlowDataset([dataset], config) data = next(iter(cloudflow)) - x, image, y = data + x, image = data channels = check_channels(config) assert x.shape == (13, channels, 256, 256) - assert y.shape == (24, 1, 64, 64) assert image.shape == (24, 12, 64, 64) assert not np.allclose(image[0], image[-1]) assert not np.allclose(x[0], x[-1]) @@ -117,13 +112,11 @@ def test_satflow_time_channels_all(): config = load_config("satflow/tests/configs/satflow_time_channels_all.yaml") cloudflow = SatFlowDataset([dataset], config) data = next(iter(cloudflow)) - x, image, y = data + x, image = data channels = check_channels(config) assert x.shape == (channels, 128, 128) - assert y.shape == (24, 128, 128) assert image.shape == (12 * 24, 128, 128) assert not np.allclose(image[0], image[-1]) - assert not np.allclose(y[0], y[-1]) def test_cloudflow(): @@ -141,16 +134,14 @@ def test_satflow_all_deterministic_validation(): config = load_config("satflow/tests/configs/satflow_all.yaml") cloudflow = SatFlowDataset([dataset], config, train=False) data = next(iter(cloudflow)) - x, image, y = data + x, image = data dataset2 = wds.WebDataset("datasets/satflow-test.tar") cloudflow2 = SatFlowDataset([dataset2], config, train=False) data = next(iter(cloudflow2)) - x2, image2, y2 = data + x2, image2 = data np.testing.assert_almost_equal(x, x2) np.testing.assert_almost_equal(image, image2) - np.testing.assert_almost_equal(y, y2) assert not np.allclose(image[0], image[-1]) - assert not np.allclose(y[0], y[-1]) def test_satflow_all_deterministic_validation_restart(): @@ -158,11 +149,9 @@ def test_satflow_all_deterministic_validation_restart(): config = load_config("satflow/tests/configs/satflow_all.yaml") cloudflow = SatFlowDataset([dataset], config, train=False) data = next(iter(cloudflow)) - x, image, y = data + x, image = data data = next(iter(cloudflow)) - x2, image2, y2 = data + x2, image2 = data np.testing.assert_almost_equal(x, x2) np.testing.assert_almost_equal(image, image2) - np.testing.assert_almost_equal(y, y2) assert not np.allclose(image[0], image[-1]) - assert not np.allclose(y[0], y[-1])