diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex index b48d18ff..ed976b8d 100644 --- a/lib/axon/quantization.ex +++ b/lib/axon/quantization.ex @@ -132,7 +132,7 @@ defmodule Axon.Quantization do fun = case opts[:kernel_initializer] do init when is_atom(init) -> - apply(Axon.Initializers, []) + apply(Axon.Initializers, init, []) fun when is_function(fun) -> fun diff --git a/test/axon/quantization_test.exs b/test/axon/quantization_test.exs index 4a289ce0..3d728158 100644 --- a/test/axon/quantization_test.exs +++ b/test/axon/quantization_test.exs @@ -42,4 +42,18 @@ defmodule Axon.QuantizationTest do assert_equal(predict_fn.(quantized_model_state, inp), real_fn.(quantized_model_state, inp)) end end + + describe "weight_only_quantized_dense" do + test "inits and executes properly" do + model = + Axon.input("input") + |> Axon.Quantization.weight_only_quantized_dense(10) + + assert {init_fn, _} = Axon.build(model) + assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty()) + + assert {_, predict_fn} = Axon.build(model) + assert predict_fn.(model_state, Nx.broadcast(1.0, {1, 1})) + end + end end