diff --git a/example/FlowOverCircle/src/FlowOverCircle.jl b/example/FlowOverCircle/src/FlowOverCircle.jl index 8a4354b3..e9abba02 100644 --- a/example/FlowOverCircle/src/FlowOverCircle.jl +++ b/example/FlowOverCircle/src/FlowOverCircle.jl @@ -36,12 +36,16 @@ function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000), data = gen_data(ts) 𝐱, 𝐲 = data[:, :, :, 1:(end - 1)], data[:, :, :, 2:end] n = length(ts) - 1 + coord = generate_coordinates(𝐱[1, :, :, 1]) + coord = repeat(coord, outer = (1, 1, 1, n)) if flatten 𝐱, 𝐲 = reshape(𝐱, 1, :, n), reshape(𝐲, 1, :, n) + coord = reshape(coord, size(coord, 1), :, n) end - data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at = ratio) + coord = vcat(𝐱, coord) + data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲, coord)), at = ratio) loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true) loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false) @@ -49,6 +53,18 @@ function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000), return loader_train, loader_test end +function generate_coordinates(A::AbstractArray) + dims = size(A) + N = length(dims) + colons = ntuple(i -> Colon(), N) + coord = similar(A, N, dims...) + for i in 1:N + ones = ntuple(x -> 1, i - 1) + coord[i, colons...] .= reshape(1:dims[i], ones..., :) + end + return coord +end + function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50) if cuda && CUDA.has_cuda() device = gpu @@ -84,12 +100,18 @@ function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50) @info "Training on CPU" end + coord_dim = 2 + edge_dim = 2(coord_dim + 1) featured_graph = FeaturedGraph(grid([96, 64])) model = Chain(Dense(1, 16), - WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)), - WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)), - WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)), - WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)), + WithGraph(featured_graph, + GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)), + WithGraph(featured_graph, + GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)), + WithGraph(featured_graph, + GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)), + WithGraph(featured_graph, + GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)), Dense(16, 1)) data = get_dataloader(batchsize = 16, flatten = true) optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))