Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

copy files over ocf_datapipes #57

Merged
merged 9 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions ocf_data_sampler/select/geospatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Geospatial functions"""

from numbers import Number
from typing import Union

import numpy as np
import pyproj
import xarray as xr

# OSGB is also called "OSGB 1936 / British National Grid -- United
# Kingdom Ordnance Survey". OSGB is used in many UK electricity
# system maps, and is used by the UK Met Office UKV model. OSGB is a
# Transverse Mercator projection, using 'easting' and 'northing'
# coordinates which are in meters. See https://epsg.io/27700
OSGB36 = 27700

# WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
# latitude and longitude.
WGS84 = 4326


_osgb_to_lon_lat = pyproj.Transformer.from_crs(
crs_from=OSGB36, crs_to=WGS84, always_xy=True
).transform
_lon_lat_to_osgb = pyproj.Transformer.from_crs(
crs_from=WGS84, crs_to=OSGB36, always_xy=True
).transform
_geod = pyproj.Geod(ellps="WGS84")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is used anywhere. If so can we strip it out?



def osgb_to_lon_lat(
x: Union[Number, np.ndarray], y: Union[Number, np.ndarray]
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Change OSGB coordinates to lon, lat.

Args:
x: osgb east-west
y: osgb north-south
Return: 2-tuple of longitude (east-west), latitude (north-south)
"""
return _osgb_to_lon_lat(xx=x, yy=y)


def lon_lat_to_osgb(
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Change lon-lat coordinates to OSGB.

Args:
x: longitude east-west
y: latitude north-south

Return: 2-tuple of OSGB x, y
"""
return _lon_lat_to_osgb(xx=x, yy=y)


def osgb_to_geostationary_area_coords(
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
xr_data: Union[xr.Dataset, xr.DataArray],
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[Union[Number, np.ndarray], Union[Number, np.ndarray]]:
"""Loads geostationary area and transformation from OSGB to geostationary coords

Args:
x: osgb east-west
y: osgb north-south
xr_data: xarray object with geostationary area

Returns:
Geostationary coords: x, y
"""
# Only load these if using geostationary projection
import pyresample

try:
area_definition_yaml = xr_data.attrs["area"]
except KeyError:
area_definition_yaml = xr_data.data.attrs["area"]
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
geostationary_area_definition = pyresample.area_config.load_area_from_string(
area_definition_yaml
)
geostationary_crs = geostationary_area_definition.crs
osgb_to_geostationary = pyproj.Transformer.from_crs(
crs_from=OSGB36, crs_to=geostationary_crs, always_xy=True
).transform
return osgb_to_geostationary(xx=x, yy=y)
AUdaltsova marked this conversation as resolved.
Show resolved Hide resolved


def _coord_priority(available_coords):
if "longitude" in available_coords:
return "lon_lat", "longitude", "latitude"
elif "x_geostationary" in available_coords:
return "geostationary", "x_geostationary", "y_geostationary"
elif "x_osgb" in available_coords:
return "osgb", "x_osgb", "y_osgb"
else:
raise ValueError(f"Unrecognized coordinate system: {available_coords}")


def spatial_coord_type(ds: xr.DataArray):
"""Searches the data array to determine the kind of spatial coordinates present.

This search has a preference for the dimension coordinates of the xarray object.

Args:
ds: Dataset with spatial coords

Returns:
str: The kind of the coordinate system
x_coord: Name of the x-coordinate
y_coord: Name of the y-coordinate
"""
if isinstance(ds, xr.DataArray):
# Search dimension coords of dataarray
coords = _coord_priority(ds.xindexes)
else:
raise ValueError(f"Unrecognized input type: {type(ds)}")

return coords
62 changes: 62 additions & 0 deletions ocf_data_sampler/select/location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""location"""

from typing import Optional

import numpy as np
from pydantic import BaseModel, Field, model_validator


allowed_coordinate_systems =["osgb", "lon_lat", "geostationary", "idx"]

class Location(BaseModel):
"""Represent a spatial location."""

coordinate_system: Optional[str] = "osgb" # ["osgb", "lon_lat", "geostationary", "idx"]
x: float
y: float
id: Optional[int] = Field(None)

@model_validator(mode='after')
def validate_coordinate_system(self):
"""Validate 'coordinate_system'"""
if self.coordinate_system not in allowed_coordinate_systems:
raise ValueError(f"coordinate_system = {self.coordinate_system} is not in {allowed_coordinate_systems}")
return self

@model_validator(mode='after')
def validate_x(self):
"""Validate 'x'"""
min_x: float
max_x: float

co = self.coordinate_system
if co == "osgb":
min_x, max_x = -103976.3, 652897.98
if co == "lon_lat":
min_x, max_x = -180, 180
if co == "geostationary":
min_x, max_x = -5568748.275756836, 5567248.074173927
if co == "idx":
min_x, max_x = 0, np.inf
if self.x < min_x or self.x > max_x:
raise ValueError(f"x = {self.x} must be within {[min_x, max_x]} for {co} coordinate system")
return self

@model_validator(mode='after')
def validate_y(self):
"""Validate 'y'"""
min_y: float
max_y: float

co = self.coordinate_system
if co == "osgb":
min_y, max_y = -16703.87, 1199851.44
if co == "lon_lat":
min_y, max_y = -90, 90
if co == "geostationary":
min_y, max_y = 1393687.2151494026, 5570748.323202133
if co == "idx":
min_y, max_y = 0, np.inf
if self.y < min_y or self.y > max_y:
raise ValueError(f"y = {self.y} must be within {[min_y, max_y]} for {co} coordinate system")
return self
23 changes: 7 additions & 16 deletions ocf_data_sampler/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import numpy as np
import xarray as xr

from ocf_datapipes.utils import Location
from ocf_datapipes.utils.geospatial import (
lon_lat_to_geostationary_area_coords,
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.select.geospatial import (
lon_lat_to_osgb,
osgb_to_geostationary_area_coords,
osgb_to_lon_lat,
spatial_coord_type,
)
from ocf_datapipes.utils.utils import searchsorted


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,9 +44,6 @@ def convert_coords_to_match_xarray(
if from_coords == "osgb":
x, y = osgb_to_geostationary_area_coords(x, y, da)

elif from_coords == "lon_lat":
x, y = lon_lat_to_geostationary_area_coords(x, y, da)

AUdaltsova marked this conversation as resolved.
Show resolved Hide resolved
elif target_coords == "lon_lat":
if from_coords == "osgb":
x, y = osgb_to_lon_lat(x, y)
Expand Down Expand Up @@ -97,8 +93,8 @@ def _get_idx_of_pixel_closest_to_poi(
x_index = da.get_index(x_dim)
y_index = da.get_index(y_dim)

closest_x = x_index.get_indexer([x], method="nearest")[0]
closest_y = y_index.get_indexer([y], method="nearest")[0]
closest_x = float(x_index.get_indexer([x], method="nearest")[0])
closest_y = float(y_index.get_indexer([y], method="nearest")[0])
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved

return Location(x=closest_x, y=closest_y, coordinate_system="idx")

Expand Down Expand Up @@ -130,13 +126,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
f"{y} is not in the interval {da[y_dim].min().values}: {da[y_dim].max().values}"

# Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
x_index_at_center = searchsorted(
da[x_dim].values, center_geostationary.x, assume_ascending=True
)

y_index_at_center = searchsorted(
da[y_dim].values, center_geostationary.y, assume_ascending=True
)
x_index_at_center = np.searchsorted(da[x_dim].values, center_geostationary.x)
y_index_at_center = np.searchsorted(da[y_dim].values, center_geostationary.y)
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved

return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx")

Expand Down
67 changes: 67 additions & 0 deletions tests/select/test_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from ocf_data_sampler.select.location import Location
import pytest


def test_make_valid_location_object_with_default_coordinate_system():
x, y = -1000.5, 50000
location = Location(x=x, y=y)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == "osgb"
), "location.coordinate_system value not set correctly"


def test_make_valid_location_object_with_osgb_coordinate_system():
x, y, coordinate_system = 1.2, 22.9, "osgb"
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == coordinate_system
), "location.coordinate_system value not set correctly"


def test_make_valid_location_object_with_lon_lat_coordinate_system():
x, y, coordinate_system = 1.2, 1.2, "lon_lat"
location = Location(x=x, y=y, coordinate_system=coordinate_system)
assert location.x == x, "location.x value not set correctly"
assert location.y == y, "location.x value not set correctly"
assert (
location.coordinate_system == coordinate_system
), "location.coordinate_system value not set correctly"


def test_make_invalid_location_object_with_invalid_osgb_x():
x, y, coordinate_system = 10000000, 1.2, "osgb"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_osgb_y():
x, y, coordinate_system = 2.5, 10000000, "osgb"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_lon_lat_x():
x, y, coordinate_system = 200, 1.2, "lon_lat"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_lon_lat_y():
x, y, coordinate_system = 2.5, -200, "lon_lat"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"


def test_make_invalid_location_object_with_invalid_coordinate_system():
x, y, coordinate_system = 2.5, 1000, "abcd"
with pytest.raises(ValueError) as err:
_ = Location(x=x, y=y, coordinate_system=coordinate_system)
assert err.typename == "ValidationError"
Loading