Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.1.1 #74

Merged
merged 21 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/model-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ jobs:
working-directory: ${{ github.workspace }}/tests
run: |
git config --global --add safe.directory /__w/FRDC-ML/FRDC-ML
python3 -m model_tests.chestnut_dec_may.train_mixmatch
python3 -m model_tests.chestnut_dec_may.train_fixmatch
python3 -m model_tests.chestnut_dec_may.train_mixmatch

- name: Comment results via CML
run: |
Expand Down
84 changes: 76 additions & 8 deletions src/frdc/load/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.preprocessing import StandardScaler
from torch import rot90
from torch.utils.data import Dataset, ConcatDataset
from torchvision.transforms.v2.functional import hflip

from frdc.conf import (
BAND_CONFIG,
Expand Down Expand Up @@ -71,8 +74,9 @@ def __init__(
site: str,
date: str,
version: str | None,
transform: Callable[[list[np.ndarray]], Any] = None,
target_transform: Callable[[list[str]], list[str]] = None,
transform: Callable[[np.ndarray], Any] = lambda x: x,
transform_scale: bool | StandardScaler = True,
target_transform: Callable[[str], str] = lambda x: x,
use_legacy_bounds: bool = False,
polycrop: bool = False,
polycrop_value: Any = np.nan,
Expand All @@ -95,13 +99,17 @@ def __init__(
date: The date of the dataset, e.g. "20201218".
version: The version of the dataset, e.g. "183deg".
transform: The transform to apply to each segment.
transform_scale: Whether to scale the data. If True, it will fit
a StandardScaler to the data. If a StandardScaler is passed,
it will use that instead. If False, it will not scale the data.
target_transform: The transform to apply to each label.
use_legacy_bounds: Whether to use the legacy bounds.csv file.
This will automatically be set to True if LABEL_STUDIO_CLIENT
is None, which happens when Label Studio cannot be connected
to.
polycrop: Whether to further crop the segments via its polygon
bounds. The cropped area will be padded with np.nan.
polycrop_value: The value to pad the cropped area with.
"""
self.site = site
self.date = date
Expand All @@ -125,17 +133,40 @@ def __init__(
self.transform = transform
self.target_transform = target_transform

if transform_scale is True:
self.x_scaler = StandardScaler()
self.x_scaler.fit(
np.concatenate(
[
# Segments: [H x W x C] -> [H*W, C]
# Reshaping is necessary for StandardScaler
segm.reshape(-1, segm.shape[-1])
for segm in self.ar_segments
]
)
)
self.transform = lambda x: transform(
self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape(
x.shape
)
)
elif isinstance(transform_scale, StandardScaler):
self.x_scaler = transform_scale
self.transform = lambda x: transform(
self.x_scaler.transform(x.reshape(-1, x.shape[-1])).reshape(
x.shape
)
)
else:
self.x_scaler = None

def __len__(self):
return len(self.ar_segments)

def __getitem__(self, idx):
return (
self.transform(self.ar_segments[idx])
if self.transform
else self.ar_segments[idx],
self.target_transform(self.targets[idx])
if self.target_transform
else self.targets[idx],
self.transform(self.ar_segments[idx]),
self.target_transform(self.targets[idx]),
)

@property
Expand Down Expand Up @@ -305,3 +336,40 @@ def __getitem__(self, item):
if self.transform
else self.ar_segments[item]
)


class FRDCConstRotatedDataset(FRDCDataset):
def __len__(self):
"""Assume that the dataset is 8x larger than it actually is.

There are 8 possible orientations for each image.
1. As-is
2, 3, 4. Rotated 90, 180, 270 degrees
5. Horizontally flipped
6, 7, 8. Horizontally flipped and rotated 90, 180, 270 degrees
"""
return super().__len__() * 8

def __getitem__(self, idx):
"""Alter the getitem method to implement the logic above."""
x, y = super().__getitem__(int(idx // 8))
assert x.ndim == 3, "x must be a 3D tensor"
x_ = None
if idx % 8 == 0:
x_ = x
elif idx % 8 == 1:
x_ = rot90(x, 1, (1, 2))
elif idx % 8 == 2:
x_ = rot90(x, 2, (1, 2))
elif idx % 8 == 3:
x_ = rot90(x, 3, (1, 2))
elif idx % 8 == 4:
x_ = hflip(x)
elif idx % 8 == 5:
x_ = hflip(rot90(x, 1, (1, 2)))
elif idx % 8 == 6:
x_ = hflip(rot90(x, 2, (1, 2)))
elif idx % 8 == 7:
x_ = hflip(rot90(x, 3, (1, 2)))

return x_, y
80 changes: 44 additions & 36 deletions src/frdc/load/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,58 @@ def get_bounds_and_labels(self) -> tuple[list[tuple[int, int]], list[str]]:
bounds = []
labels = []

# for ann_ix, ann in enumerate(self["annotations"]):

ann = self["annotations"][0]
results = ann["result"]
for r_ix, r in enumerate(results):
r: dict
# Each annotation is an entire image labelled by a single person.
# By selecting the 0th annotation, we are usually selecting the main
# annotation.
annotation = self["annotations"][0]

# There are some metadata in `annotation`, but we just want the results
results = annotation["result"]

for bbox_ix, bbox in enumerate(results):
# 'id' = {str} 'jr4EXAKAV8'
# 'type' = {str} 'polygonlabels'
# 'value' = {dict: 3} {
# 'closed': True,
# 'points': [[x0, y0], [x1, y1], ... [xn, yn]],
# 'polygonlabels': ['label']
# }
# 'origin' = {str} 'manual'
# 'to_name' = {str} 'image'
# 'from_name' = {str} 'label'
# 'image_rotation' = {int} 0
# 'original_width' = {int} 450
# 'original_height' = {int} 600
bbox: dict

# See Issue FRML-78: Somehow some labels are actually just metadata
if r["from_name"] != "label":
if bbox["from_name"] != "label":
continue

# We flatten the value dict into the result dict
v = r.pop("value")
r = {**r, **v}
v = bbox.pop("value")
bbox = {**bbox, **v}

# Points are in percentage, we need to convert them to pixels
r["points"] = [
bbox["points"] = [
(
int(x * r["original_width"] / 100),
int(y * r["original_height"] / 100),
int(x * bbox["original_width"] / 100),
int(y * bbox["original_height"] / 100),
)
for x, y in r["points"]
for x, y in bbox["points"]
]

# Only take the first label as this is not a multi-label task
r["label"] = r.pop("polygonlabels")[0]
if not r["closed"]:
bbox["label"] = bbox.pop("polygonlabels")[0]
if not bbox["closed"]:
logger.warning(
f"Label for {r['label']} @ {r['points']} not closed. "
f"Label for {bbox['label']} @ {bbox['points']} not closed. "
f"Skipping"
)
continue

bounds.append(r["points"])
labels.append(r["label"])
bounds.append(bbox["points"])
labels.append(bbox["label"])

return bounds, labels

Expand All @@ -60,24 +77,15 @@ def get_task(
project_id: int = 1,
):
proj = LABEL_STUDIO_CLIENT.get_project(project_id)
# Get the task that has the file name
filter = Filters.create(
Filters.AND,
[
Filters.item(
# The GS path is in the image column, so we can just filter on that
Column.data("image"),
Operator.CONTAINS,
Type.String,
Path(file_name).as_posix(),
)
],
)
tasks = proj.get_tasks(filter)

if len(tasks) > 1:
task_ids = [
task["id"]
for task in proj.get_tasks()
if file_name.as_posix() in task["storage_filename"]
]

if len(task_ids) > 1:
warn(f"More than 1 task found for {file_name}, using the first one")
elif len(tasks) == 0:
elif len(task_ids) == 0:
raise ValueError(f"No task found for {file_name}")

return Task(tasks[0])
return Task(proj.get_task(task_ids[0]))
Loading
Loading