Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DConv layer #441

Merged
merged 15 commits into from
Jul 18, 2024
Merged

Add DConv layer #441

merged 15 commits into from
Jul 18, 2024

Conversation

aurorarossi
Copy link
Member

This PR adds the Diffusion Convolutional Layer of the following paper https://arxiv.org/pdf/1707.01926

@aurorarossi aurorarossi marked this pull request as ready for review June 6, 2024 10:37
src/layers/conv.jl Outdated Show resolved Hide resolved
h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x

T0 = x
if l.K > 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For K=1 this layer is the identity, is that correct?

end

function (l::DConv)(g::GNNGraph, x::AbstractMatrix)
A = adjacency_matrix(g, weighted = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of materializing the adjacency matrix, which currently gives a dense matrix on gpu therefore inconvenient for large graphs, the operations should be expressed through the propagate framework that relies on gather/scatter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not found a way to use propagate on the transpose adjacency matrix. Should I implement the Graphs.reverse function for GNNGraph to do this?(but probably will not be GPU compatible)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I manage to create the reverse of the graph and it is GPU compatible.

Co-authored-by: Carlo Lucibello <[email protected]>
@aurorarossi aurorarossi marked this pull request as draft June 20, 2024 10:37
@aurorarossi aurorarossi marked this pull request as ready for review July 14, 2024 18:11
src/layers/conv.jl Outdated Show resolved Hide resolved
src/layers/conv.jl Outdated Show resolved Hide resolved
src/layers/conv.jl Outdated Show resolved Hide resolved
aurorarossi and others added 3 commits July 16, 2024 10:08
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Carlo Lucibello <[email protected]>
@CarloLucibello CarloLucibello merged commit df56b7e into JuliaGraphs:master Jul 18, 2024
3 of 5 checks passed
@aurorarossi aurorarossi deleted the add-DConv branch July 18, 2024 12:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants