Skip to content

Commit

Permalink
implement named tuples
Browse files Browse the repository at this point in the history
Based on #16580, also much work done by quinnj.

`(a=1, ...)` syntax is implemented, and `(; ...)` syntax is
implemented but not yet enabled.
  • Loading branch information
JeffBezanson committed Oct 17, 2017
1 parent 26c7ecb commit ab88bc9
Show file tree
Hide file tree
Showing 18 changed files with 520 additions and 38 deletions.
2 changes: 1 addition & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ export
# key types
Any, DataType, Vararg, ANY, NTuple,
Tuple, Type, UnionAll, TypeName, TypeVar, Union, Void,
SimpleVector, AbstractArray, DenseArray,
SimpleVector, AbstractArray, DenseArray, NamedTuple,
# special objects
Function, CodeInfo, Method, MethodTable, TypeMapEntry, TypeMapLevel,
Module, Symbol, Task, Array, WeakRef, VecElement,
Expand Down
17 changes: 16 additions & 1 deletion base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ end
const _Type_name = Type.body.name
isType(@nospecialize t) = isa(t, DataType) && (t::DataType).name === _Type_name

const _NamedTuple_name = NamedTuple.body.body.name

# true if Type is inlineable as constant (is a singleton)
function isconstType(@nospecialize t)
isType(t) || return false
Expand Down Expand Up @@ -725,6 +727,10 @@ function isdefined_tfunc(args...)
end
if 1 <= idx <= a1.ninitialized
return Const(true)
elseif a1.name === _NamedTuple_name
if isleaftype(a1)
return Const(false)
end
elseif idx <= 0 || (!isvatuple(a1) && idx > fieldcount(a1))
return Const(false)
elseif !isvatuple(a1) && isbits(fieldtype(a1, idx))
Expand Down Expand Up @@ -762,7 +768,9 @@ add_tfunc(nfields, 1, 1,
# TODO: remove with deprecation in builtins.c for nfields(::Type)
isleaftype(x.parameters[1]) && return Const(old_nfields(x.parameters[1]))
elseif isa(x,DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x)) && x !== DataType
return Const(length(x.types))
if !(x.name === _NamedTuple_name && !isleaftype(x))
return Const(length(x.types))
end
end
return Int
end, 0)
Expand Down Expand Up @@ -1324,6 +1332,10 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
end
return Any
end
if s.name === _NamedTuple_name && !isleaftype(s)
# TODO: better approximate inference
return Any
end
if isempty(s.types)
return Bottom
end
Expand Down Expand Up @@ -1407,6 +1419,9 @@ function fieldtype_tfunc(@nospecialize(s0), @nospecialize(name))
if !isa(u,DataType) || u.abstract
return Type
end
if u.name === _NamedTuple_name && !isleaftype(u)
return Type
end
ftypes = u.types
if isempty(ftypes)
return Bottom
Expand Down
163 changes: 163 additions & 0 deletions base/namedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

@generated function NamedTuple{names}(args...) where names
N = length(names)
if length(args) == N
Expr(:new, :(NamedTuple{names,$(Tuple{args...})}), Any[ :(args[$i]) for i in 1:N ]...)
else
:(throw(ArgumentError("wrong number of arguments to named tuple constructor")))
end
end

@generated function NamedTuple{names,T}(args...) where {names, T <: Tuple}
N = length(names)
types = T.parameters
if length(args) == N
Expr(:new, :(NamedTuple{names,T}), Any[ :(convert($(types[i]), args[$i])) for i in 1:N ]...)
else
:(throw(ArgumentError("wrong number of arguments to named tuple constructor")))
end
end

NamedTuple() = NamedTuple{()}()

length(t::NamedTuple) = nfields(t)
start(t::NamedTuple) = 1
done(t::NamedTuple, iter) = iter > nfields(t)
next(t::NamedTuple, iter) = (getfield(t, iter), iter + 1)
endof(t::NamedTuple) = nfields(t)
getindex(t::NamedTuple, i::Int) = getfield(t, i)
getindex(t::NamedTuple, i::Symbol) = getfield(t, i)

function getindex(t::NamedTuple, I::AbstractVector)
idxs = unique( Int[ isa(i, Symbol) ? fieldindex(typeof(t), i) : i for i in I ] )
names = keys(t)[idxs]
NamedTuple{names}([ getfield( t, i ) for i in idxs ]...)
end

convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names,T}) where {names,T} = nt
convert(::Type{NamedTuple{names}}, nt::NamedTuple{names}) where {names} = nt

function convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names}) where {names,T}
NamedTuple{names,T}(nt...)
end

function show(io::IO, t::NamedTuple)
n = nfields(t)
if n == 0
print(io, "NamedTuple()")
else
print(io, "(")
for i = 1:n
print(io, fieldname(typeof(t),i), " = "); show(io, getfield(t,i))
if n == 1
print(io, ",")
elseif i < n
print(io, ", ")
end
end
print(io, ")")
end
end

eltype(::Type{NamedTuple{names,T}}) where {names,T} = eltype(T)

==(a::NamedTuple{n}, b::NamedTuple{n}) where {n} = Tuple(a) == Tuple(b)
==(a::NamedTuple, b::NamedTuple) = false

isequal(a::NamedTuple{n}, b::NamedTuple{n}) where {n} = isequal(Tuple(a), Tuple(b))
isequal(a::NamedTuple, b::NamedTuple) = false

_nt_names(::NamedTuple{names}) where {names} = names
_nt_names(::Type{T}) where {names,T<:NamedTuple{names}} = names

hash(x::NamedTuple, h::UInt) = xor(object_id(_nt_names(x)), hash(Tuple(x), h))

isless(a::NamedTuple{n}, b::NamedTuple{n}) where {n} = isless(Tuple(a), Tuple(b))
# TODO: case where one argument's names are a prefix of the other's

function map(f, nt::NamedTuple, nts::NamedTuple...)
# this method makes sure we don't define a map(f) method
_nt_map(f, nt, nts...)
end

@generated function _nt_map(f, nts::NamedTuple...)
fields = _nt_names(nts[1])
for x in nts[2:end]
if _nt_names(x) != fields
throw(ArgumentError("All NamedTuple arguments to map must have the same fields in the same order"))
end
end
N = fieldcount(nts[1])
M = length(nts)

NT = NamedTuple{fields}
args = Expr[:(f($(Expr[:(getfield(nts[$i], $j)) for i = 1:M]...))) for j = 1:N]
:( $NT($(args...)) )
end

# a version of `in` for the older world these generated functions run in
function sym_in(x, itr)
for y in itr
y === x && return true
end
return false
end

"""
merge(a::NamedTuple, b::NamedTuple)
Construct a new named tuple by merging two existing ones.
The order of fields in `a` is preserved, but values are taken from matching
fields in `b`. Fields present only in `b` are appended at the end.
```jldoctest
julia> merge((a=1, b=2, c=3), (b=4, d=5))
(a = 1, b = 4, c = 3, d = 5)
```
"""
@generated function merge(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn}
names = Symbol[an...]
for n in bn
if !sym_in(n, an)
push!(names, n)
end
end
vals = map(names) do n
if sym_in(n, bn)
:(getfield(b, $(Expr(:quote, n))))
else
:(getfield(a, $(Expr(:quote, n))))
end
end
names = (names...,)
:( NamedTuple{$names}($(vals...)) )
end

merge(a::NamedTuple{()}, b::NamedTuple) = b

"""
merge(a::NamedTuple, iterable)
Interpret an iterable of key-value pairs as a named tuple, and perform a merge.
```jldoctest
julia> merge((a=1, b=2, c=3), [:b=>4, :d=>5])
(a = 1, b = 4, c = 3, d = 5)
```
"""
function merge(a::NamedTuple, itr)
names = Symbol[]
vals = Any[]
for (k,v) in itr
push!(names, k)
push!(vals, v)
end
merge(a, NamedTuple{(names...)}(vals...))
end

keys(nt::NamedTuple{names}) where {names} = names
values(nt::NamedTuple) = Tuple(nt)
haskey(nt::NamedTuple, key::Union{Integer, Symbol}) = isdefined(nt, key)
get(nt::NamedTuple, key::Union{Integer, Symbol}, default) = haskey(nt, key) ? getfield(nt, key) : default
get(f::Callable, nt::NamedTuple, key::Union{Integer, Symbol}) = haskey(nt, key) ? getfield(nt, key) : f()
6 changes: 4 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,14 @@ julia> fieldname(SparseMatrixCSC, 5)
```
"""
function fieldname(t::DataType, i::Integer)
n_fields = length(t.name.names)
names = isdefined(t, :names) ? t.names : t.name.names
n_fields = length(names)
field_label = n_fields == 1 ? "field" : "fields"
i > n_fields && throw(ArgumentError("Cannot access field $i since type $t only has $n_fields $field_label."))
i < 1 && throw(ArgumentError("Field numbers must be positive integers. $i is invalid."))
return t.name.names[i]::Symbol
return names[i]::Symbol
end

fieldname(t::UnionAll, i::Integer) = fieldname(unwrap_unionall(t), i)
fieldname(t::Type{<:Tuple}, i::Integer) =
i < 1 || i > fieldcount(t) ? throw(BoundsError(t, i)) : Int(i)
Expand Down
2 changes: 2 additions & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ Vector(m::Integer) = Array{Any,1}(Int(m))
Matrix{T}(m::Integer, n::Integer) where {T} = Matrix{T}(Int(m), Int(n))
Matrix(m::Integer, n::Integer) = Matrix{Any}(Int(m), Int(n))

include("namedtuple.jl")

# numeric operations
include("hashing.jl")
include("rounding.jl")
Expand Down
1 change: 1 addition & 0 deletions src/ast.scm
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
;; predicates and accessors

(define (quoted? e) (memq (car e) '(quote top core globalref outerref line break inert meta)))
(define (quotify e) `',e)

(define (lam:args x) (cadr x))
(define (lam:vars x) (llist-vars (lam:args x)))
Expand Down
1 change: 1 addition & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,7 @@ void jl_init_primitives(void)
add_builtin("QuoteNode", (jl_value_t*)jl_quotenode_type);
add_builtin("NewvarNode", (jl_value_t*)jl_newvarnode_type);
add_builtin("GlobalRef", (jl_value_t*)jl_globalref_type);
add_builtin("NamedTuple", (jl_value_t*)jl_namedtuple_type);

add_builtin("Bool", (jl_value_t*)jl_bool_type);
add_builtin("UInt8", (jl_value_t*)jl_uint8_type);
Expand Down
3 changes: 2 additions & 1 deletion src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ jl_datatype_t *jl_new_uninitialized_datatype(void)
t->hasfreetypevars = 0;
t->isleaftype = 1;
t->layout = NULL;
t->names = NULL;
return t;
}

Expand Down Expand Up @@ -288,7 +289,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
return;
}
}
if (st->types == NULL)
if (st->types == NULL || (jl_is_namedtuple_type(st) && !jl_is_leaf_type((jl_value_t*)st)))
return;
uint32_t nfields = jl_svec_len(st->types);
if (nfields == 0) {
Expand Down
9 changes: 6 additions & 3 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ static void jl_serialize_datatype(jl_serializer_state *s, jl_datatype_t *dt)
if (has_instance)
jl_serialize_value(s, dt->instance);
jl_serialize_value(s, dt->name);
jl_serialize_value(s, dt->names);
jl_serialize_value(s, dt->parameters);
jl_serialize_value(s, dt->super);
jl_serialize_value(s, dt->types);
Expand Down Expand Up @@ -1245,6 +1246,8 @@ static jl_value_t *jl_deserialize_datatype(jl_serializer_state *s, int pos, jl_v
}
dt->name = (jl_typename_t*)jl_deserialize_value(s, (jl_value_t**)&dt->name);
jl_gc_wb(dt, dt->name);
dt->names = (jl_svec_t*)jl_deserialize_value(s, (jl_value_t**)&dt->names);
jl_gc_wb(dt, dt->names);
dt->parameters = (jl_svec_t*)jl_deserialize_value(s, (jl_value_t**)&dt->parameters);
jl_gc_wb(dt, dt->parameters);
dt->super = (jl_datatype_t*)jl_deserialize_value(s, (jl_value_t**)&dt->super);
Expand Down Expand Up @@ -2803,7 +2806,6 @@ void jl_init_serializer(void)
jl_box_int32(30), jl_box_int32(31), jl_box_int32(32),
#ifndef _P64
jl_box_int32(33), jl_box_int32(34), jl_box_int32(35),
jl_box_int32(36), jl_box_int32(37),
#endif
jl_box_int64(0), jl_box_int64(1), jl_box_int64(2),
jl_box_int64(3), jl_box_int64(4), jl_box_int64(5),
Expand All @@ -2818,7 +2820,6 @@ void jl_init_serializer(void)
jl_box_int64(30), jl_box_int64(31), jl_box_int64(32),
#ifdef _P64
jl_box_int64(33), jl_box_int64(34), jl_box_int64(35),
jl_box_int64(36), jl_box_int64(37),
#endif
jl_labelnode_type, jl_linenumbernode_type, jl_gotonode_type,
jl_quotenode_type, jl_type_type, jl_bottom_type, jl_ref_type,
Expand All @@ -2844,7 +2845,8 @@ void jl_init_serializer(void)
jl_intrinsic_type->name, jl_task_type->name, jl_labelnode_type->name,
jl_linenumbernode_type->name, jl_builtin_type->name, jl_gotonode_type->name,
jl_quotenode_type->name, jl_globalref_type->name, jl_typeofbottom_type->name,
jl_string_type->name, jl_abstractstring_type->name,
jl_string_type->name, jl_abstractstring_type->name, jl_namedtuple_type,
jl_namedtuple_typename,

ptls->root_task,

Expand Down Expand Up @@ -2882,6 +2884,7 @@ void jl_init_serializer(void)
arraylist_push(&builtin_typenames, ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_densearray_type))->name);
arraylist_push(&builtin_typenames, jl_tuple_typename);
arraylist_push(&builtin_typenames, jl_vararg_typename);
arraylist_push(&builtin_typenames, jl_namedtuple_typename);
}

#ifdef __cplusplus
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void jl_set_datatype_super(jl_datatype_t *tt, jl_value_t *super)
if (!jl_is_datatype(super) || !jl_is_abstracttype(super) ||
tt->name == ((jl_datatype_t*)super)->name ||
jl_subtype(super,(jl_value_t*)jl_vararg_type) ||
jl_is_tuple_type(super) ||
jl_is_tuple_type(super) || jl_is_namedtuple_type(super) ||
jl_subtype(super,(jl_value_t*)jl_type_type) ||
super == (jl_value_t*)jl_builtin_type) {
jl_errorf("invalid subtyping in definition of %s",
Expand Down
Loading

0 comments on commit ab88bc9

Please sign in to comment.