diff --git a/ext/InterpolationsRegridderExt.jl b/ext/InterpolationsRegridderExt.jl index 3b24ceae..9b90dd1d 100644 --- a/ext/InterpolationsRegridderExt.jl +++ b/ext/InterpolationsRegridderExt.jl @@ -12,6 +12,7 @@ struct InterpolationsRegridder{ SPACE <: ClimaCore.Spaces.AbstractSpace, FIELD <: ClimaCore.Fields.Field, BC, + GITP, } <: Regridders.AbstractRegridder """ClimaCore.Space where the output Field will be defined""" @@ -22,6 +23,14 @@ struct InterpolationsRegridder{ """Tuple of extrapolation conditions as accepted by Interpolations.jl""" extrapolation_bc::BC + + # This is needed because Adapt moves from CPU to GPU and allocates new memory. + """Dictionary of preallocated areas of memory where to store the GPU interpolant (if + needed). Every time new data/dimensions are used in regrid, a new entry in the + dictionary is created. The keys of the dictionary a tuple of tuple + `(size(dimensions), size(data))`, with `dimensions` and `data` defined in `regrid`. + """ + _gpuitps::GITP end # Note, we swap Lat and Long! This is because according to the CF conventions longitude @@ -75,13 +84,38 @@ function Regridders.InterpolationsRegridder( "Number of boundary conditions does not match the number of dimensions", ) + # Let's figure out the type of _gpuitps by creating a simple spline + FT = ClimaCore.Spaces.undertype(target_space) + dimensions = ntuple(_ -> [zero(FT), one(FT)], num_dimensions) + data = zeros(FT, ntuple(_ -> 2, num_dimensions)) + itp = _create_linear_spline(FT, data, dimensions, extrapolation_bc) + fake_gpuitp = Adapt.adapt(ClimaComms.array_type(target_space), itp) + gpuitps = Dict((size.(dimensions), size(data)) => fake_gpuitp) + return InterpolationsRegridder( target_space, coordinates, - extrapolation_bc + extrapolation_bc, + gpuitps, + ) +end + +""" + _create_linear_spline(regridder::InterpolationsRegridder, data, dimensions) + +Create a linear spline for the given data on the given dimension (on the CPU). +""" +function _create_linear_spline(FT, data, dimensions, extrapolation_bc) + dimensions_FT = map(d -> FT.(d), dimensions) + + # Make a linear spline + return Intp.extrapolate( + Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())), + extrapolation_bc, ) end + """ regrid(regridder::InterpolationsRegridder, data, dimensions)::Field @@ -91,16 +125,31 @@ This function is allocating. """ function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions) FT = ClimaCore.Spaces.undertype(regridder.target_space) - dimensions_FT = map(d -> FT.(d), dimensions) - - # Make a linear spline - itp = Intp.extrapolate( - Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())), - regridder.extrapolation_bc, - ) + itp = + _create_linear_spline(FT, data, dimensions, regridder.extrapolation_bc) + + key = (size.(dimensions), size(data)) + + if haskey(regridder._gpuitps, key) + for (k, k_new) in zip( + regridder._gpuitps[key].itp.knots, + Adapt.adapt( + ClimaComms.array_type(regridder.target_space), + itp.itp.knots, + ), + ) + k .= k_new + end + regridder._gpuitps[key].itp.coefs .= Adapt.adapt( + ClimaComms.array_type(regridder.target_space), + itp.itp.coefs, + ) + else + regridder._gpuitps[key] = + Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) + end - # Move it to GPU (if needed) - gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) + gpuitp = regridder._gpuitps[key] return map(regridder.coordinates) do coord gpuitp(totuple(coord)...)