From 8e591df96ae16677b0baab1a92c310561eb840e8 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Mon, 11 Oct 2021 11:12:20 -0600 Subject: [PATCH] backing off for now --- Project.toml | 1 - src/QuickPOMDPs.jl | 1 - src/quick.jl | 79 ++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 2fb01fa..52ad45c 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ NamedTupleTools = "0.11, 0.12, 0.13" POMDPModelTools = "0.3.1" POMDPTesting = "0.2.1" POMDPs = "0.9" -Tricks = "0.1" julia = "1" [extras] diff --git a/src/QuickPOMDPs.jl b/src/QuickPOMDPs.jl index 991ef84..759856c 100644 --- a/src/QuickPOMDPs.jl +++ b/src/QuickPOMDPs.jl @@ -7,7 +7,6 @@ using POMDPTesting using UUIDs using NamedTupleTools using Random -using Tricks: static_hasmethod export DiscreteExplicitPOMDP, diff --git a/src/quick.jl b/src/quick.jl index b6e6124..c62dd85 100644 --- a/src/quick.jl +++ b/src/quick.jl @@ -28,6 +28,9 @@ function QuickMDP(id=uuid4(); kwargs...) S = infer_statetype(kwd) A = infer_actiontype(kwd) + + kwd = + d = namedtuple(keys(kwd)...)(values(kwd)...) qm = QuickMDP{id, S, A, typeof(d)}(d) return qm @@ -223,9 +226,9 @@ end function POMDPs.observation(m::QuickPOMDP, args...) if haskey(m.data, :observation) obs = m.data[:observation] - if static_hasmethod(obs, typeof(args)) + if hasmethod(obs, typeof(args)) return obs(args...) - elseif length(args) == 3 && static_hasmethod(obs, typeof(args[2:3])) + elseif length(args) == 3 && hasmethod(obs, typeof(args[2:3])) return obs(args[2:3]...) else return obs(args...) @@ -239,15 +242,15 @@ end function POMDPs.reward(m::QuickModel, args...) if haskey(m.data, :reward) r = m.data[:reward] - if static_hasmethod(r, typeof(args)) # static_hasmethod could cause issues, but I think it is worth doing in this single spot + if hasmethod(r, typeof(args)) # static_hasmethod could cause issues, but I think it is worth doing in this single spot return r(args...) elseif m isa POMDP && length(args) == 4 - if static_hasmethod(r, typeof(args[1:3])) # (s, a, sp, o) -> (s, a, sp) + if hasmethod(r, typeof(args[1:3])) # (s, a, sp, o) -> (s, a, sp) return r(args[1:3]...) - elseif static_hasmethod(r, typeof(args[1:2])) # (s, a, sp, o) -> (s, a) + elseif hasmethod(r, typeof(args[1:2])) # (s, a, sp, o) -> (s, a) return r(args[1:2]...) end - elseif length(args) == 3 && static_hasmethod(r, typeof(args[1:2])) # (s, a, sp) -> (s, a) + elseif length(args) == 3 && hasmethod(r, typeof(args[1:2])) # (s, a, sp) -> (s, a) return r(args[1:2]...) else return r(args...) @@ -257,6 +260,70 @@ function POMDPs.reward(m::QuickModel, args...) end end +struct QuickRewardModel{ArgNums, F} <: Function + f::F + hasmethod_fallback::Bool +end + +QuickRewardModel(f::Function, S, A; hasmethod_fallback::Bool=true) = QuickRewardModel{reward_argnums(f, S, A), typeof(f)}(f, hasmethod_fallback) +QuickRewardModel(f::Function, S, A, O; hasmethod_fallback::Bool=true) = QuickRewardModel{reward_argnums(f, S, A, O), typeof(f)}(f, hasmethod_fallback) +QuickRewardModel(r::QuickRewardModel, args...) = r + +function reward_argnums(f, S, A) + ans = [] + if hasmethod(f, Tuple{S,A}) + push!(ans, 2) + end + if hasmethod(f, Tuple{S,A,S}) + push!(ans, 3) + end + return (ans...,) # convert to tuple +end + +function reward_argnums(f, S, A, O) + if hasmethod(f, Tuple{S, A, S, O}) + return (reward_argnums(f, S, A)..., 4) + else + return reward_argnums(f, S, A) + end +end + +function (r::QuickRewardModel{ArgNums})(args...) where ArgNums + if length(args) in ArgNums + return r.f(args...) + elseif maximum(ArgNums) < length(args) + return r.f(args[1:maximum(ArgNums)]...) + elseif r.f.hasmethod_fallback + if hasmethod(r.f, typeof(args)) + found = r.f(args...) + elseif m isa POMDP && length(args) == 4 + if hasmethod(r.f, typeof(args[1:3])) # (s, a, sp, o) -> (s, a, sp) + found = r.f(args[1:3]...) + elseif hasmethod(r.f, typeof(args[1:2])) # (s, a, sp, o) -> (s, a) + found = r.f(args[1:2]...) + end + elseif length(args) == 3 && hasmethod(r.f, typeof(args[1:2])) # (s, a, sp) -> (s, a) + found = r.f(args[1:2]...) + else + return r.f(args...) + end + @warn("""A Quick(PO)MDP had to use hasmethod as a fallback to find the correct method of + the reward function to use. + + This may be caused by adding new methods to the reward function after creating + the Quick(PO)MDP and can cause significant perfromance degredation. Originally, + the Quick(PO)MDP found reward methods with the following numbers of arguments: + + $(ArgNums) + + Recommend adding all methods to the reward function before creaing the + Quick(PO)MDP.""", current_methods=methods(r.f)) + return found + else + return r.f(args...) + end +end + @forward_to_data POMDPs.initialstate @forward_to_data POMDPs.initialobs