Skip to content

Commit

Permalink
Raise on ambiguous inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Oct 16, 2024
1 parent 4cc474b commit 5f281ba
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
33 changes: 25 additions & 8 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,16 @@ defmodule Axon.Compiler do
name: name_fn,
opts: [shape: _input_shape, optional: optional?]
},
_nodes,
nodes,
{cache, op_counts, block_cache, model_state_meta},
%{mode: mode, print_values: print_values}
) do
name = name_fn.(:input, op_counts)
op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end)
all_inputs = get_all_inputs(nodes)

predict_fun = fn _params, inputs, state, _cache, result_cache, _fn_stacktrace ->
value = get_input(inputs, name, optional?)
value = get_input(all_inputs, inputs, name, optional?)

# TODO: Add this back in
# validate_input_shape!(value, shape)
Expand All @@ -509,7 +510,7 @@ defmodule Axon.Compiler do
end

init_fun = fn template, _cache, result_cache, _fn_stacktrace, _keys ->
input = get_input(template, name, optional?)
input = get_input(all_inputs, template, name, optional?)
{Nx.to_template(input), {%{}, result_cache}}
end

Expand Down Expand Up @@ -889,16 +890,32 @@ defmodule Axon.Compiler do
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp get_input(inputs, name, optional?) do
defp get_all_inputs(nodes) do
nodes
|> Enum.filter(fn {_, %{op: op}} -> op == :input end)
|> Enum.map(fn {_, %{name: name_fn}} ->
# inputs require a name, so we can just ignore op counts
name_fn.(:input, %{})
end)
|> Enum.uniq()
end

defp get_input(all_input_names, inputs, name, optional?) do
res =
case inputs do
%Nx.Tensor{} = inputs ->
case {all_input_names, inputs} do
{[^name], %Nx.Tensor{} = inputs} ->
inputs

%{} = inputs ->
{_, %Nx.Tensor{}} ->
raise ArgumentError,
"ambiguous input given to the model," <>
" expected inputs with names #{inspect(all_input_names)}" <>
" but received a single tensor as input"

{_, %{} = inputs} ->
inputs[name]

inputs when is_tuple(inputs) ->
{[^name], inputs} when is_tuple(inputs) ->
inputs

_ ->
Expand Down
14 changes: 14 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ defmodule CompilerTest do
assert message =~ "exception found when compiling layer Axon.Layers.add/2 named add_0"
assert message =~ "cannot broadcast tensor of dimensions {1, 32} to {1, 64}"
end

test "raises if inputs are ambiguous" do
x = Axon.input("x")
y = Axon.input("y")
model = Axon.add(x, y)

{_, predict_fn} = Axon.build(model)

exception = assert_raise ArgumentError, fn ->
predict_fn.(ModelState.empty(), Nx.tensor([1]))
end

assert Exception.message(exception) =~ "ambiguous"
end
end

describe "optional" do
Expand Down
6 changes: 5 additions & 1 deletion test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ defmodule Axon.LoopTest do
Loop.trainer(model, [mean_squared_error: 0.5, mean_absolute_error: 0.5], :adam)

assert %{model_state: %{}} =
pstate = init_fn.({Nx.tensor([[2]]), Nx.tensor([[2]])}, Axon.ModelState.empty())
pstate =
init_fn.(
{%{"input_0" => Nx.tensor([[2]]), "input_1" => Nx.tensor([[2]])}, Nx.tensor(0)},
Axon.ModelState.empty()
)

state = %State{step_state: pstate}

Expand Down

0 comments on commit 5f281ba

Please sign in to comment.