Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance functional dependencies customization in ReactiveMP #437

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions docs/src/lib/nodes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# [Nodes implementation](@id lib-node)

In the message passing framework, one of the most important concepts is a factor node.
Expand Down Expand Up @@ -100,7 +99,11 @@ Here we see that in the standard setting for the belief-propagation message out
\mu(x) = \exp \int q(y) q(z) \log f(x, y, z) \mathrm{d}y \mathrm{d}z
```

We see that in this setting, we do not need messages $\mu(y)$ and $\mu(z)$, but only the marginals $q(y)$ and $q(z)$. The purpose of a __functional dependencies pipeline__ is to determine functional dependencies (a set of messages or marginals) that are needed to compute a single message. By default, `ReactiveMP.jl` uses so-called `DefaultFunctionalDependencies` that correctly implements belief-propagation and variational message passing schemes (including both mean-field and structured factorisations). The full list of built-in pipelines is presented below:
We see that in this setting, we do not need messages $\mu(y)$ and $\mu(z)$, but only the marginals $q(y)$ and $q(z)$.

## [List of functional dependencies pipelines](@id lib-node-functional-dependencies-pipelines)

The purpose of a __functional dependencies pipeline__ is to determine functional dependencies (a set of messages or marginals) that are needed to compute a single message. By default, `ReactiveMP.jl` uses so-called `DefaultFunctionalDependencies` that correctly implements belief-propagation and variational message passing schemes (including both mean-field and structured factorisations). The full list of built-in pipelines is presented below:

```@docs
ReactiveMP.DefaultFunctionalDependencies
Expand All @@ -109,6 +112,27 @@ ReactiveMP.RequireMarginalFunctionalDependencies
ReactiveMP.RequireEverythingFunctionalDependencies
```

## [Customizing Dependencies with Metadata](@id lib-node-metadata-dependencies)

The functional dependencies of a node can be customized at runtime using options during node activation. This allows for runtime customization of the functional dependencies, e.g. to test different message passing schemes or implement specialized behavior for specific instances of a node type:

```julia
# Define custom dependencies based on metadata
function ReactiveMP.collect_functional_dependencies(::Type{MyNode}, options::FactorNodeActivationOptions)
if some_condition(options) # a user can specify dependencies based, for example, on metadata
return CustomDependencies()
end
# Fall back to default dependencies
return ReactiveMP.collect_functional_dependencies(MyNode, getdependecies(options))
end

# Use custom dependencies during activation
node = factornode(MyNode, ...)
activate!(node, FactorNodeActivationOptions(:custom_behavior, ...))
```

This feature is particularly useful for testing different message passing schemes or implementing specialized behavior for specific instances of a node type.

## [Node traits](@id lib-node-traits)

Each factor node has to define the [`ReactiveMP.is_predefined_node`](@ref) trait function and to specify a [`ReactiveMP.PredefinedNodeFunctionalForm`](@ref)
Expand Down
5 changes: 4 additions & 1 deletion src/nodes/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,11 @@
getscheduler(options::FactorNodeActivationOptions) = options.scheduler
getrulefallback(options::FactorNodeActivationOptions) = options.rulefallback

# Users can override the dependencies if they want to
collect_functional_dependencies(fform::F, options::FactorNodeActivationOptions) where {F} = collect_functional_dependencies(fform, getdependecies(options))

Check warning on line 264 in src/nodes/nodes.jl

View check run for this annotation

Codecov / codecov/patch

src/nodes/nodes.jl#L264

Added line #L264 was not covered by tests

function activate!(factornode::FactorNode, options::FactorNodeActivationOptions)
dependencies = collect_functional_dependencies(functionalform(factornode), getdependecies(options))
dependencies = collect_functional_dependencies(functionalform(factornode), options)
initialize_clusters!(getlocalclusters(factornode), dependencies, factornode, options)
return activate!(dependencies, factornode, options)
end
Expand Down
98 changes: 98 additions & 0 deletions test/nodes/dependencies_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,101 @@ end
end
end
end

@testitem "Functional dependencies may change depending on the metadata from options" begin
# This test demonstrates how functional dependencies can be customized based on metadata
# passed during node activation. This is useful when:
# 1. The same node type needs different message passing behaviors in different contexts
# 2. Users want to override the default message passing behavior without creating a new node type
#
# The test creates a custom node with 3 interfaces (out, in1, in2) and shows how:
# - With meta = :use_a, the output message depends only on in1
# - With meta = :use_b, the output message depends only on in2
# - With meta = nothing, it falls back to default dependencies
#
# This pattern allows for flexible message passing schemes that can be configured at runtime
# rather than being hardcoded into the node type.

include("../testutilities.jl")
using BayesBase, Rocket

import ReactiveMP: NodeInterface, collect_functional_dependencies, getdata, getrecent, activate!, getmetadata, name, getinterface
import ReactiveMP: FunctionalDependencies, functional_dependencies, DefaultFunctionalDependencies, FactorNodeActivationOptions, getdependecies

# Define a custom node for testing
struct CustomMetaNode end

@node CustomMetaNode Stochastic [out, in1, in2]

# Define custom functional dependencies that we'll use based on meta
struct CustomDependencyA <: FunctionalDependencies end
struct CustomDependencyB <: FunctionalDependencies end

# Define how meta affects functional dependencies
ReactiveMP.collect_functional_dependencies(::Type{CustomMetaNode}, options::FactorNodeActivationOptions) =
ReactiveMP.collect_functional_dependencies(CustomMetaNode, options, getmetadata(options))

# Mock different behavior for our custom dependencies
ReactiveMP.collect_functional_dependencies(::Type{CustomMetaNode}, ::FactorNodeActivationOptions, meta::Symbol) = meta === :use_a ? CustomDependencyA() : CustomDependencyB()
ReactiveMP.collect_functional_dependencies(::Type{CustomMetaNode}, options::FactorNodeActivationOptions, meta::Nothing) =
ReactiveMP.collect_functional_dependencies(CustomMetaNode, getdependecies(options))

# Mock different behavior for our custom dependencies
function ReactiveMP.functional_dependencies(::CustomDependencyA, factornode, interface, iindex)
# CustomDependencyA only depends on in1
msg_deps = name(interface) === :out ? (getinterface(factornode, 2),) : () # only in1
return (msg_deps, ())
end

function ReactiveMP.functional_dependencies(::CustomDependencyB, factornode, interface, iindex)
# CustomDependencyB only depends on in2
msg_deps = name(interface) === :out ? (getinterface(factornode, 3),) : () # only in2
return (msg_deps, ())
end

# Create test variables
out_v = randomvar()
in1_v = ConstVariable(1.0)
in2_v = ConstVariable(2.0)

@testset "use_a metadata results in CustomDependencyA" begin
options_a = FactorNodeActivationOptions(:use_a, nothing, nothing, nothing, AsapScheduler(), nothing)
deps = collect_functional_dependencies(CustomMetaNode, options_a)
@test deps isa CustomDependencyA
end

@testset "use_b metadata results in CustomDependencyB" begin
options_b = FactorNodeActivationOptions(:use_b, nothing, nothing, nothing, AsapScheduler(), nothing)
deps = collect_functional_dependencies(CustomMetaNode, options_b)
@test deps isa CustomDependencyB
end

@testset "no metadata falls back to default dependencies" begin
options_default = FactorNodeActivationOptions(nothing, nothing, nothing, nothing, AsapScheduler(), nothing)
deps = collect_functional_dependencies(CustomMetaNode, options_default)
@test deps isa DefaultFunctionalDependencies
end

@testset "Dependencies change based on meta" begin
# Create node with meta :use_a
node_a = factornode(CustomMetaNode, [(:out, out_v), (:in1, in1_v), (:in2, in2_v)], ((1,),))
options_a = FactorNodeActivationOptions(:use_a, nothing, nothing, nothing, AsapScheduler(), nothing)
deps_a = collect_functional_dependencies(CustomMetaNode, options_a)
activate!(node_a, options_a)

out_interface_a = getinterface(node_a, 1)
msg_deps_a, marg_deps_a = functional_dependencies(deps_a, node_a, out_interface_a, 1)
@test length(msg_deps_a) == 1
@test name(first(msg_deps_a)) === :in1

# Test that functional dependencies are different with meta :use_b
node_b = factornode(CustomMetaNode, [(:out, out_v), (:in1, in1_v), (:in2, in2_v)], ((1,),))
options_b = FactorNodeActivationOptions(:use_b, nothing, nothing, nothing, AsapScheduler(), nothing)
deps_b = collect_functional_dependencies(CustomMetaNode, options_b)
activate!(node_b, options_b)
out_interface_b = getinterface(node_b, 1)
msg_deps_b, marg_deps_b = functional_dependencies(deps_b, node_b, out_interface_b, 1)
@test length(msg_deps_b) == 1
@test name(first(msg_deps_b)) === :in2
end
end
Loading