From c1c94bfaa47a23158234db10d3ce70534fce3a82 Mon Sep 17 00:00:00 2001 From: nialov Date: Fri, 1 Nov 2024 11:02:37 +0200 Subject: [PATCH] fix: remove pygeos references Also resulted in clean up of code in fractopo/ and in tests/. --- fractopo/analysis/contour_grid.py | 44 ++--- fractopo/analysis/network.py | 8 +- fractopo/branches_and_nodes.py | 150 +++--------------- fractopo/cli.py | 11 +- fractopo/fractopo_utils.py | 16 +- fractopo/general.py | 95 ++--------- fractopo/tval/proximal_traces.py | 15 +- fractopo/tval/trace_validation.py | 18 +-- fractopo/tval/trace_validation_utils.py | 6 +- tests/__init__.py | 15 +- tests/analysis/test_contour_grid.py | 10 +- .../test_branches_and_nodes.py | 21 +-- 12 files changed, 81 insertions(+), 328 deletions(-) diff --git a/fractopo/analysis/contour_grid.py b/fractopo/analysis/contour_grid.py index f4ca9787..1c784299 100755 --- a/fractopo/analysis/contour_grid.py +++ b/fractopo/analysis/contour_grid.py @@ -1,13 +1,14 @@ """ Scripts for creating sample grids for fracture trace, branch and node data. """ + import logging import platform -from typing import Dict, Optional +from typing import Any, Dict, Optional import geopandas as gpd import numpy as np -from geopandas.sindex import PyGEOSSTRTreeIndex +from geopandas.sindex import SpatialIndex from joblib import Parallel, delayed from shapely.geometry import LineString, Point, Polygon @@ -23,7 +24,6 @@ Param, crop_to_target_areas, geom_bounds, - pygeos_spatial_index, safe_buffer, spatial_index_intersection, ) @@ -49,11 +49,11 @@ def create_grid(cell_width: float, lines: gpd.GeoDataFrame) -> gpd.GeoDataFrame: ... ) >>> create_grid(cell_width=0.1, lines=lines).head(5) geometry - 0 POLYGON ((-2.00000 5.00000, -1.90000 5.00000, ... - 1 POLYGON ((-2.00000 4.90000, -1.90000 4.90000, ... - 2 POLYGON ((-2.00000 4.80000, -1.90000 4.80000, ... - 3 POLYGON ((-2.00000 4.70000, -1.90000 4.70000, ... - 4 POLYGON ((-2.00000 4.60000, -1.90000 4.60000, ... + 0 POLYGON ((-2 5, -1.9 5, -1.9 4.9, -2 4.9, -2 5)) + 1 POLYGON ((-2 4.9, -1.9 4.9, -1.9 4.8, -2 4.8, ... + 2 POLYGON ((-2 4.8, -1.9 4.8, -1.9 4.7, -2 4.7, ... + 3 POLYGON ((-2 4.7, -1.9 4.7, -1.9 4.6, -2 4.6, ... + 4 POLYGON ((-2 4.6, -1.9 4.6, -1.9 4.5, -2 4.5, ... """ assert cell_width > 0 assert len(lines) > 0 @@ -106,7 +106,7 @@ def populate_sample_cell( branches: gpd.GeoDataFrame, snap_threshold: float, resolve_branches_and_nodes: bool, - traces_sindex: Optional[PyGEOSSTRTreeIndex] = None, + traces_sindex: Optional[Any] = None, ) -> Dict[str, float]: """ Take a single grid polygon and populate it with parameters. @@ -152,7 +152,7 @@ def choose_geometries(sindex, sample_circle, geometries): assert sample_circle_area > 0 if traces_sindex is None: - traces_sindex = pygeos_spatial_index(traces) + traces_sindex: SpatialIndex = traces.sindex # Choose geometries that are either within the sample_circle or # intersect it @@ -186,13 +186,13 @@ def choose_geometries(sindex, sample_circle, geometries): is_topology_defined = branches.shape[0] > 0 if is_topology_defined: branch_candidates = choose_geometries( - sindex=pygeos_spatial_index(branches), + sindex=branches.sindex, sample_circle=sample_circle, geometries=branches, ) node_candidates = choose_geometries( - sindex=pygeos_spatial_index(nodes), + sindex=nodes.sindex, sample_circle=sample_circle, geometries=nodes, ) @@ -266,25 +266,7 @@ def sample_grid( assert isinstance(nodes_reset, gpd.GeoDataFrame) assert isinstance(branches_reset, gpd.GeoDataFrame) traces, nodes, branches = traces_reset, nodes_reset, branches_reset - # [gdf.reset_index(inplace=True, drop=True) for gdf in (traces, nodes)] - # traces_sindex = pygeos_spatial_index(traces) - # nodes_sindex = pygeos_spatial_index(nodes) - - # params_for_cells = list( - # map( - # lambda sample_cell: populate_sample_cell( - # sample_cell=sample_cell, - # sample_cell_area=sample_cell_area, - # traces_sindex=traces_sindex, - # traces=traces, - # nodes=nodes, - # branches=branches, - # snap_threshold=snap_threshold, - # resolve_branches_and_nodes=resolve_branches_and_nodes, - # ), - # grid.geometry.values, - # ) - # ) + # Use all CPUs with n_jobs=-1 # Use only one process on Windows params_for_cells = Parallel( diff --git a/fractopo/analysis/network.py b/fractopo/analysis/network.py index 68f894b1..113733fc 100644 --- a/fractopo/analysis/network.py +++ b/fractopo/analysis/network.py @@ -1,6 +1,7 @@ """ Analyse and plot trace map data with Network. """ + import logging from dataclasses import dataclass, field from functools import wraps @@ -13,6 +14,7 @@ import numpy as np import pandas as pd import powerlaw +from geopandas.sindex import SpatialIndex from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.projections import PolarAxes @@ -59,7 +61,6 @@ determine_boundary_intersecting_lines, focus_plot_to_bounds, numpy_to_python_type, - pygeos_spatial_index, raise_determination_error, remove_z_coordinates_from_geodata, sanitize_name, @@ -96,7 +97,6 @@ def wrapper(*args, **kwargs): @dataclass class Network: - """ Trace network. @@ -343,7 +343,7 @@ def __hash__(self) -> int: """ def convert_gdf( - gdf: Union[gpd.GeoDataFrame, gpd.GeoSeries, None, Polygon, MultiPolygon] + gdf: Union[gpd.GeoDataFrame, gpd.GeoSeries, None, Polygon, MultiPolygon], ) -> Optional[str]: """ Convert GeoDataFrame or geometry to (json) str. @@ -1218,7 +1218,7 @@ def estimate_censoring( # Use spatial index to filter censoring polygons that are not near the # network - sindex = pygeos_spatial_index(self.censoring_area) + sindex: SpatialIndex = self.censoring_area.sindex index_intersection = spatial_index_intersection( spatial_index=sindex, coordinates=network_area_bounds ) diff --git a/fractopo/branches_and_nodes.py b/fractopo/branches_and_nodes.py index cc44fc40..815f1574 100644 --- a/fractopo/branches_and_nodes.py +++ b/fractopo/branches_and_nodes.py @@ -3,14 +3,15 @@ branches_and_nodes is the main entrypoint. """ + import logging import math from itertools import chain, compress -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import geopandas as gpd import numpy as np -from geopandas.sindex import PyGEOSSTRTreeIndex +from geopandas.sindex import SpatialIndex from shapely.geometry import ( LineString, MultiLineString, @@ -19,8 +20,6 @@ Point, Polygon, ) -from shapely.geometry.base import BaseMultipartGeometry -from shapely.ops import unary_union from shapely.wkt import dumps from fractopo.general import ( @@ -47,7 +46,6 @@ line_intersection_to_points, numpy_to_python_type, point_to_point_unit_vector, - pygeos_spatial_index, remove_z_coordinates_from_geodata, spatial_index_intersection, ) @@ -134,7 +132,7 @@ def get_branch_identities( """ assert len(nodes) == len(node_identities) - node_spatial_index = pygeos_spatial_index(nodes) + node_spatial_index = nodes.sindex branch_identities = [] for branch in branches.geometry.values: assert isinstance(branch, LineString) @@ -292,8 +290,10 @@ def insert_point_to_linestring( t_coords = list(trace.coords) if not insert: t_coords.pop(idx) - t_coords.insert(idx, (point.x, point.y)) if not point.has_z else t_coords.insert( - idx, (point.x, point.y, point.z) + ( + t_coords.insert(idx, (point.x, point.y)) + if not point.has_z + else t_coords.insert(idx, (point.x, point.y, point.z)) ) # Closest points might not actually be the points which in between the # point is added. Have to use project and interpolate (?) @@ -421,7 +421,7 @@ def snap_traces( assert all(isinstance(trace, LineString) for trace in traces) # Spatial index for traces - traces_spatial_index = pygeos_spatial_index(geodataset=gpd.GeoSeries(traces)) + traces_spatial_index = gpd.GeoSeries(traces).sindex # Collect simply snapped (and non-snapped) traces to list simply_snapped_traces, simple_changes = zip( @@ -465,12 +465,12 @@ def snap_traces( def resolve_trace_candidates( trace: LineString, idx: int, - traces_spatial_index: PyGEOSSTRTreeIndex, + traces_spatial_index, traces: List[LineString], snap_threshold: float, ) -> List[LineString]: """ - Resolve PyGEOSSTRTreeIndex intersection to actual intersection candidates. + Resolve spatial index intersection to actual intersection candidates. """ assert isinstance(trace, LineString) @@ -517,7 +517,7 @@ def snap_trace_simple( trace: LineString, snap_threshold: float, traces: List[LineString], - traces_spatial_index: PyGEOSSTRTreeIndex, + traces_spatial_index: Any, final_allowed_loop: bool = False, ) -> Tuple[LineString, bool]: """ @@ -553,7 +553,7 @@ def snap_others_to_trace( trace: LineString, snap_threshold: float, traces: List[LineString], - traces_spatial_index: PyGEOSSTRTreeIndex, + traces_spatial_index: Any, areas: Optional[List[Union[Polygon, MultiPolygon]]], final_allowed_loop: bool = False, ) -> Tuple[LineString, bool]: @@ -568,7 +568,7 @@ def snap_others_to_trace( >>> trace = LineString([(0, 0), (1, 0), (2, 0), (3, 0)]) >>> snap_threshold = 0.001 >>> traces = [trace, LineString([(1.5, 3), (1.5, 0.00001)])] - >>> traces_spatial_index = pygeos_spatial_index(gpd.GeoSeries(traces)) + >>> traces_spatial_index = gpd.GeoSeries(traces).sindex >>> areas = None >>> snapped = snap_others_to_trace( ... idx, trace, snap_threshold, traces, traces_spatial_index, areas @@ -582,7 +582,7 @@ def snap_others_to_trace( >>> trace = LineString([(0, 0), (1, 0), (2, 0), (3, 0)]) >>> snap_threshold = 0.001 >>> traces = [trace, LineString([(3.0001, -3), (3.0001, 0), (3, 3)])] - >>> traces_spatial_index = pygeos_spatial_index(gpd.GeoSeries(traces)) + >>> traces_spatial_index = gpd.GeoSeries(traces).sindex >>> areas = None >>> snapped = snap_others_to_trace( ... idx, trace, snap_threshold, traces, traces_spatial_index, areas @@ -784,120 +784,6 @@ def filter_non_unique_traces( return unique_traces -def safer_unary_union( - traces_geosrs: gpd.GeoSeries, snap_threshold: float, size_threshold: int -) -> MultiLineString: - """ - Perform unary union to transform traces to branch segments. - - unary_union is not completely stable with large datasets but problem can be - alleviated by dividing analysis to parts. - - TODO: Usage is deprecated as unary_union seems to give consistent results. - """ - if traces_geosrs.empty: - return MultiLineString() - - # Get amount of traces - trace_count = traces_geosrs.shape[0] - - # Only one trace and self-intersects shouldn't occur -> Simply return the - # one LineString wrapped in MultiLineString - if trace_count == 1: - return MultiLineString(list(traces_geosrs.geometry.values)) - - # Try normal union without any funny business - # This will be compared to the split approach result and better will - # be returned - normal_full_union: MultiLineString = traces_geosrs.unary_union - - if isinstance(normal_full_union, LineString): - return MultiLineString([normal_full_union]) - - if trace_count < size_threshold: - if len(normal_full_union.geoms) > trace_count and isinstance( - normal_full_union, MultiLineString - ): - return normal_full_union - - # Debugging, fail safely - if size_threshold < UNARY_ERROR_SIZE_THRESHOLD: - log.critical( - "Expected size_threshold to be higher than 100. Union might be impossible." - ) - - # How many parts - div = int(np.ceil(trace_count / size_threshold)) - - # Divide with numpy - split_traces = np.array_split(traces_geosrs, div) - assert isinstance(split_traces, list) - - # How many in each pair - # part_count = int(np.ceil(trace_count / div)) - - assert div * sum(part.shape[0] for part in split_traces) >= trace_count - assert all(isinstance(val, gpd.GeoSeries) for val in split_traces) - assert isinstance(split_traces[0].iloc[0], LineString) - - # Do unary_union in parts - part_unions = part_unary_union( - split_traces=split_traces, - snap_threshold=snap_threshold, - size_threshold=size_threshold, - div=div, - ) - # Do full union of split unions - full_union = unary_union(MultiLineString(list(chain(*part_unions)))) - - # full_union should always be better or equivalent to normal unary_union. - # (better when unary_union fails silently) - if isinstance(full_union, MultiLineString): - assert isinstance(full_union, BaseMultipartGeometry) - if len(full_union.geoms) >= len(normal_full_union.geoms): - return full_union - - raise ValueError( - "Expected split union to give better results." - " Branches and nodes should be checked for inconsistencies." - ) - if isinstance(full_union, LineString): - return MultiLineString([full_union]) - raise TypeError( - f"Expected (Multi)LineString from unary_union. Got {full_union.wkt}" - ) - - -def part_unary_union( - split_traces: list, snap_threshold: float, size_threshold: int, div: int -): - """ - Conduct safer_unary_union in parts. - """ - # Collect partly done unary_unions to part_unions list - part_unions = [] - - # Iterate over list of split trace GeoSeries - for part in split_traces: - # Do unary_union to part - part_union = part.unary_union - - # Do naive check for if unary_union is successful - if ( - not isinstance(part_union, MultiLineString) - or len(part_union.geoms) < part.shape[0] - ): - # Still fails -> Try with lower threshold for part - part_union = safer_unary_union( - part, snap_threshold, size_threshold=size_threshold // 2 - ) - - # Collect - part_unions.append(part_union.geoms) - assert len(part_unions) == div - return part_unions - - def report_snapping_loop(loops: int, allowed_loops: int): """ Report snapping looping. @@ -1015,7 +901,7 @@ def branches_and_nodes( ] # Branches are determined with shapely/geopandas unary_union - unary_union_result = traces_geosrs.unary_union + unary_union_result = traces_geosrs.union_all() if isinstance(unary_union_result, MultiLineString): branches_all = list(unary_union_result.geoms) elif isinstance(unary_union_result, LineString): @@ -1118,7 +1004,7 @@ def node_identity( idx: int, areas: Union[gpd.GeoSeries, gpd.GeoDataFrame], endpoints_geoseries: gpd.GeoSeries, - endpoints_spatial_index: PyGEOSSTRTreeIndex, + endpoints_spatial_index: Any, snap_threshold: float, ) -> str: """ @@ -1201,7 +1087,7 @@ def node_identities_from_branches( all_endpoints_geoseries = gpd.GeoSeries(all_endpoints) # Get spatial index - endpoints_spatial_index = pygeos_spatial_index(all_endpoints_geoseries) + endpoints_spatial_index: SpatialIndex = all_endpoints_geoseries.sindex # Collect resolved nodes collected_nodes: Dict[str, Tuple[Point, str]] = dict() diff --git a/fractopo/cli.py b/fractopo/cli.py index 554dc502..efd6c73e 100644 --- a/fractopo/cli.py +++ b/fractopo/cli.py @@ -12,10 +12,9 @@ from typing import Dict, Optional, Tuple, Type import click -import fiona import geopandas as gpd import pandas as pd -import pygeos +import pyogrio import typer from rich.console import Console from rich.table import Table @@ -240,12 +239,11 @@ def tracevalidate( # Set same crs as input if input had crs if input_crs is not None: - validated_trace.crs = input_crs + validated_trace = validated_trace.set_crs(input_crs) # Get input driver to use as save driver - with fiona.open(trace_file) as open_trace_file: - assert open_trace_file is not None - save_driver = open_trace_file.driver + + save_driver = pyogrio.detect_write_driver(trace_file) # Resolve output if not explicitly given if output is None: @@ -481,7 +479,6 @@ def info(): information = dict( fractopo_version=__version__, geopandas_version=gpd.__version__, - pygeos_version=pygeos.__version__, package_location=str(Path(__file__).parent.absolute()), python_location=str(Path(sys.executable).absolute()), ) diff --git a/fractopo/fractopo_utils.py b/fractopo/fractopo_utils.py index 7ea9675c..6c27714b 100644 --- a/fractopo/fractopo_utils.py +++ b/fractopo/fractopo_utils.py @@ -1,10 +1,12 @@ """ Miscellaneous utilities and scripts of fractopo. """ + from itertools import count from typing import List, Tuple, Union import geopandas as gpd +from geopandas.sindex import SpatialIndex from shapely.geometry import LineString, Point from fractopo.general import ( @@ -12,14 +14,12 @@ create_unit_vector, geom_bounds, get_trace_endpoints, - pygeos_spatial_index, safe_buffer, spatial_index_intersection, ) class LineMerge: - """ Merge lines conditionally. """ @@ -137,7 +137,7 @@ def conditional_linemerge_collection( (['LINESTRING (0 0, 0 2, 0 4)'], [0, 1]) """ - spatial_index = pygeos_spatial_index(traces) + spatial_index: SpatialIndex = traces.sindex new_traces = [] modified_idx = [] @@ -189,8 +189,8 @@ def run_loop( >>> tolerance = 5 >>> buffer_value = 0.01 >>> LineMerge.run_loop(traces, tolerance, buffer_value) - geometry - 0 LINESTRING (0.00000 0.00000, 0.00000 2.00000, ... + geometry + 0 LINESTRING (0 0, 0 2, 0 4) """ loop_count = count() @@ -221,8 +221,8 @@ def integrate_replacements( >>> new_traces = [LineString([(0, 0), (0, 2), (0, 4)])] >>> modified_idx = [0, 1] >>> LineMerge.integrate_replacements(traces, new_traces, modified_idx) - geometry - 0 LINESTRING (0.00000 0.00000, 0.00000 2.00000, ... + geometry + 0 LINESTRING (0 0, 0 2, 0 4) """ unmod_traces = [ @@ -246,7 +246,7 @@ def remove_identical_sindex( geosrs_reset = geosrs.reset_index(inplace=False, drop=True) assert isinstance(geosrs_reset, gpd.GeoSeries) geosrs = geosrs_reset - spatial_index = geosrs.sindex + spatial_index: SpatialIndex = geosrs.sindex identical_idxs = [] point: Point for idx, point in enumerate(geosrs.geometry.values): diff --git a/fractopo/general.py b/fractopo/general.py index 4a82b508..c8c2a8aa 100644 --- a/fractopo/general.py +++ b/fractopo/general.py @@ -21,9 +21,8 @@ import geopandas as gpd import numpy as np import pandas as pd -import pygeos import sklearn.metrics as sklm -from geopandas.sindex import PyGEOSSTRTreeIndex +from geopandas.sindex import SpatialIndex from joblib import Memory from matplotlib import patheffects as path_effects from matplotlib.axes import Axes @@ -646,8 +645,8 @@ def match_crs( if len(all_crs) == 1: # One valid crs in inputs crs = all_crs[0] - first.crs = crs - second.crs = crs + first = first.set_crs(crs) + second = second.set_crs(crs) return first, second # Two crs that are not the same return first, second @@ -739,7 +738,7 @@ def determine_general_nodes( endpoint_nodes: List[Tuple[Point, ...]] = [] # spatial_index = traces.geometry.sindex try: - spatial_index = pygeos_spatial_index(traces.geometry) + spatial_index = traces.sindex except TypeError: spatial_index = None for idx, geom in enumerate(traces.geometry.values): @@ -977,7 +976,7 @@ def determine_node_junctions( # Create spatial index of nodes # nodes_geoseries_sindex = flattened_nodes_geoseries.sindex - nodes_geoseries_sindex = pygeos_spatial_index(flattened_nodes_geoseries) + nodes_geoseries_sindex: SpatialIndex = flattened_nodes_geoseries.sindex # Set collection for indexes with junctions indexes_with_junctions: Set[int] = set() @@ -1167,45 +1166,6 @@ def mls_to_ls(multilinestrings: List[MultiLineString]) -> List[LineString]: return linestrings -def efficient_clip( - traces: Union[gpd.GeoSeries, gpd.GeoDataFrame], - areas: Union[gpd.GeoSeries, gpd.GeoDataFrame], -) -> gpd.GeoDataFrame: - """ - Perform efficient clip of LineString geometries with a Polygon. - - :param traces: Trace data. - :param areas: Area data. - :return: Traces clipped with the area data. - """ - # Transform to pygeos types - pygeos_traces = pygeos.from_shapely(traces.geometry.values) - - # Convert MultiPolygon in area_gdf to Polygons and collect to list. - polygons = [] - for geom in areas.geometry.values: - if isinstance(geom, MultiPolygon): - polygons.extend(geom.geoms) - elif isinstance(geom, Polygon): - polygons.append(geom) - else: - raise TypeError( - f"Expected (Multi)Polygons in efficient_clip." - f" Got: {geom.wkt, type(geom)}." - ) - pygeos_polygons = pygeos.from_shapely(polygons) - pygeos_multipolygon = pygeos.multipolygons(pygeos_polygons) - - # Perform intersection - intersection = pygeos.intersection(pygeos_traces, pygeos_multipolygon) - assert isinstance(intersection, np.ndarray) - - # Collect into GeoDataFrame. - geodataframe = gpd.GeoDataFrame(geometry=intersection, crs=traces.crs) - assert "geometry" in geodataframe.columns - return geodataframe - - @JOBLIB_CACHE.cache def crop_to_target_areas( traces: Union[gpd.GeoSeries, gpd.GeoDataFrame], @@ -1255,7 +1215,8 @@ def crop_to_target_areas( if not is_filtered: traces.reset_index(drop=True, inplace=True) - spatial_index = pygeos_spatial_index(traces) + spatial_index = traces.sindex + assert isinstance(spatial_index, SpatialIndex), type(spatial_index) areas_bounds = total_bounds(areas) assert len(areas_bounds) == 4 @@ -1270,22 +1231,7 @@ def crop_to_target_areas( candidate_traces = traces # TODO: Remove environment check - if keep_column_data or os.environ.get("USE_PYGEOS") == "0": - assert all(area_geom.is_valid for area_geom in areas.geometry.values) - assert all(not area_geom.is_empty for area_geom in areas.geometry.values) - # geopandas.clip keeps the column data - try: - clipped_traces = gpd.clip(candidate_traces, areas) - except TypeError: - logging.error( - "Expected to be able to clip with geopandas. " - "Falling back to pygeos clip.", - exc_info=True, - ) - clipped_traces = efficient_clip(candidate_traces, areas) - else: - # pygeos.intersection does not - clipped_traces = efficient_clip(candidate_traces, areas) + clipped_traces = gpd.clip(candidate_traces, areas) assert hasattr(clipped_traces, "geometry") assert isinstance(clipped_traces, (gpd.GeoDataFrame, gpd.GeoSeries)) @@ -1415,7 +1361,7 @@ def is_empty_area(area: gpd.GeoDataFrame, traces: gpd.GeoDataFrame): Check if any traces intersect the area(s) in area GeoDataFrame. """ for area_polygon in area.geometry.values: - sindex = pygeos_spatial_index(traces) + sindex: SpatialIndex = traces.sindex intersection = spatial_index_intersection(sindex, geom_bounds(area_polygon)) potential_traces = traces.iloc[intersection] @@ -1484,7 +1430,7 @@ def random_points_within(poly: Polygon, num_points: int) -> List[Point]: def spatial_index_intersection( - spatial_index: PyGEOSSTRTreeIndex, coordinates: Union[BoundsTuple, PointTuple] + spatial_index: Any, coordinates: Union[BoundsTuple, PointTuple] ) -> List[int]: """ Type-checked spatial index intersection. @@ -1569,24 +1515,6 @@ def total_bounds( return bounds[0], bounds[1], bounds[2], bounds[3] -def pygeos_spatial_index( - geodataset: Union[gpd.GeoDataFrame, gpd.GeoSeries], -) -> PyGEOSSTRTreeIndex: - """ - Get PyGEOSSTRTreeIndex from geopandas dataset. - - :param geodataset: Geodataset of which - spatial index is wanted. - :return: ``pygeos`` spatial index. - :raises TypeError: If the geodataset ``sindex`` attribute was not a - ``pygeos`` spatial index object. - """ - sindex = geodataset.sindex - if not isinstance(sindex, PyGEOSSTRTreeIndex): - raise TypeError("Expected PyGEOSSTRTreeIndex as spatial index.") - return sindex - - def read_geofile(path: Path) -> gpd.GeoDataFrame: """ Read a filepath for a ``GeoDataFrame`` representable geo-object. @@ -1612,7 +1540,8 @@ def determine_boundary_intersecting_lines( assert isinstance(line_gdf, (gpd.GeoSeries, gpd.GeoDataFrame)) # line_gdf = line_gdf.reset_index(inplace=False, drop=True) # spatial_index = line_gdf.sindex - spatial_index = pygeos_spatial_index(line_gdf) + spatial_index = line_gdf.sindex + assert isinstance(spatial_index, SpatialIndex), type(spatial_index) intersecting_idxs = [] cuts_through_idxs = [] for target_area in area_gdf.geometry.values: diff --git a/fractopo/tval/proximal_traces.py b/fractopo/tval/proximal_traces.py index 456df2ee..9a27c3e0 100644 --- a/fractopo/tval/proximal_traces.py +++ b/fractopo/tval/proximal_traces.py @@ -6,9 +6,11 @@ column `Merge` which has values of True or False depending on if nearby proximal traces were found. """ + from typing import List, Union import geopandas as gpd +from geopandas.sindex import SpatialIndex from shapely.geometry import LineString from fractopo.general import ( @@ -16,7 +18,6 @@ determine_regression_azimuth, geom_bounds, is_azimuth_close, - pygeos_spatial_index, safe_buffer, spatial_index_intersection, ) @@ -112,11 +113,11 @@ def determine_proximal_traces( >>> buffer_value = 1.1 >>> azimuth_tolerance = 10 >>> determine_proximal_traces(traces, buffer_value, azimuth_tolerance) - geometry Merge - 0 LINESTRING (0.00000 0.00000, 0.00000 3.00000) True - 1 LINESTRING (1.00000 0.00000, 1.00000 3.00000) True - 2 LINESTRING (5.00000 0.00000, 5.00000 3.00000) False - 3 LINESTRING (0.00000 0.00000, -3.00000 -3.00000) False + geometry Merge + 0 LINESTRING (0 0, 0 3) True + 1 LINESTRING (1 0, 1 3) True + 2 LINESTRING (5 0, 5 3) False + 3 LINESTRING (0 0, -3 -3) False """ assert isinstance(traces, (gpd.GeoSeries, gpd.GeoDataFrame)) @@ -125,7 +126,7 @@ def determine_proximal_traces( else: traces_as_gdf = traces traces_as_gdf.reset_index(inplace=True, drop=True) - spatial_index = pygeos_spatial_index(traces_as_gdf) + spatial_index: SpatialIndex = traces_as_gdf.sindex trace: LineString proximal_traces: List[int] = [] for idx, trace in enumerate(traces_as_gdf.geometry.values): diff --git a/fractopo/tval/trace_validation.py b/fractopo/tval/trace_validation.py index e6e01a1a..aad6ed29 100644 --- a/fractopo/tval/trace_validation.py +++ b/fractopo/tval/trace_validation.py @@ -12,7 +12,7 @@ from typing import Any, List, Optional, Set, Tuple import geopandas as gpd -from geopandas.sindex import PyGEOSSTRTreeIndex +from geopandas.sindex import SpatialIndex from shapely.geometry import LineString, MultiLineString, Point from fractopo.general import ( @@ -71,7 +71,7 @@ def __post_init__(self): # Private caching attributes self._endpoint_nodes: Optional[List[Tuple[Point, ...]]] = None self._intersect_nodes: Optional[List[Tuple[Point, ...]]] = None - self._spatial_index: Optional[PyGEOSSTRTreeIndex] = None + self._spatial_index: Optional[Any] = None self._faulty_junctions: Optional[Set[int]] = None self._vnodes: Optional[Set[int]] = None @@ -136,20 +136,14 @@ def intersect_nodes(self) -> List[Tuple[Point, ...]]: raise TypeError("Expected self._intersect_nodes to not be None.") @property - def spatial_index(self) -> Optional[PyGEOSSTRTreeIndex]: + def spatial_index(self) -> Optional[Any]: """ Get geopandas/pygeos spatial_index of traces. """ if self._spatial_index is None: - spatial_index = self.traces.sindex - if ( - not isinstance(spatial_index, PyGEOSSTRTreeIndex) - or len(spatial_index) == 0 - ): - log.warning( - "Expected sindex property to be of type: PyGEOSSTRTreeIndex \n" - "and non-empty." - ) + spatial_index: SpatialIndex = self.traces.sindex + if len(spatial_index) == 0: + log.warning("Expected sindex property to be non-empty.") self._spatial_index = None return self._spatial_index self._spatial_index = spatial_index diff --git a/fractopo/tval/trace_validation_utils.py b/fractopo/tval/trace_validation_utils.py index f3e1403c..bb73387c 100644 --- a/fractopo/tval/trace_validation_utils.py +++ b/fractopo/tval/trace_validation_utils.py @@ -7,7 +7,7 @@ import geopandas as gpd import numpy as np -from geopandas.sindex import PyGEOSSTRTreeIndex +from geopandas.sindex import SpatialIndex from shapely.geometry import LineString, MultiLineString, Point, Polygon from shapely.ops import split @@ -237,7 +237,7 @@ def determine_trace_candidates( geom: LineString, idx: int, traces: gpd.GeoDataFrame, - spatial_index: Optional[PyGEOSSTRTreeIndex], + spatial_index: Optional[SpatialIndex], ) -> gpd.GeoSeries: """ Determine potentially intersecting traces with spatial index. @@ -246,7 +246,7 @@ def determine_trace_candidates( log.error("Expected spatial_index not be None.") return gpd.GeoSeries() assert isinstance(traces, (gpd.GeoSeries, gpd.GeoDataFrame)) - assert isinstance(spatial_index, PyGEOSSTRTreeIndex) + assert isinstance(spatial_index, SpatialIndex) candidate_idxs = spatial_index_intersection(spatial_index, geom_bounds(geom)) candidate_idxs.remove(idx) candidate_traces: gpd.GeoSeries = traces.geometry.iloc[candidate_idxs] diff --git a/tests/__init__.py b/tests/__init__.py index bf8cb8b6..ff7f79bb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1382,18 +1382,6 @@ class ValidationParamType(NamedTuple): assert isinstance(unary_err_traces, gpd.GeoDataFrame) assert isinstance(unary_err_areas, gpd.GeoDataFrame) -test_safer_unary_union_params = [ - ( - unary_err_traces.geometry, # traces_geosrs - 0.001, # snap_threshold - 13000, # size_threshold - ), - ( - unary_err_traces.geometry, # traces_geosrs - 0.001, # snap_threshold - 50, # size_threshold - ), -] test_segment_within_buffer_params = [ (valid_geom, invalid_geom_multilinestring, 0.001, 1.1, 50, 5, True), @@ -2175,7 +2163,6 @@ def geodataframe_regression_check(file_regression, gdf: gpd.GeoDataFrame): Removes crs to avoid differences between module versions where some include it in the json and others do not. """ - gdf_copy = gdf.copy() - gdf_copy.crs = None + gdf_copy = gdf.copy().set_crs(None, allow_override=True) gdf_as_json = gdf_copy.to_json(indent=1, sort_keys=True) file_regression.check(gdf_as_json) diff --git a/tests/analysis/test_contour_grid.py b/tests/analysis/test_contour_grid.py index 5a13e578..095e47e1 100755 --- a/tests/analysis/test_contour_grid.py +++ b/tests/analysis/test_contour_grid.py @@ -13,13 +13,7 @@ import tests from fractopo.analysis import contour_grid from fractopo.analysis.network import Network -from fractopo.general import ( - CC_branch, - CI_branch, - II_branch, - Param, - pygeos_spatial_index, -) +from fractopo.general import CC_branch, CI_branch, II_branch, Param CELL_WIDTH = 0.10 BRANCHES = gpd.GeoDataFrame( @@ -160,7 +154,7 @@ def test_populate_sample_cell(sample_cell, traces, snap_threshold, branches): result = contour_grid.populate_sample_cell( sample_cell=sample_cell, sample_cell_area=sample_cell.area, - traces_sindex=pygeos_spatial_index(traces), + traces_sindex=traces.sindex, nodes=gpd.GeoDataFrame(), traces=traces, branches=branches, diff --git a/tests/branches_and_nodes/test_branches_and_nodes.py b/tests/branches_and_nodes/test_branches_and_nodes.py index 253f5f12..80073225 100644 --- a/tests/branches_and_nodes/test_branches_and_nodes.py +++ b/tests/branches_and_nodes/test_branches_and_nodes.py @@ -1,6 +1,7 @@ """ Tests for branch and node determination. """ + from typing import List import geopandas as gpd @@ -313,7 +314,7 @@ def test_snap_trace_simple( """ Test snap_trace_simple. """ - traces_spatial_index = general.pygeos_spatial_index(gpd.GeoSeries(traces)) + traces_spatial_index = gpd.GeoSeries(traces).sindex result, was_simple_snapped = branches_and_nodes.snap_trace_simple( idx, trace, snap_threshold, traces, traces_spatial_index ) @@ -322,24 +323,6 @@ def test_snap_trace_simple( assert result.intersects(traces[intersects_idx]) -@pytest.mark.parametrize( - "traces_geosrs,snap_threshold,size_threshold", tests.test_safer_unary_union_params -) -def test_safer_unary_union(traces_geosrs, snap_threshold, size_threshold): - """ - Test safer_unary_union. - """ - try: - result = branches_and_nodes.safer_unary_union( - traces_geosrs, snap_threshold, size_threshold - ) - except ValueError: - if size_threshold < branches_and_nodes.UNARY_ERROR_SIZE_THRESHOLD: - return - raise - assert len(list(result.geoms)) >= traces_geosrs.shape[0] - - @pytest.mark.parametrize( "loops,allowed_loops,will_error", tests.test_report_snapping_loop_params )