diff --git a/src/PyramidScheme.jl b/src/PyramidScheme.jl index 10528aa..655e937 100644 --- a/src/PyramidScheme.jl +++ b/src/PyramidScheme.jl @@ -207,13 +207,40 @@ Fill the pyramids generated from the `data` with the aggregation function `func` `recursive` indicates whether higher tiles are computed from lower tiles or directly from the original data. This is an optimization which for functions like median might lead to misleading results. """ -function fill_pyramids(data, outputs,func,recursive;runner=LocalRunner, kwargs...) +function fill_pyramids(data, outputs,func,recursive;runner=LocalRunner, verbose=false, outtype=:mem, kwargs...) + t = typeof(func(zeros(eltype(data), 2,2))) + n_level = compute_nlevels(data) + input_axes = pyramidedaxes(data) + nonpyramiddims = DD.otherdims(data, input_axes) + @show nonpyramiddims + if length(input_axes) != 2 + throw(ArgumentError("Expected two spatial dimensions got $input_axes")) + end + verbose && println("Constructing output arrays") + spatialsize = size(data)[collect(DD.dimnum(data, input_axes))] + pyramid_sizes = [ceil.(Int, spatialsize ./ 2^i) for i in 1:n_level] + allsizes = [spatialsize..., [1 for o in nonpyramiddims]...] + sizeperm = [DD.dimnum(data, input_axes)..., DD.dimnum(data, nonpyramiddims)...] + permute!(allsizes, sizeperm) + @show allsizes + outputs = if outtype == :zarr + [output_zarr(n, input_axes, t, joinpath(path, string(n))) for n in 1:n_level] + elseif outtype == :mem + outmin = output_arrays(pyramid_sizes, t) + else + throw(ArgumentError("Output type not valied got $outtype expected :mem or :zarr")) + end + + verbose && println("Start computation") n_level = length(outputs) + @show typeof(data) + @show n_level + @show size.(outputs) pixel_base_size = 2^n_level pyramid_sizes = size.(outputs) tmp_sizes = [ceil(Int,pixel_base_size / 2^i) for i in 1:n_level] - - ia = InputArray(data, windows = arraywindows(size(data),pixel_base_size)) + windows = arraywindows(allsizes,pixel_base_size) + ia = InputArray(data;windows) oa = ntuple(i->create_outwindows(pyramid_sizes[i],windows = arraywindows(pyramid_sizes[i],tmp_sizes[i])),n_level) @@ -255,6 +282,8 @@ Construct a list of `RegularWindows` for the size list in `s` for windows `w`. ?? """ function arraywindows(s,w) + @show s + @show w map(s) do l RegularWindows(1,l,window=w) end @@ -308,6 +337,9 @@ Union of Dimensions which are assumed to be in space and are therefore used in t """ SpatialDim = Union{DD.Dimensions.XDim, DD.Dimensions.YDim} +pyramidedaxes(input) = filter(x-> x isa SpatialDim, DD.dims(input)) + + """ buildpyramids(path; resampling_method=mean) Build the pyramids for the zarr dataset at `path` and write the pyramid layers into the zarr folder. @@ -324,15 +356,6 @@ function buildpyramids(path; resampling_method=mean, recursive=true, runner=Loca # Build a loop for all variables in a dataset? org = Cube(path) # We run the method once to derive the output type - t = typeof(resampling_method(zeros(eltype(org), 2,2))) - n_level = compute_nlevels(org) - input_axes = filter(x-> x isa SpatialDim, DD.dims(org)) - if length(input_axes) != 2 - throw(ArgumentError("Expected two spatial dimensions got $input_axes")) - end - verbose && println("Constructing output arrays") - outarrs = [output_zarr(n, input_axes, t, joinpath(path, string(n))) for n in 1:n_level] - verbose && println("Start computation") fill_pyramids(org, outarrs, resampling_method, recursive;runner) pyraxs = [agg_axis.(input_axes, 2^n) for n in 1:n_level] pyrlevels = DD.DimArray.(outarrs, pyraxs) @@ -376,7 +399,7 @@ Compute the data of the pyramids of a given data cube `ras`. This returns the data of the pyramids and the dimension values of the aggregated axes. """ function getpyramids(reducefunc, ras;recursive=true) - input_axes = DD.dims(ras) + input_axes = pyramidedaxes(ras) n_level = compute_nlevels(ras) if iszero(n_level) @info "Array is smaller than the tilesize no pyramids are computed"