Skip to content

Commit

Permalink
fix GNO example
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Aug 6, 2022
1 parent 54602e6 commit 1176c51
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions example/FlowOverCircle/src/FlowOverCircle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,35 @@ function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
data = gen_data(ts)
𝐱, 𝐲 = data[:, :, :, 1:(end - 1)], data[:, :, :, 2:end]
n = length(ts) - 1
coord = generate_coordinates(𝐱[1, :, :, 1])
coord = repeat(coord, outer = (1, 1, 1, n))

if flatten
𝐱, 𝐲 = reshape(𝐱, 1, :, n), reshape(𝐲, 1, :, n)
coord = reshape(coord, size(coord, 1), :, n)
end

data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at = ratio)
coord = vcat(𝐱, coord)
data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲, coord)), at = ratio)

loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)
loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)

return loader_train, loader_test
end

function generate_coordinates(A::AbstractArray)
dims = size(A)
N = length(dims)
colons = ntuple(i -> Colon(), N)
coord = similar(A, N, dims...)
for i in 1:N
ones = ntuple(x -> 1, i - 1)
coord[i, colons...] .= reshape(1:dims[i], ones..., :)
end
return coord
end

function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
if cuda && CUDA.has_cuda()
device = gpu
Expand Down Expand Up @@ -84,12 +100,18 @@ function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
@info "Training on CPU"
end

coord_dim = 2
edge_dim = 2(coord_dim + 1)
featured_graph = FeaturedGraph(grid([96, 64]))
model = Chain(Dense(1, 16),
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
WithGraph(featured_graph,
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
WithGraph(featured_graph,
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
WithGraph(featured_graph,
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
WithGraph(featured_graph,
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
Dense(16, 1))
data = get_dataloader(batchsize = 16, flatten = true)
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))
Expand Down

0 comments on commit 1176c51

Please sign in to comment.