Skip to content

Commit

Permalink
Preallocate RPG search state to reduce allocs.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Sep 9, 2024
1 parent 1af0581 commit 982df9a
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 116 deletions.
10 changes: 8 additions & 2 deletions src/heuristics/ff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mutable struct FFHeuristic <: Heuristic
goal_hash::Union{Nothing,UInt} # Hash of most recently pre-computed goal
statics::Vector{Symbol} # Static domain fluents
graph::PlanningGraph # Precomputed planning graph
search_state::PlanningGraphSearchState # Preallocated search state
FFHeuristic() = new()
end

Expand All @@ -46,6 +47,7 @@ function precompute!(h::FFHeuristic,
h.goal_hash = nothing
h.statics = infer_static_fluents(domain)
h.graph = build_planning_graph(domain, state; statics=h.statics)
h.search_state = PlanningGraphSearchState(h.graph)
return h
end

Expand All @@ -56,6 +58,7 @@ function precompute!(h::FFHeuristic,
h.goal_hash = hash(get_goal_terms(spec))
h.statics = infer_static_fluents(domain)
h.graph = build_planning_graph(domain, state, spec; statics=h.statics)
h.search_state = PlanningGraphSearchState(h.graph)
return h
end

Expand All @@ -68,8 +71,11 @@ function compute(h::FFHeuristic,
h.goal_hash = hash(get_goal_terms(spec))
end
# Compute achievers to each condition node of the relaxed planning graph
costs, achievers, goal_idx, _ =
relaxed_pgraph_search(domain, state, spec, maximum, h.graph)
init_pgraph_search!(h.search_state, h.graph, domain, state)
search_state, goal_idx, _ =
run_pgraph_search!(h.search_state, h.graph, spec)
costs = search_state.cond_costs
achievers = search_state.cond_achievers
# Return infinity if goal is not reached
if isnothing(goal_idx) return Inf32 end
# Initialize queue
Expand Down
10 changes: 7 additions & 3 deletions src/heuristics/hsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mutable struct HSPHeuristic{F <: Function} <: Heuristic
goal_hash::Union{Nothing,UInt} # Hash of most recently pre-computed goal
statics::Vector{Symbol} # Static domain fluents
graph::PlanningGraph # Precomputed planning graph
search_state::PlanningGraphSearchState # Preallocated search state
HSPHeuristic{F}() where {F <: Function} = new{F}(F.instance)
HSPHeuristic(op::F) where {F <: Function} = new{F}(op)
end
Expand Down Expand Up @@ -67,6 +68,7 @@ function precompute!(h::HSPHeuristic,
h.goal_hash = nothing
h.statics = infer_static_fluents(domain)
h.graph = build_planning_graph(domain, state; statics=h.statics)
h.search_state = PlanningGraphSearchState(h.graph)
return h
end

Expand All @@ -77,6 +79,7 @@ function precompute!(h::HSPHeuristic,
h.goal_hash = hash(get_goal_terms(spec))
h.statics = infer_static_fluents(domain)
h.graph = build_planning_graph(domain, state, spec; statics=h.statics)
h.search_state = PlanningGraphSearchState(h.graph)
return h
end

Expand All @@ -89,7 +92,8 @@ function compute(h::HSPHeuristic,
h.goal_hash = hash(get_goal_terms(spec))
end
# Compute relaxed costs to goal nodes of the planning graph
_, _, _, goal_cost = relaxed_pgraph_search(domain, state, spec, h.op, h.graph)
init_pgraph_search!(h.search_state, h.graph, domain, state)
_, _, goal_cost = run_pgraph_search!(h.search_state, h.graph, spec, h.op)
# Return goal cost (may be infinite if unreachable)
return goal_cost
end
Expand Down Expand Up @@ -143,9 +147,9 @@ function precompute!(h::HSPRHeuristic,
domain::Domain, state::State, spec::Specification)
# Construct and compute fact costs from planning graph
graph = build_planning_graph(domain, state)
costs, _, _, _ = relaxed_pgraph_search(domain, state, spec, h.op, graph)
search_state, _, _ = run_pgraph_search(graph, domain, state, spec, h.op)
# Convert costs to dictionary for fast look-up
h.costs = Dict{Term,Float64}(c => v for (c, v) in zip(graph.conditions, costs))
h.costs = Dict{Term,Float64}(zip(graph.conditions, search_state.cond_costs))
return h
end

Expand Down
32 changes: 19 additions & 13 deletions src/heuristics/lmcut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mutable struct LMCut <: Heuristic
goal_hash::Union{Nothing,UInt} # Hash of most recently pre-computed goal
statics::Vector{Symbol} # Static domain fluents
graph::PlanningGraph # Precomputed planning graph
search_state::PlanningGraphSearchState # Preallocated search state
action_costs::Vector{Float32} # Precomputed action costs
LMCut() = new()
end
Expand All @@ -39,6 +40,7 @@ function precompute!(h::LMCut, domain::Domain, state::State)
# Precompute static domain fluents and planning graph
h.statics = infer_static_fluents(domain)
h.graph = build_planning_graph(domain, state; statics=h.statics)
h.search_state = PlanningGraphSearchState(h.graph)
# Precompute cost of each action
n_actions = length(h.graph.actions)
h.action_costs = map(eachindex(h.graph.actions)) do act_idx
Expand All @@ -58,6 +60,7 @@ function precompute!(h::LMCut, domain::Domain, state::State, spec::Specification
# Precompute static domain fluents and planning graph
h.statics = infer_static_fluents(domain)
h.graph = build_planning_graph(domain, state, spec; statics=h.statics)
h.search_state = PlanningGraphSearchState(h.graph)
# Precompute cost of each action
n_actions = length(h.graph.actions)
h.action_costs = map(enumerate(h.graph.actions)) do (act_idx, act)
Expand Down Expand Up @@ -88,37 +91,40 @@ function compute(h::LMCut, domain::Domain, state::State, spec::Specification)
end
end
end
# Compute set of initial facts
init_idxs = pgraph_init_idxs(h.graph, domain, state)
# Initialize planning graph search state
init_pgraph_search!(h.search_state, h.graph, domain, state)
# Calculate relaxed costs of facts and the h-max value
cond_costs, _, goal_idx, goal_cost =
relaxed_pgraph_search(domain, state, spec, maximum, h.graph;
action_costs = h.action_costs)
search_state, goal_idx, goal_cost =
run_pgraph_search!(h.search_state, h.graph, spec, maximum;
action_costs = h.action_costs)
# Terminate early if goal is unreachable
goal_cost == Inf32 && return goal_cost
# Iteratively find landmark cuts and sum their costs
hval = 0.0f0
init_conds = search_state.init_conds
action_costs = copy(h.action_costs)
for _ in 1:length(h.graph.actions)
# Find the supporters for each action
supporters = find_supporters(h.graph, cond_costs)
supporters = find_supporters(h.graph, search_state.cond_costs)
# Construct the justification graph
jgraph = build_justification_graph(h.graph, supporters, action_costs)
# Extract the goal zone
goal_zone = extract_goal_zone(jgraph)
# Extract the pregoal zone, landmarks, and their cost
pregoal_zone, landmark_idxs, landmark_cost =
extract_pregoal_zone_and_landmarks(jgraph, goal_zone,
init_idxs, action_costs)
# Update heuristic value and action costs
init_conds, action_costs)
# Update heuristic value, action costs and search queue
hval += landmark_cost
for idx in landmark_idxs
action_costs[idx] -= landmark_cost
for act_idx in landmark_idxs
action_costs[act_idx] -= landmark_cost
end
# Re-calculate relaxed costs to each fact
cond_costs, _, goal_idx, goal_cost =
relaxed_pgraph_search(domain, state, spec, maximum, h.graph;
action_costs = action_costs)
init_pgraph_search!(h.search_state, h.graph, domain, state,
compute_init_conds = false)
search_state, goal_idx, goal_cost =
run_pgraph_search!(h.search_state, h.graph, spec, maximum;
action_costs = action_costs)
# Terminate once goal cost has been reduced to zero
iszero(goal_cost) && break
end
Expand Down
Loading

0 comments on commit 982df9a

Please sign in to comment.