Skip to content

Commit

Permalink
Replace _type with widenconst (#429)
Browse files Browse the repository at this point in the history
* Replace _type with widenconst

* Excise _type

* Bump patch version

* Actually save the file
  • Loading branch information
willtebbutt authored Dec 20, 2024
1 parent b7455e1 commit 8c478a8
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.65"
version = "0.4.66"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 0 additions & 6 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,6 @@ else
get_inference_world(interp::CC.AbstractInterpreter) = CC.get_inference_world(interp)
end

_type(x::Type) = x
_type(x::CC.Const) = _typeof(x.val)
_type(x::CC.PartialStruct) = x.typ
_type(x::CC.Conditional) = Union{_type(x.thentype),_type(x.elsetype)}
_type(::CC.PartialTypeVar) = TypeVar

struct NoInlineCallInfo <: CC.CallInfo
info::CC.CallInfo # wrapped call
tt::Any # signature
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ to be the return type given by the code cache.
function fix_up_invoke_inference!(ir::IRCode)::IRCode
stmts = ir.stmts
for n in 1:length(stmts)
if Meta.isexpr(stmt(stmts)[n], :invoke) && _type(stmts.type[n]) == Any
if Meta.isexpr(stmt(stmts)[n], :invoke) && CC.widenconst(stmts.type[n]) == Any
mi = stmt(stmts)[n].args[1]::Core.MethodInstance
R = isdefined(mi, :cache) ? mi.cache.rettype : CC.return_type(mi.specTypes)
stmts.type[n] = R
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ end
# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54
function __get_toplevel_mi_from_ir(ir, _module::Module)
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ())
mi.specTypes = Tuple{map(_type, ir.argtypes)...}
mi.specTypes = Tuple{map(CC.widenconst, ir.argtypes)...}
mi.def = _module
return mi
end
Expand Down
19 changes: 10 additions & 9 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,13 @@ end
# ADInfo struct for information regarding `interp` and `debug_mode`.
function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool)
arg_types = Dict{Argument,Any}(
map(((n, t),) -> (Argument(n) => _type(t)), enumerate(ir.argtypes))
map(((n, t),) -> (Argument(n) => CC.widenconst(t)), enumerate(ir.argtypes))
)
stmts = collect_stmts(ir)
ssa_insts = Dict{ID,NewInstruction}(stmts)
is_used_dict = characterise_used_ids(stmts)
zero_lazy_rdata_ref = Ref{Tuple{map(lazy_zero_rdata_type ∘ _type, ir.argtypes)...}}()
Tlazy_rdata_ref = Tuple{map(lazy_zero_rdata_type CC.widenconst, ir.argtypes)...}
zero_lazy_rdata_ref = Ref{Tlazy_rdata_ref}()
return ADInfo(
interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref
)
Expand Down Expand Up @@ -209,7 +210,7 @@ is_used(info::ADInfo, id::ID)::Bool = info.is_used_dict[id]
Returns the static / inferred type associated to `x`.
"""
get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x]
get_primal_type(info::ADInfo, x::ID) = _type(info.ssa_insts[x].type)
get_primal_type(info::ADInfo, x::ID) = CC.widenconst(info.ssa_insts[x].type)
get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value)
get_primal_type(::ADInfo, x) = _typeof(x)
function get_primal_type(::ADInfo, x::GlobalRef)
Expand Down Expand Up @@ -238,10 +239,10 @@ Create the statements which initialise the reverse-data `Ref`s.
function reverse_data_ref_stmts(info::ADInfo)
return vcat(
map(collect(info.arg_rdata_ref_ids)) do (k, id)
(id, new_inst(Expr(:call, __make_ref, _type(info.arg_types[k]))))
(id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.arg_types[k]))))
end,
map(collect(info.ssa_rdata_ref_ids)) do (k, id)
(id, new_inst(Expr(:call, __make_ref, _type(info.ssa_insts[k].type))))
(id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.ssa_insts[k].type))))
end,
)
end
Expand Down Expand Up @@ -462,15 +463,15 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
P = get_primal_type(info, line)
val_rdata_ref_id = get_rev_data_id(info, stmt.val)
output_rdata_ref_id = get_rev_data_id(info, line)
fwds = PiNode(__inc(stmt.val), fcodual_type(_type(stmt.typ)))
fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ)))
rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id)
else
# If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to
# do on the reverse-pass.
const_id = ID()
fwds = [
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
(line, new_inst(PiNode(const_id, fcodual_type(_type(stmt.typ))))),
(line, new_inst(PiNode(const_id, fcodual_type(CC.widenconst(stmt.typ))))),
]
rvs = nothing
end
Expand Down Expand Up @@ -935,7 +936,7 @@ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where
Treturn = Base.Experimental.compute_ir_rettype(ir)
isva, _ = is_vararg_and_sparam_names(sig_or_mi)

arg_types = map(_type, ir.argtypes)
arg_types = map(CC.widenconst, ir.argtypes)
sig = Tuple{arg_types...}
arg_fwds_types = Tuple{map(fcodual_type, arg_types)...}
arg_rvs_types = Tuple{map(rdata_type tangent_type, arg_types)...}
Expand Down Expand Up @@ -1281,7 +1282,7 @@ function forwards_pass_ir(
end

# Create and return the `BBCode` for the forwards-pass.
arg_types = vcat(Tshared_data, map(fcodual_type _type, ir.argtypes))
arg_types = vcat(Tshared_data, map(fcodual_type CC.widenconst, ir.argtypes))
ir = BBCode(vcat(entry_block, blocks), arg_types, ir.sptypes, ir.linetable, ir.meta)
return remove_unreachable_blocks!(ir)
end
Expand Down
3 changes: 1 addition & 2 deletions test/front_matter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ using Mooncake:
verify_rdata_value,
is_primitive,
MinimalCtx,
stmt,
_type
stmt

using .TestUtils:
test_rule,
Expand Down
4 changes: 0 additions & 4 deletions test/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,4 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x)
@test stmt(ad_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :a_primitive)
end
end
@testset "_type" begin
@test _type(CC.Const(5.0)) === Float64
@test _type(CC.PartialTypeVar(TypeVar(:a, Union{}, Any), true, true)) === TypeVar
end
end

2 comments on commit 8c478a8

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/121722

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.66 -m "<description of version>" 8c478a88bbd89d12507441314f21fb193e4eceef
git push origin v0.4.66

Please sign in to comment.