Skip to content

Commit

Permalink
Thread for transformations like reproject, simplify, etc (#24)
Browse files Browse the repository at this point in the history
* threaded apply

* apply

* test threading a little

* tweaks

* clean up
  • Loading branch information
rafaqz authored Nov 2, 2023
1 parent 8851c23 commit 049c151
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 31 deletions.
70 changes: 57 additions & 13 deletions src/primitives.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# # Primitive functions

# This file mainly defines the [`apply`](@ref) function.
Expand Down Expand Up @@ -27,15 +28,24 @@ apply(f, ::Type{Target}, geom; kw...) where Target = _apply(f, Target, geom; kw.

_apply(f, ::Type{Target}, geom; kw...) where Target =
_apply(f, Target, GI.trait(geom), geom; kw...)
function _apply(f, ::Type{Target}, ::Nothing, A::AbstractArray; threaded=false, kw...) where Target
_maptasks(eachindex(A); threaded) do i
_apply(f, Target, A[i]; kw...)
end
end
# Try to _apply over iterables
_apply(f, ::Type{Target}, ::Nothing, iterable; kw...) where Target =
map(x -> _apply(f, Target, x; kw...), iterable)
# Rewrap feature collections
function _apply(f, ::Type{Target}, ::GI.FeatureCollectionTrait, fc; crs=GI.crs(fc), calc_extent=false) where Target
applicator(feature) = _apply(f, Target, feature; crs, calc_extent)::GI.Feature
features = map(applicator, GI.getfeature(fc))
function _apply(f, ::Type{Target}, ::GI.FeatureCollectionTrait, fc;
crs=GI.crs(fc), calc_extent=false, threaded=false
) where Target
features = _maptasks(1:GI.nfeature(fc); threaded) do i
feature = GI.getfeature(fc, i)
_apply(f, Target, feature; crs, calc_extent)::GI.Feature
end
if calc_extent
extent = rebuce(features; init=GI.extent(first(features))) do (acc, f)
extent = reduce(features; init=GI.extent(first(features))) do acc, f
Extents.union(acc, Extents.extent(f))
end
return GI.FeatureCollection(features; crs, extent)
Expand All @@ -44,7 +54,9 @@ function _apply(f, ::Type{Target}, ::GI.FeatureCollectionTrait, fc; crs=GI.crs(f
end
end
# Rewrap features
function _apply(f, ::Type{Target}, ::GI.FeatureTrait, feature; crs=GI.crs(feature), calc_extent=false) where Target
function _apply(f, ::Type{Target}, ::GI.FeatureTrait, feature;
crs=GI.crs(feature), calc_extent=false, threaded=false
) where Target
properties = GI.properties(feature)
geometry = _apply(f, Target, GI.geometry(feature); crs, calc_extent)
if calc_extent
Expand All @@ -56,11 +68,12 @@ function _apply(f, ::Type{Target}, ::GI.FeatureTrait, feature; crs=GI.crs(featur
end
# Reconstruct nested geometries
function _apply(f, ::Type{Target}, trait, geom;
crs=GI.crs(geom), calc_extent=false
crs=GI.crs(geom), calc_extent=false, threaded=false
)::(GI.geointerface_geomtype(trait)) where Target
# TODO handle zero length...
applicator(g) = _apply(f, Target, g; crs, calc_extent)
geoms = map(applicator, GI.getgeom(geom))
geoms = _maptasks(1:GI.ngeom(geom); threaded) do i
_apply(f, Target, GI.getgeom(geom, i); crs, calc_extent)
end
if calc_extent
extent = GI.extent(first(geoms))
for g in geoms
Expand All @@ -72,14 +85,14 @@ function _apply(f, ::Type{Target}, trait, geom;
end
end
# Apply f to the target geometry
_apply(f, ::Type{Target}, ::Trait, geom; crs=GI.crs(geom), calc_extent=false) where {Target,Trait<:Target} = f(geom)
_apply(f, ::Type{Target}, ::Trait, geom; crs=GI.crs(geom), kw...) where {Target,Trait<:Target} = f(geom)
# Fail if we hit PointTrait without running `f`
_apply(f, ::Type{Target}, trait::GI.PointTrait, geom; crs=nothing, calc_extent=false) where Target =
_apply(f, ::Type{Target}, trait::GI.PointTrait, geom; crs=nothing, kw...) where Target =
throw(ArgumentError("target $Target not found, but reached a `PointTrait` leaf"))
# Specific cases to avoid method ambiguity
_apply(f, ::Type{GI.PointTrait}, trait::GI.PointTrait, geom; crs=nothing, calc_extent=false) = f(geom)
_apply(f, ::Type{GI.FeatureTrait}, ::GI.FeatureTrait, feature; crs=GI.crs(feature), calc_extent=false) = f(feature)
_apply(f, ::Type{GI.FeatureCollectionTrait}, ::GI.FeatureCollectionTrait, fc; crs=GI.crs(fc)) = f(fc)
_apply(f, ::Type{GI.PointTrait}, trait::GI.PointTrait, geom; kw...) = f(geom)
_apply(f, ::Type{GI.FeatureTrait}, ::GI.FeatureTrait, feature; kw...) = f(feature)
_apply(f, ::Type{GI.FeatureCollectionTrait}, ::GI.FeatureCollectionTrait, fc; kw...) = f(fc)

"""
unwrap(target::Type{<:AbstractTrait}, obj)
Expand Down Expand Up @@ -234,3 +247,34 @@ end
function rebuild(trait::GI.PolygonTrait, geom::GB.Polygon, child_geoms; crs=nothing)
Polygon(child_geoms[1], child_geoms[2:end])
end

using Base.Threads: nthreads, @threads, @spawn


# Threading utility, modified Mason Protters threading PSA
# run `f` over ntasks, where f recieves an AbstractArray/range
# of linear indices
function _maptasks(f, taskrange; threaded=false)
if threaded
ntasks = length(taskrange)
# Customize this as needed.
# More tasks have more overhead, but better load balancing
tasks_per_thread = 2
chunk_size = max(1, ntasks ÷ (tasks_per_thread * nthreads()))
# partition the range into chunks
task_chunks = Iterators.partition(taskrange, chunk_size)
# Map over the chunks
tasks = map(task_chunks) do chunk
# Spawn a task to process this chunk
@spawn begin
# Where we map `f` over the chunk indices
map(f, chunk)
end
end

# Finally we join the results into a new vector
return reduce(vcat, map(fetch, tasks))
else
return map(f, taskrange)
end
end
2 changes: 1 addition & 1 deletion src/transformations/extent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ calculating and adding an `Extents.Extent` to all objects.
This can improve performance when extents need to be checked multiple times.
"""
embed_extent(x) = apply(extent_applicator, AbstractTrait, x)
embed_extent(x; kw...) = apply(AbstractTrait, x; kw...)

extent_applicator(x) = extent_applicator(trait(x), x)
extent_applicator(::Nothing, xs::AbstractArray) = embed_extent.(xs)
Expand Down
21 changes: 16 additions & 5 deletions src/transformations/reproject.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,18 @@ function reproject(geom;
source_crs=nothing, target_crs=nothing, transform=nothing, kw...
)
if isnothing(transform)
source_crs = isnothing(source_crs) ? GeoInterface.crs(geom) : source_crs
if isnothing(source_crs)
source_crs = if GI.trait(geom) isa Nothing && geom isa AbstractArray
GeoInterface.crs(first(geom))
else
GeoInterface.crs(geom)
end
end

# If its still nothing, error
isnothing(source_crs) && throw(ArgumentError("geom has no crs attatched. Pass a `source_crs` keyword"))

# Otherwise reproject
reproject(geom, source_crs, target_crs; kw...)
else
reproject(geom, transform; kw...)
Expand All @@ -49,16 +59,17 @@ function reproject(geom, source_crs, target_crs;
time=Inf,
always_xy=true,
transform=Proj.Transformation(Proj.CRS(source_crs), Proj.CRS(target_crs); always_xy),
kw...
)
reproject(geom, transform; time, target_crs)
reproject(geom, transform; time, target_crs, kw...)
end
function reproject(geom, transform::Proj.Transformation; time=Inf, target_crs=nothing)
function reproject(geom, transform::Proj.Transformation; time=Inf, target_crs=nothing, kw...)
if _is3d(geom)
return apply(PointTrait, geom; crs=target_crs) do p
return apply(PointTrait, geom; crs=target_crs, kw...) do p
transform(GI.x(p), GI.y(p), GI.z(p))
end
else
return apply(PointTrait, geom; crs=target_crs) do p
return apply(PointTrait, geom; crs=target_crs, kw...) do p
transform(GI.x(p), GI.y(p))
end
end
Expand Down
3 changes: 2 additions & 1 deletion src/transformations/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ GI.npoint(simple)
6
```
"""
simplify(data; calc_extent=false, kw...) = _simplify(DouglasPeucker(; kw...), data; calc_extent)
simplify(data; calc_extent=false, threaded=false, kw...) =
_simplify(DouglasPeucker(; kw...), data; calc_extent, threaded)
simplify(alg::SimplifyAlg, data; kw...) = _simplify(alg, data; kw...)

function _simplify(alg::SimplifyAlg, data; kw...)
Expand Down
3 changes: 3 additions & 0 deletions test/transformations/reproject.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,8 @@ import Proj
# Embedded crs check
@test GI.crs(multipolygon3857) == EPSG(3857)
@test GI.crs(multipolygon4326) == EPSG(4326)

# Run it threaded over 100 replicates
GO.reproject([multipolygon3857 for _ in 1:100]; target_crs=EPSG(4326), threaded=true, calc_extent=true)
end

24 changes: 13 additions & 11 deletions test/transformations/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ import GeoInterface as GI
import GeometryOps as GO
import GeoJSON

# Unncomment when JSON3 bumps a patch version
# @testset "simplify" begin
# datadir = realpath(joinpath(dirname(pathof(GO)), "../test/data"))
# fc = GeoJSON.read(joinpath(datadir, "simplify.geojson"))
# fc2 = GeoJSON.read(joinpath(datadir, "simplify2.geojson"))
@testset "simplify" begin
datadir = realpath(joinpath(dirname(pathof(GO)), "../test/data"))
fc = GeoJSON.read(joinpath(datadir, "simplify.geojson"))
fc2 = GeoJSON.read(joinpath(datadir, "simplify2.geojson"))
T = GO.RadialDistance
fcs = [fc for i in 1:100]

# for T in (GO.RadialDistance, GO.VisvalingamWhyatt, GO.DouglasPeucker)
# @test length(collect(GO.flatten(GI.PointTrait, GO.simplify(T(number=10), fc)))) == 10
# @test length(collect(GO.flatten(GI.PointTrait, GO.simplify(T(ratio=0.5), fc)))) == 39 # Half of 78
# GO.simplify(T(tol=0.001), fc)
# end
# end
for T in (GO.RadialDistance, GO.VisvalingamWhyatt, GO.DouglasPeucker)
@test length(collect(GO.flatten(GI.PointTrait, GO.simplify(T(number=10), fc)))) == 10
@test length(collect(GO.flatten(GI.PointTrait, GO.simplify(T(ratio=0.5), fc)))) == 39 # Half of 78
GO.simplify(T(tol=0.001), fc; threaded=true, calc_extent=true)
GO.simplify(T(tol=0.001), fcs; threaded=true, calc_extent=true)
end
end

0 comments on commit 049c151

Please sign in to comment.