From e82b9aca6d6736a92f8b085c8afb53e06fd33854 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 14 May 2024 10:02:36 -0400 Subject: [PATCH 1/4] Support blocks with multiple inputs --- lib/axon.ex | 29 ++++++++++++++++++----------- lib/axon/compiler.ex | 33 ++++++++++++++++++++++----------- test/axon/compiler_test.exs | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 865cb5e0..7d1d97ee 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -745,18 +745,25 @@ defmodule Axon do generated. """ @doc type: :special - def block(fun, opts \\ []) when is_function(fun) do - opts = Keyword.validate!(opts, [:name, :meta]) - block_id = System.unique_integer([:positive, :monotonic]) + def block(fun, opts \\ []) - fn inputs -> - layer(:block, List.wrap(inputs), - op_name: :block, - name: opts[:name], - meta: opts[:meta], - block_fun: fun, - block_id: block_id - ) + for i <- 1..128 do + args = Macro.generate_arguments(i, __MODULE__) + + @doc false + def block(fun, opts) when is_function(fun, unquote(i)) do + opts = Keyword.validate!(opts, [:name, :meta]) + block_id = System.unique_integer([:positive, :monotonic]) + + fn unquote_splicing(args) -> + layer(:block, List.wrap(unquote(args)), + op_name: :block, + name: opts[:name], + meta: opts[:meta], + block_fun: fun, + block_id: block_id + ) + end end end diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 372486b3..6c25e35c 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -606,7 +606,7 @@ defmodule Axon.Compiler do %Axon.Node{ id: id, op: :block, - parent: [parent], + parent: parents, opts: [block_fun: block_fun, block_id: block_id], name: name_fn }, @@ -614,9 +614,9 @@ defmodule Axon.Compiler do cache_and_counts, config ) do - {[parent_id], {cache, op_counts, block_cache, model_state_meta}} = + {parent_ids, {cache, op_counts, block_cache, model_state_meta}} = Enum.map_reduce( - [parent], + parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config) ) @@ -627,7 +627,8 @@ defmodule Axon.Compiler do {funs, name, block_cache, op_counts} %{} -> - funs = build(block_fun.(Axon.input("subgraph")), debug?: config.debug?) + inputs = Enum.with_index(parents, fn _, i -> Axon.input("subgraph#{i}") end) + funs = build(apply(block_fun, inputs), debug?: config.debug?) name = name_fn.(:block, op_counts) op_counts = Map.update(op_counts, :block, 1, fn x -> x + 1 end) {funs, name, Map.put(block_cache, block_id, {funs, name}), op_counts} @@ -637,9 +638,9 @@ defmodule Axon.Compiler do # Recurse graph inputs and invoke cache to get parent results, # state, and result_cache and then apply dtype policy and hooks # to each input - {[layer_input], {state, result_cache, none?}} = + {layer_inputs, {state, result_cache, none?}} = Enum.map_reduce( - [parent_id], + parent_ids, {state, result_cache, false}, fn parent_id, {state, result_cache, none?} -> {layer_input, {state, result_cache}} = @@ -663,7 +664,13 @@ defmodule Axon.Compiler do {%Axon.None{}, {state, result_cache}} else block_params = params[block_name] || %{} - result = apply(block_predict_fun, [Axon.ModelState.new(block_params), layer_input]) + + inputs = + layer_inputs + |> Enum.with_index() + |> Map.new(fn {input, i} -> {"subgraph#{i}", input} end) + + result = apply(block_predict_fun, [Axon.ModelState.new(block_params), inputs]) {out_result, out_state} = case result do @@ -685,8 +692,8 @@ defmodule Axon.Compiler do end init_fun = fn template, cache, result_cache, fn_stacktrace, keys -> - {[parent_shape], {parent_params, result_cache, none?}} = - Enum.map_reduce([parent_id], {%{}, result_cache, false}, fn + {parent_shapes, {parent_params, result_cache, none?}} = + Enum.map_reduce(parent_ids, {%{}, result_cache, false}, fn parent_id, {params, result_cache, none?} -> {parent_shape, {params, result_cache}} = call_init_cache( @@ -706,8 +713,12 @@ defmodule Axon.Compiler do if none? do {%Axon.None{}, {parent_params, result_cache}} else - template = Nx.broadcast(0.0, parent_shape) - block_params = apply(block_init_fun, [template, Axon.ModelState.empty()]) + templates = + parent_shapes + |> Enum.with_index() + |> Map.new(fn {shape, i} -> {"subgraph#{i}", Nx.broadcast(0.0, shape)} end) + + block_params = apply(block_init_fun, [templates, Axon.ModelState.empty()]) params = if block_params == %{} do diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index a22a49da..11fc9108 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5330,6 +5330,43 @@ defmodule CompilerTest do input = random({1, 1}) assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k, b)) end + + test "works with multiple block inputs" do + block = + Axon.block(fn x, y -> + dense = Axon.block(&Axon.dense(&1, 4)) + Axon.add(dense.(y), dense.(x)) + end) + + input1 = Axon.input("input1") + input2 = Axon.input("input2") + + model = block.(input1, input2) |> Axon.dense(1) + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn %{"input1" => x, "input2" => y}, k1, b1, k2, b2 -> + x = Axon.Layers.dense(x, k1, b1) + y = Axon.Layers.dense(y, k1, b1) + + x + |> Nx.add(y) + |> Axon.Layers.dense(k2, b2) + end + + input = %{"input1" => Nx.tensor([[0.5]]), "input2" => Nx.tensor([[0.75]])} + + assert %ModelState{ + data: %{ + "block_0" => %{ + "block_0" => %{"dense_0" => %{"kernel" => k1, "bias" => b1}} + }, + "dense_0" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(input, ModelState.empty()) + + assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k1, b1, k2, b2)) + end end describe "initializers" do From 8e7f7b9fdc431a4449defab83340a8c8f85b790d Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 14 May 2024 10:12:09 -0400 Subject: [PATCH 2/4] Do not duplicate so much code --- lib/axon.ex | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 7d1d97ee..995650c8 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -745,25 +745,28 @@ defmodule Axon do generated. """ @doc type: :special - def block(fun, opts \\ []) + def block(fun, opts \\ []) when is_function(fun) do + {:arity, arity} = Function.info(fun, :arity) + opts = Keyword.validate!(opts, [:name, :meta]) + block_id = System.unique_integer([:positive, :monotonic]) - for i <- 1..128 do - args = Macro.generate_arguments(i, __MODULE__) + block_fun(arity, fn inputs -> + layer(:block, List.wrap(inputs), + op_name: :block, + name: opts[:name], + meta: opts[:meta], + block_fun: fun, + block_id: block_id + ) + end) + end - @doc false - def block(fun, opts) when is_function(fun, unquote(i)) do - opts = Keyword.validate!(opts, [:name, :meta]) - block_id = System.unique_integer([:positive, :monotonic]) + @doc false + for i <- 0..128 do + args = Macro.generate_arguments(i, __MODULE__) - fn unquote_splicing(args) -> - layer(:block, List.wrap(unquote(args)), - op_name: :block, - name: opts[:name], - meta: opts[:meta], - block_fun: fun, - block_id: block_id - ) - end + def block_fun(unquote(i), callback) do + fn unquote_splicing(args) -> callback.(unquote(args)) end end end From 977a4770c39f5e9cdfa9e9c53e525920f24b6d86 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 14 May 2024 10:14:14 -0400 Subject: [PATCH 3/4] Make block fun private --- lib/axon.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/axon.ex b/lib/axon.ex index 995650c8..4173f6c7 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -765,7 +765,7 @@ defmodule Axon do for i <- 0..128 do args = Macro.generate_arguments(i, __MODULE__) - def block_fun(unquote(i), callback) do + defp block_fun(unquote(i), callback) do fn unquote_splicing(args) -> callback.(unquote(args)) end end end From 893c38cbf97836960646265589839125b9d11f32 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 14 May 2024 10:14:28 -0400 Subject: [PATCH 4/4] Remove doc false --- lib/axon.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/axon.ex b/lib/axon.ex index 4173f6c7..737d7ca3 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -761,7 +761,6 @@ defmodule Axon do end) end - @doc false for i <- 0..128 do args = Macro.generate_arguments(i, __MODULE__)