Skip to content

Commit

Permalink
KAG with learnable grids is added (experimental)
Browse files Browse the repository at this point in the history
  • Loading branch information
cometscome committed Jun 3, 2024
1 parent 8845d65 commit 5b37090
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/FluxKAN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using ChainRulesCore
include("./KALnet.jl")
include("./KACnet.jl")
include("./KAGnet.jl")
include("./KAGLnet.jl")
include("./examples.jl")

end
142 changes: 142 additions & 0 deletions src/KAGLnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@



mutable struct Radial_distribution_function_L
grids::Vector{Float64}
denominator::Float64
num_grids::Int64
grid_max::Float64
grid_min::Float64
end

```
KAGLnet:
Gaussian version with learnable grids
```
mutable struct KAGLnet{in_dim,out_dim,num_grids}
base_weight
poly_weight
layer_norm
base_activation
in_dim::Int64
out_dim::Int64
num_grids::Int64
rdf::Radial_distribution_function_L
end



function Radial_distribution_function_L(num_grids, grid_min, grid_max)
grids = range(grid_min, grid_max, length=num_grids)
denominator = (grid_max - grid_min) / (num_grids - 1)
return Radial_distribution_function_L(grids, denominator, num_grids, grid_max, grid_min)
end
export Radial_distribution_function_L
Flux.@layer Radial_distribution_function_L trainable = (grids,)


function (m::Radial_distribution_function_L)(x)
y = rdf_foward_L(x, m.num_grids, m.grids, m.denominator)
end

function rdf_foward_L(x, num_grids, grids, denominator)
y = []
for n = 1:num_grids
yn = exp.(-((x .- grids[n]) ./ denominator) .^ 2)
push!(y, yn)
end
return y
end


function ChainRulesCore.rrule(::typeof(rdf_foward_L), x, num_grids, grids, denominator)
y = []
for n = 1:num_grids
yn = exp.(-((x .- grids[n]) ./ denominator) .^ 2)
push!(y, yn)
end

function pullback(ybar)
sbar = NoTangent()

dLdGdx = @thunk(begin
dy = []
for n = 1:num_grids
dyn = (-2 .* (x .- grids[n]) ./ denominator^2) .* y[n]
#dyn = 2 * (-((x .- grids[n]) ./ denominator)) * exp.(-((x .- grids[n]) ./ denominator) .^ 2)
push!(dy, dyn)
end
dLdGdx = zero(x)
for n = 1:length(ybar)
dLdGdx .+= dy[n] .* ybar[n]
end
dLdGdx
end)
dLdGdg = @thunk(begin
dy = []
for n = 1:num_grids
dyn = (2 .* (x .- grids[n]) ./ denominator^2) .* y[n]
push!(dy, dyn)
end

dLdGdg = zero(grids)
for n = 1:length(ybar)
dLdGdg[n] = sum(dy[n] .* ybar[n])
end
dLdGdg
end)

return sbar, dLdGdx, sbar, dLdGdg, sbar
end
return y, pullback
end



function KAGLnet(in_dim, out_dim; num_grids=8, base_activation=SiLU, grid_max=1, grid_min=-1)
base_weight = Dense(in_dim, out_dim; bias=false)
poly_weight = Dense(in_dim * num_grids, out_dim; bias=false)
if out_dim == 1
layer_norm = Dense(out_dim, out_dim; bias=false)
else
layer_norm = LayerNorm(out_dim)
end
rdf = Radial_distribution_function_L(num_grids, grid_min, grid_max)
return KAGLnet{in_dim,out_dim,num_grids}(base_weight,
poly_weight, layer_norm, base_activation, in_dim, out_dim, num_grids, rdf)
end
function KAGLnet(base_weight, poly_weight, layer_norm, base_activation, in_dim, out_dim, num_grids, rdf)
return KAGLnet{in_dim,out_dim,num_grids}(base_weight, poly_weight,
layer_norm, base_activation,
in_dim, out_dim, num_grids, rdf
)
end

export KAGLnet
Flux.@layer KAGLnet

function (m::KAGLnet{in_dim,out_dim,num_grids})(x) where {in_dim,out_dim,num_grids}
y = KAGLnet_forward(x, m.base_weight, m.poly_weight, m.layer_norm, m.base_activation, m.rdf)
end


function KAGLnet_forward(x, base_weight, poly_weight, layer_norm, base_activation, rdf)
# Apply base activation to input and then linear transform with base weights
base_output = base_weight(base_activation.(x))
# Normalize x to the range [-1, 1] for stable chebyshev polynomial computation
xmin = minimum(x)
xmax = maximum(x)
dx = xmax - xmin
x_normalized = normalize_x(x, xmin, dx)
# Compute chebyshev polynomials for the normalized x
chebyshev_polys = rdf(x_normalized)
#compute_chebyshev_polynomials(x_normalized, polynomial_order)
#chebyshev_polys = compute_chebyshev_polynomials(x_normalized, polynomial_order)
chebyshev_basis = cat(chebyshev_polys..., dims=1)
# Compute polynomial output using polynomial weights
poly_output = poly_weight(chebyshev_basis)
# Combine base and polynomial outputs, normalize, and activate
y = base_activation.(layer_norm(base_output .+ poly_output))
return y
end

13 changes: 12 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ function test3(method="L")
model = Chain(KACnet(k, 10), KACnet(10, 1))
elseif method == "G"
model = Chain(KAGnet(k, 10), KAGnet(10, 1))
elseif method == "GL"
model = Chain(KAGLnet(k, 10), KAGLnet(10, 1))
end
display(model)
#println("W = ", model[1].weight)
Expand Down Expand Up @@ -194,7 +196,10 @@ function test4(method="L")
model = Chain(KACnet(2, 10), KACnet(10, 1))
elseif method == "G"
model = Chain(KAGnet(2, 10), KAGnet(10, 1))
elseif method == "GL"
model = Chain(KAGLnet(2, 10), KAGLnet(10, 1))
end
display(model)

rule = Adam()
opt_state = Flux.setup(rule, model)
Expand All @@ -207,6 +212,7 @@ function test4(method="L")
j])[1] for i in x, j in y]'
#p = plot(x, y, [znn], st=:wireframe)
#savefig("dense.png")
display(model)

end
main(method)
Expand All @@ -218,18 +224,23 @@ end
test()
end
@testset "KAN" begin
# Write your tests here.
# Write your tests here.
test2()
#=
test3("L")
test3("C")
test3("G")
test3("GL")
println("test 4")
println("KAL")
test4("L")
println("KAC")
test4("C")
println("KAG")
test4("G")
=#
println("KAGL")
test4("GL")
end
end

Expand Down

0 comments on commit 5b37090

Please sign in to comment.