Skip to content

Commit

Permalink
Merge pull request #4 from EBothereau/patch-2
Browse files Browse the repository at this point in the history
Functors for back propagation KAN.jl
  • Loading branch information
rbSparky authored May 19, 2024
2 parents 6958c05 + 11579e6 commit 694feca
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/KAN.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module KAN
export KANLinear, KAN, update_grid!, regularization_loss

using Flux
using Flux: sigmoid
using Statistics
using LinearAlgebra
Expand Down Expand Up @@ -49,6 +49,8 @@ module KAN
enable_standalone_scale_spline::Bool
end

Flux.@functor KANLinear base_weight,spline_weight

function KANLinear(in_features, out_features; grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0,
scale_spline=1.0, enable_standalone_scale_spline=true, base_activation=sigmoid, grid_eps=0.02, grid_range=(-1, 1))
h = (grid_range[2] - grid_range[1]) / grid_size
Expand Down Expand Up @@ -89,6 +91,8 @@ module KAN
layers::Vector{KANLinear}
end

Flux.@functor KAN

function KAN(layers_hidden; grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=sigmoid, grid_eps=0.02, grid_range=(-1, 1))
layers = [KANLinear(inf, outf; grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range) for (inf, outf) in zip(layers_hidden[1:end-1], layers_hidden[2:end])]
KAN(layers)
Expand Down

0 comments on commit 694feca

Please sign in to comment.