diff --git a/lib/construct.ex b/lib/construct.ex index 176b49b..f3ee98f 100644 --- a/lib/construct.ex +++ b/lib/construct.ex @@ -271,12 +271,14 @@ defmodule Construct do type = case Keyword.fetch(opts, :default) do - {:ok, nil} -> - quote do: unquote(type) | nil - - # TODO recognize type of default value - {:ok, _default} -> - type + {:ok, default} -> + typeof_default = Construct.Type.typeof(default) + + if type == typeof_default do + type + else + quote do: unquote(type) | unquote(typeof_default) + end :error -> type diff --git a/lib/construct/type.ex b/lib/construct/type.ex index 00af048..2d5b24c 100644 --- a/lib/construct/type.ex +++ b/lib/construct/type.ex @@ -376,6 +376,26 @@ defmodule Construct.Type do ## Typespecs + @doc """ + Returns typespec AST for given type + + iex> spec([CommaList, {:array, :integer}]) |> Macro.to_string() + "list(:integer)" + + iex> spec({:array, :string}) |> Macro.to_string() + "list(String.t())" + + iex> spec({:map, CustomType}) |> Macro.to_string() + "%{optional(term) => CustomType.t()}" + + iex> spec(:string) |> Macro.to_string() + "String.t()" + + iex> spec(CustomType) |> Macro.to_string() + "CustomType.t()" + """ + @spec spec(t) :: Macro.t() + def spec(type) when is_list(type) do type |> List.last() |> spec() end @@ -442,6 +462,84 @@ defmodule Construct.Type do type end + @doc """ + Returns typespec AST for given term + + iex> typeof(nil) |> Macro.to_string() + "nil" + + iex> typeof(1.42) |> Macro.to_string() + "float()" + + iex> typeof("string") |> Macro.to_string() + "String.t()" + + iex> typeof(CustomType) |> Macro.to_string() + "CustomType.t()" + + iex> typeof(&NaiveDateTime.utc_now/0) |> Macro.to_string() + "NaiveDateTime.t()" + """ + @spec spec(t) :: Macro.t() + + def typeof(term) when is_nil(term) do + nil + end + + def typeof(term) when is_integer(term) do + {:integer, [], []} + end + + def typeof(term) when is_float(term) do + {:float, [], []} + end + + def typeof(term) when is_boolean(term) do + {:boolean, [], []} + end + + def typeof(term) when is_binary(term) do + quote do + String.t() + end + end + + def typeof(term) when is_pid(term) do + {:pid, [], []} + end + + def typeof(term) when is_reference(term) do + {:reference, [], []} + end + + def typeof(%{__struct__: struct}) when is_atom(struct) do + quote do + unquote(struct).t() + end + end + + def typeof(term) when is_map(term) do + {:map, [], []} + end + + def typeof(term) when is_atom(term) do + quote do + unquote(term).t() + end + end + + def typeof(term) when is_list(term) do + {:list, [], []} + end + + def typeof(term) when is_function(term, 0) do + term.() |> typeof() + end + + def typeof(_) do + {:term, [], []} + end + ## Helpers defp validate_decimal({:ok, %{__struct__: Decimal, coef: coef}}) when coef in [:inf, :qNaN, :sNaN],