diff --git a/src/KAN.jl b/src/KAN.jl index 485a1d8..6bf18d1 100644 --- a/src/KAN.jl +++ b/src/KAN.jl @@ -52,7 +52,7 @@ module KAN function KANLinear(in_features, out_features; grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=true, base_activation=sigmoid, grid_eps=0.02, grid_range=(-1, 1)) h = (grid_range[2] - grid_range[1]) / grid_size - grid = [i * h + grid_range[1] for i in -spline_order:grid_size+spline_order+1] + grid = [i * h + grid_range[1] for i in -spline_order:grid_size+spline_order] grid = reshape(collect(repeat(grid, in_features)), in_features, :) base_weight = randn(out_features, in_features) * sqrt(5) * scale_base