From dd2ea98f2a15206fc88d57afe04616058ed5abd5 Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Sat, 8 Jun 2024 22:44:31 -0400 Subject: [PATCH] Try to fix FlexiJoins with other predicates --- .../GeometryOpsFlexiJoinsExt.jl | 24 +++++++++++++++---- test/extensions/flexijoins.jl | 5 ++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ext/GeometryOpsFlexiJoinsExt/GeometryOpsFlexiJoinsExt.jl b/ext/GeometryOpsFlexiJoinsExt/GeometryOpsFlexiJoinsExt.jl index 165a52495..c390cf7c0 100644 --- a/ext/GeometryOpsFlexiJoinsExt/GeometryOpsFlexiJoinsExt.jl +++ b/ext/GeometryOpsFlexiJoinsExt/GeometryOpsFlexiJoinsExt.jl @@ -10,13 +10,14 @@ using SortTileRecursiveTree, Tables # This module defines the FlexiJoins APIs for GeometryOps' boolean comparison functions, taken from DE-9IM. # First, we define the joining modes (Tree, NestedLoopFast) that the GO DE-9IM functions support. -const GO_DE9IM_FUNCS = Union{typeof(GO.contains), typeof(GO.within), typeof(GO.intersects), typeof(GO.disjoint), typeof(GO.touches), typeof(GO.crosses), typeof(GO.overlaps), typeof(GO.covers), typeof(GO.coveredby), typeof(GO.equals)} +const GO_DE9IM_DIRECT_FUNCS = ((GO.contains), (GO.within), (GO.intersects), (GO.disjoint), (!(GO.disjoint)), (GO.touches), (GO.crosses), (GO.overlaps), (GO.covers), (GO.coveredby), (GO.equals)) +const GO_DE9IM_FUNC_TYPES = Union{typeof.(GO_DE9IM_DIRECT_FUNCS)..., typeof.((!).(GO_DE9IM_DIRECT_FUNCS))...} # NestedLoopFast is the naive fallback method -FlexiJoins.supports_mode(::FlexiJoins.Mode.NestedLoopFast, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNCS = true +FlexiJoins.supports_mode(::FlexiJoins.Mode.NestedLoopFast, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNC_TYPES = true # This method allows you to cache a tree, which we do by using an STRtree. # TODO: wrap GO predicate functions in a `TreeJoiner` struct or something, to indicate that we want to use trees, # since they can be slower in some situations. -FlexiJoins.supports_mode(::FlexiJoins.Mode.Tree, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNCS = true +FlexiJoins.supports_mode(::FlexiJoins.Mode.Tree, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNC_TYPES = true # Nested loop support is simple, and needs no further support. # However, for trees, we need to define how the tree is prepared and how it is used. @@ -26,8 +27,8 @@ FlexiJoins.supports_mode(::FlexiJoins.Mode.Tree, ::FlexiJoins.ByPred{F}, datas) # In theory, one could extract the tree from e.g a GeoPackage or some future GeoDataFrame. -FlexiJoins.prepare_for_join(::FlexiJoins.Mode.Tree, X, cond::FlexiJoins.ByPred{<: GO_DE9IM_FUNCS}) = (X, SortTileRecursiveTree.STRtree(map(cond.Rf, X))) -function FlexiJoins.findmatchix(::FlexiJoins.Mode.Tree, cond::FlexiJoins.ByPred{F}, ix_a, a, (B, tree)::Tuple, multi::typeof(identity)) where F <: GO_DE9IM_FUNCS +FlexiJoins.prepare_for_join(::FlexiJoins.Mode.Tree, X, cond::FlexiJoins.ByPred{<: GO_DE9IM_FUNC_TYPES}) = (X, SortTileRecursiveTree.STRtree(map(cond.Rf, X))) +function FlexiJoins.findmatchix(::FlexiJoins.Mode.Tree, cond::FlexiJoins.ByPred{F}, ix_a, a, (B, tree)::Tuple, multi::typeof(identity)) where F <: GO_DE9IM_FUNC_TYPES idxs = SortTileRecursiveTree.query(tree, cond.Lf(a)) intersecting_idxs = filter!(idxs) do idx cond.pred(a, cond.Rf(B[idx])) @@ -42,6 +43,19 @@ FlexiJoins.swap_sides(::typeof(GO.within)) = GO.contains FlexiJoins.swap_sides(::typeof(GO.coveredby)) = GO.covers FlexiJoins.swap_sides(::typeof(GO.covers)) = GO.coveredby +FlexiJoins.swap_sides(::typeof(GO.intersects)) = !GO.disjoint +FlexiJoins.swap_sides(::typeof(!(GO.disjoint))) = GO.intersects +FlexiJoins.swap_sides(::typeof(GO.disjoint)) = !GO.intersects +FlexiJoins.swap_sides(::typeof(!(GO.intersects))) = GO.disjoint + +FlexiJoins.swap_sides(::typeof(GO.touches)) = !GO.touches +FlexiJoins.swap_sides(::typeof(!(GO.touches))) = GO.touches + +FlexiJoins.swap_sides(::typeof(GO.crosses)) = !GO.crosses +FlexiJoins.swap_sides(::typeof(!(GO.crosses))) = GO.crosses + +FlexiJoins.swap_sides(::typeof(GO.equals)) = GO.equals + # That's a wrap, folks! end diff --git a/test/extensions/flexijoins.jl b/test/extensions/flexijoins.jl index 80e0432ac..adf1fb519 100644 --- a/test/extensions/flexijoins.jl +++ b/test/extensions/flexijoins.jl @@ -20,3 +20,8 @@ points_df = DataFrame(geometry = points) end +@testset "All Predicates" begin + for func in [GO.contains, GO.within, GO.intersects, GO.disjoint, GO.touches, GO.overlaps, GO.covers, GO.coveredby, GO.equals] + @test_nowarn FlexiJoins.innerjoin((poly_df, points_df), by_pred(:geometry, func, :geometry)) + end +end