diff --git a/lib/axon.ex b/lib/axon.ex index 865cb5e0..737d7ca3 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -746,10 +746,11 @@ defmodule Axon do """ @doc type: :special def block(fun, opts \\ []) when is_function(fun) do + {:arity, arity} = Function.info(fun, :arity) opts = Keyword.validate!(opts, [:name, :meta]) block_id = System.unique_integer([:positive, :monotonic]) - fn inputs -> + block_fun(arity, fn inputs -> layer(:block, List.wrap(inputs), op_name: :block, name: opts[:name], @@ -757,6 +758,14 @@ defmodule Axon do block_fun: fun, block_id: block_id ) + end) + end + + for i <- 0..128 do + args = Macro.generate_arguments(i, __MODULE__) + + defp block_fun(unquote(i), callback) do + fn unquote_splicing(args) -> callback.(unquote(args)) end end end diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 372486b3..6c25e35c 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -606,7 +606,7 @@ defmodule Axon.Compiler do %Axon.Node{ id: id, op: :block, - parent: [parent], + parent: parents, opts: [block_fun: block_fun, block_id: block_id], name: name_fn }, @@ -614,9 +614,9 @@ defmodule Axon.Compiler do cache_and_counts, config ) do - {[parent_id], {cache, op_counts, block_cache, model_state_meta}} = + {parent_ids, {cache, op_counts, block_cache, model_state_meta}} = Enum.map_reduce( - [parent], + parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config) ) @@ -627,7 +627,8 @@ defmodule Axon.Compiler do {funs, name, block_cache, op_counts} %{} -> - funs = build(block_fun.(Axon.input("subgraph")), debug?: config.debug?) + inputs = Enum.with_index(parents, fn _, i -> Axon.input("subgraph#{i}") end) + funs = build(apply(block_fun, inputs), debug?: config.debug?) name = name_fn.(:block, op_counts) op_counts = Map.update(op_counts, :block, 1, fn x -> x + 1 end) {funs, name, Map.put(block_cache, block_id, {funs, name}), op_counts} @@ -637,9 +638,9 @@ defmodule Axon.Compiler do # Recurse graph inputs and invoke cache to get parent results, # state, and result_cache and then apply dtype policy and hooks # to each input - {[layer_input], {state, result_cache, none?}} = + {layer_inputs, {state, result_cache, none?}} = Enum.map_reduce( - [parent_id], + parent_ids, {state, result_cache, false}, fn parent_id, {state, result_cache, none?} -> {layer_input, {state, result_cache}} = @@ -663,7 +664,13 @@ defmodule Axon.Compiler do {%Axon.None{}, {state, result_cache}} else block_params = params[block_name] || %{} - result = apply(block_predict_fun, [Axon.ModelState.new(block_params), layer_input]) + + inputs = + layer_inputs + |> Enum.with_index() + |> Map.new(fn {input, i} -> {"subgraph#{i}", input} end) + + result = apply(block_predict_fun, [Axon.ModelState.new(block_params), inputs]) {out_result, out_state} = case result do @@ -685,8 +692,8 @@ defmodule Axon.Compiler do end init_fun = fn template, cache, result_cache, fn_stacktrace, keys -> - {[parent_shape], {parent_params, result_cache, none?}} = - Enum.map_reduce([parent_id], {%{}, result_cache, false}, fn + {parent_shapes, {parent_params, result_cache, none?}} = + Enum.map_reduce(parent_ids, {%{}, result_cache, false}, fn parent_id, {params, result_cache, none?} -> {parent_shape, {params, result_cache}} = call_init_cache( @@ -706,8 +713,12 @@ defmodule Axon.Compiler do if none? do {%Axon.None{}, {parent_params, result_cache}} else - template = Nx.broadcast(0.0, parent_shape) - block_params = apply(block_init_fun, [template, Axon.ModelState.empty()]) + templates = + parent_shapes + |> Enum.with_index() + |> Map.new(fn {shape, i} -> {"subgraph#{i}", Nx.broadcast(0.0, shape)} end) + + block_params = apply(block_init_fun, [templates, Axon.ModelState.empty()]) params = if block_params == %{} do diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index a22a49da..11fc9108 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5330,6 +5330,43 @@ defmodule CompilerTest do input = random({1, 1}) assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k, b)) end + + test "works with multiple block inputs" do + block = + Axon.block(fn x, y -> + dense = Axon.block(&Axon.dense(&1, 4)) + Axon.add(dense.(y), dense.(x)) + end) + + input1 = Axon.input("input1") + input2 = Axon.input("input2") + + model = block.(input1, input2) |> Axon.dense(1) + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn %{"input1" => x, "input2" => y}, k1, b1, k2, b2 -> + x = Axon.Layers.dense(x, k1, b1) + y = Axon.Layers.dense(y, k1, b1) + + x + |> Nx.add(y) + |> Axon.Layers.dense(k2, b2) + end + + input = %{"input1" => Nx.tensor([[0.5]]), "input2" => Nx.tensor([[0.75]])} + + assert %ModelState{ + data: %{ + "block_0" => %{ + "block_0" => %{"dense_0" => %{"kernel" => k1, "bias" => b1}} + }, + "dense_0" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(input, ModelState.empty()) + + assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k1, b1, k2, b2)) + end end describe "initializers" do