Skip to content

Commit

Permalink
Revert "Merge pull request #77 from HakaiInstitute/move-transforms"
Browse files Browse the repository at this point in the history
This reverts commit 85efc06, reversing
changes made to 0d90a83.
  • Loading branch information
tayden committed Dec 7, 2023
1 parent 85efc06 commit 2187d09
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
26 changes: 18 additions & 8 deletions kelp_o_matic/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rich import print
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
from torch.utils.data import DataLoader
from torchvision import transforms

from kelp_o_matic.geotiff_io import GeotiffReader, GeotiffWriter
from kelp_o_matic.models import _Model
Expand All @@ -17,13 +18,13 @@ class GeotiffSegmentationManager:
"""Class for configuring data io and efficient segmentation of Geotiff imagery."""

def __init__(
self,
model: "_Model",
input_path: Union[str, "Path"],
output_path: Union[str, "Path"],
crop_size: int = 512,
padding: int = 256,
batch_size: int = 1,
self,
model: "_Model",
input_path: Union[str, "Path"],
output_path: Union[str, "Path"],
crop_size: int = 512,
padding: int = 256,
batch_size: int = 1,
):
"""Create the segmentation object.
Expand All @@ -40,9 +41,18 @@ def __init__(
"""
self.model = model

tran = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda img: img[:3, :, :]),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
self.reader = GeotiffReader(
Path(input_path).expanduser().resolve(),
transform=self.model.transform,
transform=tran,
crop_size=crop_size,
padding=padding,
filter_=self._should_keep,
Expand Down
7 changes: 0 additions & 7 deletions kelp_o_matic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from abc import ABC, abstractmethod

import torch
import torchvision.transforms.functional as f

from kelp_o_matic.data import (
lraspp_kelp_presence_torchscript_path,
Expand All @@ -13,12 +12,6 @@


class _Model(ABC):
@staticmethod
def transform(x: torch.Tensor) -> torch.Tensor:
x = f.to_tensor(x)[:3, :, :] / 255.0
x = f.normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return x

def __init__(self, use_gpu: bool = True):
is_cuda = torch.cuda.is_available() and use_gpu
self.device = torch.device("cuda") if is_cuda else torch.device("cpu")
Expand Down

0 comments on commit 2187d09

Please sign in to comment.