diff --git a/src/type_inf.jl b/src/type_inf.jl index 6080dc8c..5fd052c7 100644 --- a/src/type_inf.jl +++ b/src/type_inf.jl @@ -33,6 +33,8 @@ function infer_type(binding::Binding, scope, state) end function infer_type_assignment_rhs(binding, state, scope) + is_destructuring = false + lhs = binding.val.args[1] rhs = binding.val.args[2] if is_loop_iter_assignment(binding.val) settype!(binding, infer_eltype(rhs)) @@ -43,13 +45,24 @@ function infer_type_assignment_rhs(binding, state, scope) end else if CSTParser.is_func_call(rhs) + if CSTParser.istuple(lhs) + if CSTParser.isparameters(lhs.args[1]) + is_destructuring = true + else + return + end + end callname = CSTParser.get_name(rhs) if isidentifier(callname) resolve_ref(callname, scope, state) if hasref(callname) rb = get_root_method(refof(callname), state.server) if (rb isa Binding && (CoreTypes.isdatatype(rb.type) || rb.val isa SymbolServer.DataTypeStore)) || rb isa SymbolServer.DataTypeStore - settype!(binding, rb) + if is_destructuring + infer_destructuring_type(binding, rb) + else + settype!(binding, rb) + end end end end @@ -94,6 +107,26 @@ function infer_type_assignment_rhs(binding, state, scope) end end +function infer_destructuring_type(binding, rb::SymbolServer.DataTypeStore) + assigned_name = CSTParser.get_name(binding.val) + for (fieldname, fieldtype) in zip(rb.val.fieldnames, rb.val.types) + if fieldname == assigned_name + settype!(binding, fieldtype) + return + end + end +end +function infer_destructuring_type(binding::Binding, rb::EXPR) + assigned_name = string(to_codeobject(binding.name)) + scope = scopeof(rb) + names = scope.names + if haskey(names, assigned_name) + b = names[assigned_name] + settype!(binding, b.type) + end +end +infer_destructuring_type(binding, rb::Binding) = infer_destructuring_type(binding, rb.val) + function infer_type_decl(binding, state, scope) t = binding.val.args[2] if isidentifier(t) diff --git a/test/runtests.jl b/test/runtests.jl index 4a930698..be9a9bc3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -190,7 +190,8 @@ f(arg) = arg # @test parse_and_pass("function f(x::Int) x end")[1][2][3].binding.t == StaticLint.getsymbolserver(server)["Core"].vals["Function"] let cst = parse_and_pass(""" struct T end - function f(x::T) x end""") + function f(x::T) x end + """) @test StaticLint.CoreTypes.isdatatype(bindingof(cst.args[1]).type) @test StaticLint.CoreTypes.isfunction(bindingof(cst.args[2]).type) @test bindingof(cst.args[2].args[1].args[2]).type == bindingof(cst.args[1]) @@ -199,7 +200,8 @@ f(arg) = arg let cst = parse_and_pass(""" struct T end T() = 1 - function f(x::T) x end""") + function f(x::T) x end + """) @test StaticLint.CoreTypes.isdatatype(bindingof(cst.args[1]).type) @test StaticLint.CoreTypes.isfunction(bindingof(cst.args[3]).type) @test bindingof(cst.args[3].args[1].args[2]).type == bindingof(cst.args[1]) @@ -208,7 +210,8 @@ f(arg) = arg let cst = parse_and_pass(""" struct T end - t = T()""") + t = T() + """) @test StaticLint.CoreTypes.isdatatype(bindingof(cst.args[1]).type) @test bindingof(cst.args[2].args[1]).type == bindingof(cst.args[1]) end @@ -222,7 +225,8 @@ f(arg) = arg import ..B B.x end - end""") + end + """) @test refof(cst.args[1].args[3].args[2].args[3].args[2].args[2].args[1]) == bindingof(cst[1].args[3].args[1].args[3].args[1].args[1]) end @@ -235,7 +239,8 @@ f(arg) = arg end function f(arg::T1) arg.field.x - end"""); + end + """); @test refof(cst.args[3].args[2].args[1].args[1].args[1]) == bindingof(cst.args[3].args[1].args[2]) @test refof(cst.args[3].args[2].args[1].args[1].args[2].args[1]) == bindingof(cst.args[2].args[3].args[1]) @test refof(cst.args[3].args[2].args[1].args[2].args[1]) == bindingof(cst.args[1].args[3].args[1]) @@ -342,6 +347,21 @@ f(arg) = arg @test refof(cst[3][3][1]) !== nothing @test refof(cst[3][3][2]) !== nothing end + + let cst = parse_and_pass(""" + struct Foo + x::DataType + y::Float64 + end + (;x, y) = Foo(1,2) + x + y + """) + mx = cst.args[3].meta + @test mx.ref.type.name.name.name == :DataType + my = cst.args[4].meta + @test my.ref.type.name.name.name == :Float64 + end end @testset "macros" begin