From 061acbe7fab5fac3e1af6fb43fd69938b9e8a5a4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 5 Mar 2025 10:34:04 +0000 Subject: [PATCH 01/48] Release 0.36 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a9463a821..a1bf65fd5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.0" +version = "0.36.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From fc323985ee744a38ce3013915150f0891d4af3ab Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 28 Mar 2025 17:01:39 +0000 Subject: [PATCH 02/48] AbstractPPL 0.11 + change prefixing behaviour (#830) * AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading --- HISTORY.md | 49 ++++++++++++++++ Project.toml | 2 +- docs/Project.toml | 1 + docs/src/api.md | 2 +- src/DynamicPPL.jl | 5 +- src/contexts.jl | 26 ++++----- src/debug_utils.jl | 2 +- src/model.jl | 58 ++++++------------- src/submodel_macro.jl | 16 +++--- src/utils.jl | 8 ++- test/Project.toml | 2 +- test/compiler.jl | 11 ++-- test/contexts.jl | 126 +++++++++++++++++++++++------------------- test/debug_utils.jl | 4 +- test/deprecated.jl | 2 +- test/model.jl | 6 +- 16 files changed, 180 insertions(+), 140 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 3ea8071f3..cd2757edc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,54 @@ # DynamicPPL Changelog +## 0.36.0 + +**Breaking changes** + +### VarName prefixing behaviour + +The way in which VarNames in submodels are prefixed has been changed. +This is best explained through an example. +Consider this model and submodel: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model outer() = a ~ to_submodel(inner()) +``` + +In previous versions, the inner variable `x` would be saved as `a.x`. +However, this was represented as a single symbol `Symbol("a.x")`: + +```julia +julia> dump(keys(VarInfo(outer()))[1]) +VarName{Symbol("a.x"), typeof(identity)} + optic: identity (function of type typeof(identity)) +``` + +Now, the inner variable is stored as a field `x` on the VarName `a`: + +```julia +julia> dump(keys(VarInfo(outer()))[1]) +VarName{:a, Accessors.PropertyLens{:x}} + optic: Accessors.PropertyLens{:x} (@o _.x) +``` + +In practice, this means that if you are trying to condition a variable in the submodel, you now need to use + +```julia +outer() | (@varname(a.x) => 1.0,) +``` + +instead of either of these (which would have worked previously) + +```julia +outer() | (@varname(var"a.x") => 1.0,) +outer() | (a.x=1.0,) +``` + +If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) + ## 0.35.5 Several internal methods have been removed: diff --git a/Project.toml b/Project.toml index d5185d727..516dee26e 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.10.1" +AbstractPPL = "0.11" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/docs/Project.toml b/docs/Project.toml index fa57f2c1c..40a719e03 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/docs/src/api.md b/docs/src/api.md index 9c8249c97..2f6376f5d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -149,7 +149,7 @@ In the past, one would instead embed sub-models using [`@submodel`](@ref), which In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing: ```@docs -prefix +DynamicPPL.prefix ``` Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 50fe0edc7..9f45718c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -21,6 +21,9 @@ using DocStringExtensions using Random: Random +# For extending +import AbstractPPL: predict + # TODO: Remove these when it's possible. import Bijectors: link, invlink @@ -39,8 +42,6 @@ import Base: keys, haskey -import AbstractPPL: predict - # VarInfo export AbstractVarInfo, VarInfo, diff --git a/src/contexts.jl b/src/contexts.jl index a54c60374..58ac612b8 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -260,25 +260,21 @@ function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} return PrefixContext{Prefix}(child) end -const PREFIX_SEPARATOR = Symbol(".") - -@generated function PrefixContext{PrefixOuter}( - context::PrefixContext{PrefixInner} -) where {PrefixOuter,PrefixInner} - return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - context.context - )) -end +""" + prefix(ctx::AbstractContext, vn::VarName) +Apply the prefixes in the context `ctx` to the variable name `vn`. +""" function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - vn_prefixed_inner = prefix(childcontext(ctx), vn) - return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}( - getoptic(vn_prefixed_inner) - ) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}()) +end +function prefix(ctx::AbstractContext, vn::VarName) + return prefix(NodeTrait(ctx), ctx, vn) end -prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn) prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn) +function prefix(::IsParent, ctx::AbstractContext, vn::VarName) + return prefix(childcontext(ctx), vn) +end """ prefix(model::Model, x) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 328fe6983..529092e8e 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -183,7 +183,7 @@ function DynamicPPL.setchildcontext(context::DebugContext, child) end function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = prefix(context, varname) + prefixed_varname = DynamicPPL.prefix(context, varname) if haskey(context.varnames_seen, prefixed_varname) if context.error_on_failure error("varname $prefixed_varname used multiple times in model") diff --git a/src/model.jl b/src/model.jl index a0451b1b6..b4d5f6bb7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -243,7 +243,7 @@ julia> model() ≠ 1.0 true julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`. - conditioned_model = model | (var"inner.m" = 1.0, ); + conditioned_model = model | (@varname(inner.m) => 1.0, ); julia> conditioned_model() 1.0 @@ -255,15 +255,6 @@ julia> conditioned_model_fail() ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported [...] ``` - -And similarly when using `Dict`: - -```jldoctest condition -julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0); - -julia> conditioned_model_dict() -1.0 -``` """ function AbstractPPL.condition(model::Model, values...) # Positional arguments - need to handle cases carefully @@ -443,16 +434,16 @@ julia> conditioned(cm) julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: +1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0); + cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0); -julia> conditioned(cm).x +julia> conditioned(cm)[@varname(x)] 100.0 -julia> conditioned(cm).var"a.m" +julia> conditioned(cm)[@varname(a.m)] 1.0 julia> keys(VarInfo(cm)) # No variables are sampled @@ -583,7 +574,7 @@ julia> model = demo_outer(); julia> model() ≠ 1.0 true -julia> fixed_model = fix(model, var"inner.m" = 1.0, ); +julia> fixed_model = fix(model, (@varname(inner.m) => 1.0, )); julia> fixed_model() 1.0 @@ -599,24 +590,9 @@ julia> fixed_model() 2.0 ``` -And similarly when using `Dict`: - -```jldoctest fix -julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0); - -julia> fixed_model_dict() -1.0 - -julia> fixed_model_dict = fix(model, @varname(inner) => 2.0); - -julia> fixed_model_dict() -2.0 -``` - ## Difference from `condition` -A very similar functionality is also provided by [`condition`](@ref) which, -not surprisingly, _conditions_ variables instead of fixing them. The only +A very similar functionality is also provided by [`condition`](@ref). The only difference between fixing and conditioning is as follows: - `condition`ed variables are considered to be observations, and are thus included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref), @@ -798,16 +774,16 @@ julia> fixed(cm) julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: +1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. - cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0); + cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0); -julia> fixed(cm).x +julia> fixed(cm)[@varname(x)] 100.0 -julia> fixed(cm).var"a.m" +julia> fixed(cm)[@varname(a.m)] 1.0 julia> keys(VarInfo(cm)) # <= no variables are sampled @@ -1365,7 +1341,7 @@ When we sample from the model `demo2(missing, 0.4)` random variable `x` will be ```jldoctest submodel-to_submodel julia> vi = VarInfo(demo2(missing, 0.4)); -julia> @varname(var\"a.x\") in keys(vi) +julia> @varname(a.x) in keys(vi) true ``` @@ -1379,7 +1355,7 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodel-to_submodel -julia> x = vi[@varname(var\"a.x\")]; +julia> x = vi[@varname(a.x)]; julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true @@ -1417,10 +1393,10 @@ julia> @model function demo2(x, y, z) julia> vi = VarInfo(demo2(missing, missing, 0.4)); -julia> @varname(var"sub1.x") in keys(vi) +julia> @varname(sub1.x) in keys(vi) true -julia> @varname(var"sub2.x") in keys(vi) +julia> @varname(sub2.x) in keys(vi) true ``` @@ -1437,9 +1413,9 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodel-to_submodel-prefix -julia> sub1_x = vi[@varname(var"sub1.x")]; +julia> sub1_x = vi[@varname(sub1.x)]; -julia> sub2_x = vi[@varname(var"sub2.x")]; +julia> sub2_x = vi[@varname(sub2.x)]; julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index e5a8e0617..f6b9c4479 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -96,10 +96,10 @@ julia> vi = VarInfo(demo2(missing, missing, 0.4)); │ caller = ip:0x0 └ @ Core :-1 -julia> @varname(var"sub1.x") in keys(vi) +julia> @varname(sub1.x) in keys(vi) true -julia> @varname(var"sub2.x") in keys(vi) +julia> @varname(sub2.x) in keys(vi) true ``` @@ -116,9 +116,9 @@ false We can check that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodelprefix -julia> sub1_x = vi[@varname(var"sub1.x")]; +julia> sub1_x = vi[@varname(sub1.x)]; -julia> sub2_x = vi[@varname(var"sub2.x")]; +julia> sub2_x = vi[@varname(sub2.x)]; julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); @@ -157,7 +157,7 @@ julia> # Automatically determined from `a`. @model submodel_prefix_true() = @submodel prefix=true a = inner() submodel_prefix_true (generic function with 2 methods) -julia> @varname(var"a.x") in keys(VarInfo(submodel_prefix_true())) +julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -167,7 +167,7 @@ julia> # Using a static string. @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() submodel_prefix_string (generic function with 2 methods) -julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) +julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -177,7 +177,7 @@ julia> # Using string interpolation. @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() submodel_prefix_interpolation (generic function with 2 methods) -julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) +julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -187,7 +187,7 @@ julia> # Or using some arbitrary expression. @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() submodel_prefix_expr (generic function with 2 methods) -julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) +julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 diff --git a/src/utils.jl b/src/utils.jl index 50f9baf61..56c3d70af 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1285,14 +1285,18 @@ broadcast_safe(x) = x broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) +# Convert (x=1,) to Dict(@varname(x) => 1) +_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) # Version of `merge` used by `conditioned` and `fixed` to handle # the scenario where we might try to merge a dict with an empty # tuple. # TODO: Maybe replace the default of returning `NamedTuple` with `nothing`? _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) -_merge(left::AbstractDict, right::NamedTuple{()}) = left -_merge(left::NamedTuple{()}, right::AbstractDict) = right +_merge(left::AbstractDict, ::NamedTuple{()}) = left +_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right)) +_merge(::NamedTuple{()}, right::AbstractDict) = right +_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right) """ unique_syms(vns::T) where {T<:NTuple{N,VarName}} diff --git a/test/Project.toml b/test/Project.toml index 9fa3fd872..79e6d129b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,7 +32,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.10.1" +AbstractPPL = "0.11" Accessors = "0.1" Aqua = "0.8" Bijectors = "0.15.1" diff --git a/test/compiler.jl b/test/compiler.jl index 3d3c6d9e3..a0286d405 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -481,8 +481,8 @@ module Issue537 end m = demo_useval(missing, missing) vi = VarInfo(m) ks = keys(vi) - @test VarName{Symbol("sub1.x")}() ∈ ks - @test VarName{Symbol("sub2.x")}() ∈ ks + @test @varname(sub1.x) ∈ ks + @test @varname(sub2.x) ∈ ks @test @varname(z) ∈ ks @test abs(mean([VarInfo(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10 @@ -505,7 +505,7 @@ module Issue537 end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - x ~ to_submodel(prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) + x ~ to_submodel(DynamicPPL.prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) y[i] ~ MvNormal(x, 0.01 * I) end end @@ -514,8 +514,9 @@ module Issue537 end m = demo(ys) vi = VarInfo(m) - for k in [:α, :μ, :σ, Symbol("ar1_1.η"), Symbol("ar1_2.η")] - @test VarName{k}() ∈ keys(vi) + for vn in + [@varname(α), @varname(μ), @varname(σ), @varname(ar1_1.η), @varname(ar1_2.η)] + @test vn ∈ keys(vi) end end diff --git a/test/contexts.jl b/test/contexts.jl index faa831cc1..11e591f8f 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -39,44 +39,39 @@ end Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() -""" - remove_prefix(vn::VarName) - -Return `vn` but now with the prefix removed. -""" -function remove_prefix(vn::VarName) - return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - getoptic(vn) +@testset "contexts.jl" begin + child_contexts = Dict( + :default => DefaultContext(), + :prior => PriorContext(), + :likelihood => LikelihoodContext(), ) -end -@testset "contexts.jl" begin - child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] - - parent_contexts = [ - DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - SamplingContext(), - MiniBatchContext(DefaultContext(), 0.0), - PrefixContext{:x}(DefaultContext()), - PointwiseLogdensityContext(), - ConditionContext((x=1.0,)), - ConditionContext( + parent_contexts = Dict( + :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), + :sampling => SamplingContext(), + :minibatch => MiniBatchContext(DefaultContext(), 0.0), + :prefix => PrefixContext{:x}(DefaultContext()), + :pointwiselogdensity => PointwiseLogdensityContext(), + :condition1 => ConditionContext((x=1.0,)), + :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), - ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), - ConditionContext((x=[1.0, missing],)), - ] + :condition3 => ConditionContext( + (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0))) + ), + :condition4 => ConditionContext((x=[1.0, missing],)), + ) - contexts = vcat(child_contexts, parent_contexts) + contexts = merge(child_contexts, parent_contexts) - @testset "$(context)" for context in contexts + @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) end end @testset "contextual_isassumption" begin - @testset "$context" for context in contexts + @testset "$(name)" for (name, context) in contexts # Any `context` should return `true` by default. @test contextual_isassumption(context, VarName{gensym(:x)}()) @@ -85,14 +80,28 @@ end # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() + # The conditioned values might be a NamedTuple, or a Dict. + # We convert to a Dict for consistency + if conditioned_values isa NamedTuple + conditioned_values = Dict( + VarName{sym}() => val for (sym, val) in pairs(conditioned_values) + ) + end + for (vn, val) in pairs(conditioned_values) # We need to drop the prefix of `var` since in `contextual_isassumption` # it will be threaded through the `PrefixContext` before it reaches # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) + vn_without_prefix = if getoptic(vn) isa PropertyLens + # Hacky: This assumes that there is exactly one level of prefixing + # that we need to undo. This is appropriate for the :condition3 + # test case above, but is not generally correct. + AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + else + vn + end + @show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # Let's check elementwise. for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) @@ -108,7 +117,7 @@ end end @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$context" for context in contexts + @testset "$name" for (name, context) in contexts fake_vn = VarName{gensym(:x)}() @test !hasconditioned_nested(context, fake_vn) @test_throws ErrorException getconditioned_nested(context, fake_vn) @@ -118,14 +127,26 @@ end # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) + # The conditioned values might be a NamedTuple, or a Dict. + # We convert to a Dict for consistency + if conditioned_values isa NamedTuple + conditioned_values = Dict( + VarName{sym}() => val for (sym, val) in pairs(conditioned_values) + ) + end - for (sym, val) in pairs(conditioned_values) - vn = VarName{sym}() - + for (vn, val) in pairs(conditioned_values) # We need to drop the prefix of `var` since in `contextual_isassumption` # it will be threaded through the `PrefixContext` before it reaches # `ConditionContext` with the conditioned variable. - vn_without_prefix = remove_prefix(vn) + vn_without_prefix = if getoptic(vn) isa PropertyLens + # Hacky: This assumes that there is exactly one level of prefixing + # that we need to undo. This is appropriate for the :condition3 + # test case above, but is not generally correct. + AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + else + vn + end for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) @@ -153,51 +174,42 @@ end ) vn = VarName{:x}() vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + @test vn_prefixed == @varname(a.b.c.d.e.f.x) vn = VarName{:x}(((1,),)) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) ctx1 = PrefixContext{:a}(DefaultContext()) + @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) + @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext{:b}(ctx2) + @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) - vn_prefixed1 = prefix(ctx1, vn) - vn_prefixed2 = prefix(ctx2, vn) - vn_prefixed3 = prefix(ctx3, vn) - vn_prefixed4 = prefix(ctx4, vn) - @test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x") - @test DynamicPPL.getsym(vn_prefixed3) == Symbol("b.a.x") - @test DynamicPPL.getsym(vn_prefixed4) == Symbol("b.a.x") - @test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn) - @test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn) + @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end - context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + prefix = :my_prefix + context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) - # Extract the resulting symbols. - vns_varinfo_syms = Set(map(DynamicPPL.getsym, keys(varinfo))) + # Extract the resulting varnames + vns_actual = Set(keys(varinfo)) - # Extract the ground truth symbols. - vns_syms = Set([ - Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for + # Extract the ground truth varnames + vns_expected = Set([ + AbstractPPL.prefix(vn, VarName{prefix}()) for vn in DynamicPPL.TestUtils.varnames(model) ]) # Check that all variables are prefixed correctly. - @test vns_syms == vns_varinfo_syms + @test vns_actual == vns_expected end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index d4f6601f5..cac52693e 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -63,8 +63,8 @@ # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() - x1 ~ to_submodel(prefix(ModelInner(), :a), false) - x2 ~ to_submodel(prefix(ModelInner(), :b), false) + x1 ~ to_submodel(DynamicPPL.prefix(ModelInner(), :a), false) + x2 ~ to_submodel(DynamicPPL.prefix(ModelInner(), :b), false) return (x1, x2) end model = ModelOuterWorking2() diff --git a/test/deprecated.jl b/test/deprecated.jl index f12217983..500d3eb7f 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -31,7 +31,7 @@ @test outer()() isa Tuple{Float64,Float64} vi = VarInfo(outer()) @test @varname(x) in keys(vi) - @test @varname(var"sub.x") in keys(vi) + @test @varname(sub.x) in keys(vi) end @testset "logp is still accumulated properly" begin diff --git a/test/model.jl b/test/model.jl index a863b6596..447a9ecaa 100644 --- a/test/model.jl +++ b/test/model.jl @@ -448,15 +448,15 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() return nothing end @model function outer_manual_prefix() - a ~ to_submodel(prefix(inner(), :a), false) - b ~ to_submodel(prefix(inner(), :b), false) + a ~ to_submodel(DynamicPPL.prefix(inner(), :a), false) + b ~ to_submodel(DynamicPPL.prefix(inner(), :b), false) return nothing end for model in (outer_auto_prefix(), outer_manual_prefix()) vi = VarInfo(model) vns = Set(keys(values_as_in_model(model, false, vi))) - @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) + @test vns == Set([@varname(a.x), @varname(b.x)]) end end end From cc5e581ee00424e88dfb25873ce54e69f8f65d8c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 28 Mar 2025 17:04:32 +0000 Subject: [PATCH 03/48] Remove VarInfo(VarInfo, params) (#870) --- HISTORY.md | 4 ++++ src/varinfo.jl | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index cd2757edc..a956bd188 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,10 @@ **Breaking changes** +### VarInfo constructor + +`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. + ### VarName prefixing behaviour The way in which VarNames in submodels are prefixed has been changed. diff --git a/src/varinfo.jl b/src/varinfo.jl index 0c033e504..94b1f1c07 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -100,8 +100,6 @@ const TypedVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } -# TODO: Remove this -@deprecate VarInfo(vi::VarInfo, x::AbstractVector) unflatten(vi, x) # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` From b9c368b500ed1e5904f2229e915d3cefddd45171 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Apr 2025 11:35:53 +0100 Subject: [PATCH 04/48] Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879) * Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really) --- HISTORY.md | 31 +- benchmarks/src/DynamicPPLBenchmarks.jl | 10 +- docs/src/api.md | 13 +- docs/src/internals/varinfo.md | 4 +- src/DynamicPPL.jl | 2 - src/abstract_varinfo.jl | 8 +- src/sampler.jl | 2 +- src/simple_varinfo.jl | 4 +- src/test_utils/contexts.jl | 2 +- src/test_utils/varinfo.jl | 10 +- src/varinfo.jl | 510 ++++++++++++++++--------- test/ext/DynamicPPLJETExt.jl | 8 +- test/model.jl | 18 +- test/simple_varinfo.jl | 2 +- test/test_util.jl | 18 +- test/varinfo.jl | 49 +-- 16 files changed, 436 insertions(+), 255 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index a956bd188..1af5c2ca3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,10 +4,25 @@ **Breaking changes** -### VarInfo constructor +### VarInfo constructors `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + ### VarName prefixing behaviour The way in which VarNames in submodels are prefixed has been changed. @@ -53,6 +68,20 @@ outer() | (a.x=1.0,) If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +**Other changes** + +While these are technically breaking, they are only internal changes and do not affect the public API. +The following four functions have been added and/or reworked to make it easier to construct VarInfos with different types of metadata: + + 1. `DynamicPPL.untyped_varinfo([rng, ]model[, sampler, context])` + 2. `DynamicPPL.typed_varinfo([rng, ]model[, sampler, context])` + 3. `DynamicPPL.untyped_vector_varinfo([rng, ]model[, sampler, context])` + 4. `DynamicPPL.typed_vector_varinfo([rng, ]model[, sampler, context])` + +The reason for this change is that there were several flavours of VarInfo. +Some, like `typed_varinfo`, were easy to construct because we had convenience methods for them; however, the others were more difficult. +This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing. + ## 0.35.5 Several internal methods have been removed: diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 4c73bf355..16338de2f 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -52,8 +52,8 @@ end Create a benchmark suite for `model` using the selected varinfo type and AD backend. Available varinfo choices: - • `:untyped` → uses `VarInfo()` - • `:typed` → uses `VarInfo(model)` + • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` + • `:typed` → uses `DynamicPPL.typed_varinfo(model)` • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) @@ -67,11 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: suite = BenchmarkGroup() vi = if varinfo_choice == :untyped - vi = VarInfo() - model(rng, vi) - vi + DynamicPPL.untyped_varinfo(rng, model) elseif varinfo_choice == :typed - VarInfo(rng, model) + DynamicPPL.typed_varinfo(rng, model) elseif varinfo_choice == :simple_namedtuple SimpleVarInfo{Float64}(model(rng)) elseif varinfo_choice == :simple_dict diff --git a/docs/src/api.md b/docs/src/api.md index 2f6376f5d..f83a96886 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -291,18 +291,17 @@ AbstractVarInfo But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. -For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods: +#### `VarInfo` ```@docs -DynamicPPL.untyped_varinfo -DynamicPPL.typed_varinfo +VarInfo ``` -#### `VarInfo` - ```@docs -VarInfo -TypedVarInfo +DynamicPPL.untyped_varinfo +DynamicPPL.typed_varinfo +DynamicPPL.untyped_vector_varinfo +DynamicPPL.typed_vector_varinfo ``` One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index e6e1f2619..b04913aaf 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi ```@example varinfo-design # Type-unstable -varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped) +varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped) varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] ``` ```@example varinfo-design # Type-stable -varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed) +varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed) varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9f45718c5..51fa53079 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -45,8 +45,6 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - UntypedVarInfo, - TypedVarInfo, SimpleVarInfo, push!!, empty!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 44edaa4e9..f11b8a3ec 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -247,11 +247,11 @@ julia> values_as(SimpleVarInfo(data), Vector) 2.0 ``` -`TypedVarInfo`: +`VarInfo` with `NamedTuple` of `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; @@ -273,11 +273,11 @@ julia> values_as(vi, Vector) 2.0 ``` -`UntypedVarInfo`: +`VarInfo` with `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); + vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; diff --git a/src/sampler.jl b/src/sampler.jl index ff008cc93..49d910fec 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -86,7 +86,7 @@ function default_varinfo( context::AbstractContext, ) init_sampler = initialsampler(sampler) - return VarInfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler, context) end function AbstractMCMC.sample( diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 064483ddd..abf14b8fc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -10,7 +10,7 @@ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. $(FIELDS) # Notes -The major differences between this and `TypedVarInfo` are: +The major differences between this and `NTVarInfo` are: 1. `SimpleVarInfo` does not require linearization. 2. `SimpleVarInfo` can use more efficient bijectors. 3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either @@ -244,7 +244,7 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} +function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) end function SimpleVarInfo{T}( diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 5150be64b..7404a9af7 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -94,7 +94,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) # Typed varinfo. - varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) + varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6a655ded4..539872143 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -27,12 +27,10 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped_metadata = VarInfo(DynamicPPL.Metadata()) - vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector()) - model(vi_untyped_metadata) - model(vi_untyped_vnv) - vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata) - vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv) + vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) + vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_typed_metadata = DynamicPPL.typed_varinfo(model) + vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) diff --git a/src/varinfo.jl b/src/varinfo.jl index 94b1f1c07..360857ef7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,34 +69,91 @@ end ########### """ -``` -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo - metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} -end -``` + struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + metadata::Tmeta + logp::Base.RefValue{Tlogp} + num_produce::Base.RefValue{Int} + end + +A light wrapper over some kind of metadata. -A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of -`VarInfo`. If `vi isa VarInfo{<:Metadata}`, then only one `Metadata` instance is used -for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If -`vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each -symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows -for the type specialization of `vi` after the first sampling iteration when all the -symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. +The type of the metadata can be one of a number of options. It may either be a +`Metadata` or a `VarNamedVector`, _or_, it may be a `NamedTuple` which maps +symbols to `Metadata` or `VarNamedVector` instances. Here, a _symbol_ refers +to a Julia variable and may consist of one or more `VarName`s which appear on +the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both +have the same symbol `x`. -Note: It is the user's responsibility to ensure that each "symbol" is visited at least -once whenever the model is called, regardless of any stochastic branching. Each symbol -refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. +Several type aliases are provided for these forms of VarInfos: +- `VarInfo{<:Metadata}` is `UntypedVarInfo` +- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` +- `VarInfo{<:NamedTuple}` is `NTVarInfo` + +The NamedTuple form, i.e. `NTVarInfo`, is useful for maintaining type stability +of model evaluation. However, the element type of NamedTuples are not contained +in its type itself: thus, there is no way to use the type system to determine +whether the elements of the NamedTuple are `Metadata` or `VarNamedVector`. + +Note that for NTVarInfo, it is the user's responsibility to ensure that each +symbol is visited at least once during model evaluation, regardless of any +stochastic branching. """ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo metadata::Tmeta logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end -const VectorVarInfo = VarInfo{<:VarNamedVector} +VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +""" + VarInfo([rng, ]model[, sampler, context]) + +Generate a `VarInfo` object for the given `model`, by evaluating it once using +the given `rng`, `sampler`, and `context`. + +!!! warning + + This function currently returns a `VarInfo` with its metadata field set to + a `NamedTuple` of `Metadata`. This is an implementation detail. In general, + this function may return any kind of object that satisfies the + `AbstractVarInfo` interface. If you require precise control over the type + of `VarInfo` returned, use the internal functions `untyped_varinfo`, + `typed_varinfo`, `untyped_vector_varinfo`, or `typed_vector_varinfo` + instead. +""" +function VarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(rng, model, sampler, context) +end +function VarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return VarInfo(Random.default_rng(), model, sampler, context) +end +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return VarInfo(rng, model, SampleFromPrior(), context) +end +function VarInfo(model::Model, context::AbstractContext) + # No sampler, no rng + return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +end + +const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} -const TypedVarInfo = VarInfo{<:NamedTuple} +# TODO: NTVarInfo carries no information about the type of the actual metadata +# i.e. the elements of the NamedTuple. It could be Metadata or it could be +# VarNamedVector. +# Resolving this ambiguity would likely require us to replace NamedTuple with +# something which carried both its keys as well as its values' types as type +# parameters. +const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } @@ -132,70 +189,245 @@ function metadata_to_varnamedvector(md::Metadata) ) end -function VectorVarInfo(vi::UntypedVarInfo) - md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - -function VectorVarInfo(vi::TypedVarInfo) - md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - function has_varnamedvector(vi::VarInfo) return vi.metadata isa VarNamedVector || - (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) + (vi isa NTVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) end +######################## +# VarInfo constructors # +######################## + """ - untyped_varinfo(model[, context, metadata]) + untyped_varinfo([rng, ]model[, sampler, context, metadata]) -Return an untyped varinfo object for the given `model` and `context`. +Return a VarInfo object for the given `model` and `context`, which has just a +single `Metadata` as its metadata field. # Arguments -- `model::Model`: The model for which to create the varinfo object. -- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`. -- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object. - Default: `Metadata()`. +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ +function untyped_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + varinfo = VarInfo(Metadata()) + context = SamplingContext(rng, sampler, context) + return last(evaluate!!(model, varinfo, context)) +end function untyped_varinfo( model::Model, - context::AbstractContext=SamplingContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - varinfo = VarInfo(metadata) - return last( - evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context)) + # No rng + return untyped_varinfo(Random.default_rng(), model, sampler, context) +end +function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return untyped_varinfo(rng, model, SampleFromPrior(), context) +end +function untyped_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return untyped_varinfo(model, SampleFromPrior(), context) +end + +""" + typed_varinfo(vi::UntypedVarInfo) + +This function finds all the unique `sym`s from the instances of `VarName{sym}` found in +`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the +global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as +a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each +symbol. +""" +function typed_varinfo(vi::UntypedVarInfo) + meta = vi.metadata + new_metas = Metadata[] + # Symbols of all instances of `VarName{sym}` in `vi.vns` + syms_tuple = Tuple(syms(vi)) + for s in syms_tuple + # Find all indices in `vns` with symbol `s` + inds = findall(vn -> getsym(vn) === s, meta.vns) + n = length(inds) + # New `vns` + sym_vns = getindex.((meta.vns,), inds) + # New idcs + sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) + # New dists + sym_dists = getindex.((meta.dists,), inds) + # New orders + sym_orders = getindex.((meta.orders,), inds) + # New flags + sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) + + # Extract new ranges and vals + _ranges = getindex.((meta.ranges,), inds) + # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 + _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] + sym_ranges = Vector{eltype(_ranges)}(undef, n) + start = 0 + for i in 1:n + sym_ranges[i] = (start + 1):(start + length(_vals[i])) + start += length(_vals[i]) + end + sym_vals = foldl(vcat, _vals) + + push!( + new_metas, + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags + ), + ) + end + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple{syms_tuple}(Tuple(new_metas)) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end +function typed_varinfo(vi::NTVarInfo) + # This function preserves the behaviour of typed_varinfo(vi) where vi is + # already a NTVarInfo + has_varnamedvector(vi) && error( + "Cannot convert VarInfo with NamedTuple of VarNamedVector to VarInfo with NamedTuple of Metadata", ) + return vi +end +""" + typed_varinfo([rng, ]model[, sampler, context, metadata]) + +Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +`Metadata` structs as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function typed_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) +end +function typed_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return typed_varinfo(Random.default_rng(), model, sampler, context) +end +function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return typed_varinfo(rng, model, SampleFromPrior(), context) +end +function typed_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return typed_varinfo(model, SampleFromPrior(), context) end """ - typed_varinfo(model[, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) -Return a typed varinfo object for the given `model`, `sampler` and `context`. +Return a VarInfo object for the given `model` and `context`, which has just a +single `VarNamedVector` as its metadata field. -This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting -varinfo object to a typed varinfo object. +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function untyped_vector_varinfo(vi::UntypedVarInfo) + md = metadata_to_varnamedvector(vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function untyped_vector_varinfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) +end +function untyped_vector_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) +end +function untyped_vector_varinfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) + # No sampler + return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) +end +function untyped_vector_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return untyped_vector_varinfo(model, SampleFromPrior(), context) +end -See also: [`DynamicPPL.untyped_varinfo`](@ref) """ -typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) + typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) -function VarInfo( +Return a VarInfo object for the given `model` and `context`, which has a +NamedTuple of `VarNamedVector`s as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function typed_vector_varinfo(vi::NTVarInfo) + md = map(metadata_to_varnamedvector, vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function typed_vector_varinfo(vi::UntypedVectorVarInfo) + new_metas = group_by_symbol(vi.metadata) + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple(new_metas) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end +function typed_vector_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) end -function VarInfo( - model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... +function typed_vector_varinfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - return VarInfo(Random.default_rng(), model, args...) + # No rng + return typed_vector_varinfo(Random.default_rng(), model, sampler, context) +end +function typed_vector_varinfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) + # No sampler + return typed_vector_varinfo(rng, model, SampleFromPrior(), context) +end +function typed_vector_varinfo(model::Model, context::AbstractContext) + # No sampler, no rng + return typed_vector_varinfo(model, SampleFromPrior(), context) end """ @@ -204,7 +436,7 @@ end Return the length of the vector representation of `varinfo`. """ vector_length(varinfo::VarInfo) = length(varinfo.metadata) -vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +vector_length(varinfo::NTVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) @@ -241,11 +473,6 @@ end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) -# without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - return VarInfo(rng, model, SampleFromPrior(), context) -end - #### #### Internal functions #### @@ -500,7 +727,7 @@ setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val Return the metadata in `vi` that belongs to `vn`. """ getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) +getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) """ getidx(vi::VarInfo, vn::VarName) @@ -541,7 +768,7 @@ end Return the range corresponding to `varname` in the vector representation of `varinfo`. """ vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) -function vector_getrange(vi::TypedVarInfo, vn::VarName) +function vector_getrange(vi::NTVarInfo, vn::VarName) offset = 0 for md in values(vi.metadata) # First, we need to check if `vn` is in `md`. @@ -563,8 +790,8 @@ Return the range corresponding to `varname` in the vector representation of `var function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) return map(Base.Fix1(vector_getrange, varinfo), varname) end -# Specialized version for `TypedVarInfo`. -function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) +# Specialized version for `NTVarInfo`. +function vector_getranges(varinfo::NTVarInfo, vns::Vector{<:VarName}) # TODO: Does it help if we _don't_ convert to a vector here? metadatas = collect(values(varinfo.metadata)) # Extract the offsets. @@ -624,7 +851,7 @@ end getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) # NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. # See for example https://github.com/JuliaLang/julia/pull/46381. -function getindex_internal(vi::TypedVarInfo, ::Colon) +function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end function getindex_internal(md::Metadata, ::Colon) @@ -684,10 +911,10 @@ settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) Returns a tuple of the unique symbols of random variables in `vi`. """ syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::TypedVarInfo) = keys(vi.metadata) +syms(vi::NTVarInfo) = keys(vi.metadata) _getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) -_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) +_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) @generated function _getidcs(metadata::NamedTuple{names}) where {names} exprs = [] @@ -702,12 +929,11 @@ end findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) """ - all_varnames_grouped_by_symbol(vi::TypedVarInfo) + all_varnames_grouped_by_symbol(vi::NTVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ -all_varnames_grouped_by_symbol(vi::TypedVarInfo) = - all_varnames_grouped_by_symbol(vi.metadata) +all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(vi.metadata) @generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} expr = Expr(:tuple) @@ -745,73 +971,6 @@ end #### APIs for typed and untyped VarInfo #### -# VarInfo - -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) - -function TypedVarInfo(vi::VectorVarInfo) - new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end - -""" - TypedVarInfo(vi::UntypedVarInfo) - -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function TypedVarInfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New orders - sym_orders = getindex.((meta.orders,), inds) - # New flags - sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = (start + 1):(start + length(_vals[i])) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) - - push!( - new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags - ), - ) - end - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end -TypedVarInfo(vi::TypedVarInfo) = vi - function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) resetlogp!!(vi) @@ -834,8 +993,8 @@ Base.keys(vi::VarInfo) = Base.keys(vi.metadata) # HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly # on other methods in the codebase which requires `Vector{<:VarName}`. -Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] -@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} +Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] +@generated function Base.keys(vi::NTVarInfo{<:NamedTuple{names}}) where {names} expr = Expr(:call) push!(expr.args, :vcat) @@ -898,7 +1057,7 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -952,13 +1111,13 @@ function _link!(vi::UntypedVarInfo, vns) end end -# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link!(vi::NTVarInfo, vns::VarNameTuple) return _link!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::TypedVarInfo, vns::NamedTuple) +function _link!(vi::NTVarInfo, vns::NamedTuple) return _link!(vi.metadata, vi, vns) end @@ -1002,7 +1161,7 @@ end return expr end -function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1064,13 +1223,13 @@ function _invlink!(vi::UntypedVarInfo, vns) end end -# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink!(vi::NTVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::TypedVarInfo, vns::NamedTuple) +function _invlink!(vi::NTVarInfo, vns::NamedTuple) return _invlink!(vi.metadata, vi, vns) end @@ -1121,7 +1280,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) return vi end -function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link(::DynamicTransformation, vi::NTVarInfo, model::Model) return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1156,13 +1315,13 @@ function _link(model::Model, varinfo::VarInfo, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _link(model, varinfo, group_varnames_by_symbol(vns)) end -function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1257,7 +1416,7 @@ function _link_metadata!!( return metadata end -function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1297,13 +1456,13 @@ function _invlink(model::Model, varinfo::VarInfo, vns) ) end -# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end -function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1394,7 +1553,7 @@ end # TODO(mhauru) The treatment of the case when some variables are linked and others are not # should be revised. It used to be the case that for UntypedVarInfo `islinked` returned -# whether the first variable was linked. For TypedVarInfo we did an OR over the first +# whether the first variable was linked. For NTVarInfo we did an OR over the first # variables under each symbol. We now more consistently use OR, but I'm not convinced this # is really the right thing to do. """ @@ -1538,7 +1697,7 @@ Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) Check whether `vn` has a value in `vi`. """ Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function Base.haskey(vi::TypedVarInfo, vn::VarName) +function Base.haskey(vi::NTVarInfo, vn::VarName) md_haskey = map(vi.metadata) do metadata haskey(metadata, vn) end @@ -1601,12 +1760,12 @@ the `VarInfo` `vi`, mutating if it makes sense. function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist" + elseif vi isa NTVarInfo + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" end sym = getsym(vn) - if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + if vi isa NTVarInfo && ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. val = tovec(r) md = Metadata( @@ -1627,18 +1786,18 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) return vi end -function Base.push!(vi::VectorVarInfo, vn::VarName, val, args...) +function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) push!(getmetadata(vi, vn), vn, val, args...) return vi end -function Base.push!(vi::VectorVarInfo, pair::Pair, args...) +function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) vn, val = pair return push!(vi, vn, val, args...) end -# TODO(mhauru) push! can't be implemented in-place for TypedVarInfo if the symbol doesn't -# exist in the TypedVarInfo already. We could implement it in the cases where it it does +# TODO(mhauru) push! can't be implemented in-place for NTVarInfo if the symbol doesn't +# exist in the NTVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. function Base.push!(meta::Metadata, vn, r, dist, num_produce) @@ -1760,7 +1919,7 @@ function set_retained_vns_del!(vi::UntypedVarInfo) end return nothing end -function set_retained_vns_del!(vi::TypedVarInfo) +function set_retained_vns_del!(vi::NTVarInfo) idcs = _getidcs(vi) return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end @@ -1821,12 +1980,12 @@ function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) return vi end -function _apply!(kernel!, vi::TypedVarInfo, values, keys) +function _apply!(kernel!, vi::NTVarInfo, values, keys) return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) end @generated function _typed_apply!( - kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, keys + kernel!, vi::NTVarInfo, metadata::NamedTuple{names}, values, keys ) where {names} updates = map(names) do n quote @@ -1963,7 +2122,8 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()); # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. +julia> var_info = DynamicPPL.VarInfo(rng, m); + # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. julia> var_info[@varname(m)] -0.6702516921145671 @@ -2061,8 +2221,8 @@ function values_as( return ConstructionBase.constructorof(D)(iter) end -values_as(vi::VectorVarInfo, args...) = values_as(vi.metadata, args...) -values_as(vi::VectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) +values_as(vi::UntypedVectorVarInfo, args...) = values_as(vi.metadata, args...) +values_as(vi::UntypedVectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) function values_from_metadata(md::Metadata) return ( diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 933bfb1d1..86329a51d 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -14,7 +14,7 @@ @model demo2() = x ~ Normal() @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo @model function demo3() # Just making sure that nothing strange happens when type inference fails. @@ -53,7 +53,7 @@ end # Should pass if we're only checking the tilde statements. @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo # Should fail if we're including errors in the model body. @test DynamicPPL.Experimental.determine_suitable_varinfo( demo5(); only_ddpl=false @@ -75,11 +75,11 @@ ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.TypedVarInfo + is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed # If the test failed, check why it didn't infer a typed varinfo if !is_typed - typed_vi = VarInfo(model) + typed_vi = DynamicPPL.typed_varinfo(model) f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, typed_vi ) diff --git a/test/model.jl b/test/model.jl index 447a9ecaa..dd5a35fe6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,9 +25,9 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end -is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true -is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false +is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true +is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -233,8 +233,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - spl = SampleFromPrior() - vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) + vi = VarInfo(model) vi = link!!(vi, model) for i in 1:10 @@ -250,8 +249,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, VectorVarInfo" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() for i in 1:10 - vi = VarInfo(model) - @test vi[@varname(x)] >= vi[@varname(m)] + for vi_constructor in + [DynamicPPL.typed_vector_varinfo, DynamicPPL.untyped_vector_varinfo] + vi = vi_constructor(model) + @test vi[@varname(x)] >= vi[@varname(m)] + end end end @@ -400,7 +402,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( - is_typed_varinfo, + is_type_stable_varinfo, DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 8e48814a4..aa3b592f7 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -92,7 +92,7 @@ SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), SimpleVarInfo(DynamicPPL.VarNamedVector()), - VarInfo(model), + DynamicPPL.typed_varinfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) diff --git a/test/test_util.jl b/test/test_util.jl index 87c69b5fe..902dd7230 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -33,14 +33,18 @@ end Return string representing a short description of `vi`. """ -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = - "threadsafe($(short_varinfo_name(vi.varinfo)))" -function short_varinfo_name(vi::TypedVarInfo) - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" - return "TypedVarInfo" +function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) + return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +function short_varinfo_name(vi::DynamicPPL.NTVarInfo) + return if DynamicPPL.has_varnamedvector(vi) + "TypedVectorVarInfo" + else + "TypedVarInfo" + end +end +short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" end diff --git a/test/varinfo.jl b/test/varinfo.jl index 74feb42f6..777917aa6 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -34,7 +34,7 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) end @testset "varinfo.jl" begin - @testset "TypedVarInfo with Metadata" begin + @testset "VarInfo with NT of Metadata" begin @model gdemo(x, y) = begin s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) @@ -43,9 +43,8 @@ end end model = gdemo(1.0, 2.0) - vi = VarInfo(DynamicPPL.Metadata()) - model(vi, SampleFromUniform()) - tvi = TypedVarInfo(vi) + vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata for f in fieldnames(typeof(tvi.metadata)) @@ -102,7 +101,7 @@ end @test vi[vn] == 2 * r # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.VectorVarInfo + if vi isa DynamicPPL.UntypedVectorVarInfo delete!(vi, vn) @test isempty(vi) vi = push!!(vi, vn, r, dist) @@ -116,7 +115,7 @@ end vi = VarInfo() test_base!!(vi) - test_base!!(TypedVarInfo(vi)) + test_base!!(DynamicPPL.typed_varinfo(vi)) test_base!!(SimpleVarInfo()) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -135,7 +134,7 @@ end vi = VarInfo() test_varinfo_logp!(vi) - test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) test_varinfo_logp!(SimpleVarInfo()) test_varinfo_logp!(SimpleVarInfo(Dict())) test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -160,17 +159,17 @@ end unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end - vi = VarInfo(DynamicPPL.Metadata()) + vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!!(TypedVarInfo(vi))) + test_varinfo!(empty!!(DynamicPPL.typed_varinfo(vi))) end - @testset "push!! to TypedVarInfo" begin + @testset "push!! to VarInfo with NT of Metadata" begin vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = TypedVarInfo(untyped_vi) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 @@ -206,16 +205,10 @@ end m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() - ) - vi_untyped = VarInfo(DynamicPPL.Metadata()) - vi_vnv = VarInfo(DynamicPPL.VarNamedVector()) - vi_vnv_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector() - ) - model(vi_untyped, SampleFromPrior()) - model(vi_vnv, SampleFromPrior()) + vi_typed = DynamicPPL.typed_varinfo(model) + vi_untyped = DynamicPPL.untyped_varinfo(model) + vi_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_vnv_typed = DynamicPPL.typed_vector_varinfo(model) model_name = model == model_uv ? "univariate" : "multivariate" @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ @@ -405,7 +398,7 @@ end @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values - vi = TypedVarInfo(vi) + vi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) @@ -459,9 +452,9 @@ end # Need to run once since we can't specify that we want to _sample_ # in the unconstrained space for `VarInfo` without having `vn` # present in the `varinfo`. - ## `UntypedVarInfo` - vi = VarInfo() - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + + ## `untyped_varinfo` + vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -469,8 +462,8 @@ end x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - ## `TypedVarInfo` - vi = VarInfo(model) + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -979,7 +972,7 @@ end @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 - vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) From ae9b1cdaad888f552a92c73aecfeedc12aa0bd80 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 24 Feb 2025 11:00:58 +0000 Subject: [PATCH 05/48] Draft of accumulators --- docs/src/api.md | 1 - src/DynamicPPL.jl | 10 +- src/abstract_varinfo.jl | 62 ++++++-- src/accumulators.jl | 99 +++++++++++++ src/context_implementations.jl | 96 +++---------- src/contexts.jl | 46 +----- src/debug_utils.jl | 36 ++--- src/model.jl | 4 +- src/pointwise_logdensities.jl | 6 +- src/simple_varinfo.jl | 102 ++++---------- src/test_utils/contexts.jl | 2 +- src/test_utils/varinfo.jl | 14 +- src/threadsafe.jl | 98 +++++++------ src/transforming.jl | 7 +- src/utils.jl | 4 +- src/values_as_in_model.jl | 13 +- src/varinfo.jl | 241 ++++++++++++++++++-------------- test/context_implementations.jl | 2 +- test/contexts.jl | 5 +- test/independence.jl | 11 -- test/runtests.jl | 1 - test/simple_varinfo.jl | 6 +- test/threadsafe.jl | 23 +-- test/varinfo.jl | 28 ++-- 24 files changed, 464 insertions(+), 453 deletions(-) create mode 100644 src/accumulators.jl delete mode 100644 test/independence.jl diff --git a/docs/src/api.md b/docs/src/api.md index f83a96886..abc2e3016 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -419,7 +419,6 @@ SamplingContext DefaultContext LikelihoodContext PriorContext -MiniBatchContext PrefixContext ConditionContext ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 51fa53079..ed389be7a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -54,9 +54,9 @@ export AbstractVarInfo, acclogp!!, resetlogp!!, get_num_produce, - set_num_produce!, - reset_num_produce!, - increment_num_produce!, + set_num_produce!!, + reset_num_produce!!, + increment_num_produce!!, set_retained_vns_del!, is_flagged, set_flag!, @@ -92,9 +92,6 @@ export AbstractVarInfo, # Contexts SamplingContext, DefaultContext, - LikelihoodContext, - PriorContext, - MiniBatchContext, PrefixContext, ConditionContext, assume, @@ -166,6 +163,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varnamedvector.jl") +include("accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index f11b8a3ec..4aa3e402e 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -91,12 +91,26 @@ function transformation end # Accumulation of log-probabilities. """ - getlogp(vi::AbstractVarInfo) + getlogjoint(vi::AbstractVarInfo) Return the log of the joint probability of the observed data and parameters sampled in `vi`. """ -function getlogp end +getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) +getlogp(vi::AbstractVarInfo) = getlogjoint(vi) + +function setaccs!! end +function getaccs end + +getlogprior(vi::AbstractVarInfo) = getacc(vi, LogPrior).logp +getloglikelihood(vi::AbstractVarInfo) = getacc(vi, LogLikelihood).logp + +function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + return setaccs!!(vi, setacc!!(getaccs(vi), acc)) +end + +setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPrior(logp)) +setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp)) """ setlogp!!(vi::AbstractVarInfo, logp) @@ -104,23 +118,43 @@ function getlogp end Set the log of the joint probability of the observed data and parameters sampled in `vi` to `logp`, mutating if it makes sense. """ -function setlogp!! end +function setlogp!!(vi::AbstractVarInfo, logp) + vi = setlogprior!!(vi, zero(logp)) + vi = setloglikelihood!!(vi, logp) + return vi +end + +function getacc(vi::AbstractVarInfo, ::Type{AccType}) where {AccType} + return getacc(getaccs(vi), AccType) +end + +function accumulate_assume!!(vi::AbstractVarInfo, r, logp, vn, right) + return setaccs!!(vi, accumulate_assume!!(getaccs(vi), r, logp, vn, right)) +end + +function accumulate_observe!!(vi::AbstractVarInfo, left, right) + return setaccs!!(vi, accumulate_observe!!(getaccs(vi), left, right)) +end + +function acc!!(vi::AbstractVarInfo, ::Type{AccType}, args...) where {AccType} + return setaccs!!(vi, acc!!(getaccs(vi), AccType, args...)) +end + +function acclogprior!!(vi::AbstractVarInfo, logp) + return acc!!(vi, LogPrior, logp) +end + +function accloglikelihood!!(vi::AbstractVarInfo, logp) + return acc!!(vi, LogLikelihood, logp) +end """ - acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) + acclogp!!(vi::AbstractVarInfo, logp) Add `logp` to the value of the log of the joint probability of the observed data and parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(NodeTrait(context), context, vi, logp) -end -function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(vi, logp) -end -function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(childcontext(context), vi, logp) -end +acclogp!!(vi::AbstractVarInfo, logp) = accloglikelihood!!(vi, logp) """ resetlogp!!(vi::AbstractVarInfo) @@ -725,7 +759,7 @@ end # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. -increment_num_produce!(::AbstractVarInfo) = nothing +increment_num_produce!!(::AbstractVarInfo) = nothing """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/accumulators.jl b/src/accumulators.jl new file mode 100644 index 000000000..5dd0660ca --- /dev/null +++ b/src/accumulators.jl @@ -0,0 +1,99 @@ +abstract type AbstractAccumulator end + +accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) + +""" + AccumulatorTuple{N,T<:NamedTuple} + +A collection of accumulators, stored as a `NamedTuple`. + +This is defined as a separate type to be able to dispatch on it cleanly and without method +ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the +constraint the name in the tuple for each accumulator `acc` must be `accumulator_name(acc)`. +""" +struct AccumulatorTuple{N,T<:NamedTuple} + nt::T + + function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}} + names = accumulator_name.(t) + nt = NamedTuple{names}(t) + return new{N,typeof(nt)}(nt) + end +end + +AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) +AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) + +Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] +Base.length(::AccumulatorTuple{N}) where {N} = N +Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) + +function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + return Accessors.@set at.nt[accumulator_name(acc)] = acc +end + +function getacc(at::AccumulatorTuple, ::Type{AccType}) where {AccType} + return at[accumulator_name(AccType)] +end + +function accumulate_assume!!(at::AccumulatorTuple, r, logp, vn, right) + return AccumulatorTuple(map(acc -> accumulate_assume!!(acc, r, logp, vn, right), at.nt)) +end + +function accumulate_observe!!(at::AccumulatorTuple, left, right) + return AccumulatorTuple(map(acc -> accumulate_observe!!(acc, left, right), at.nt)) +end + +function acc!!(at::AccumulatorTuple, ::Type{AccType}, args...) where {AccType} + accname = accumulator_name(AccType) + return Accessors.@set at.nt[accname] = acc!!(at[accname], args...) +end + +struct LogPrior{T} <: AbstractAccumulator + logp::T +end + +LogPrior{T}() where {T} = LogPrior(zero(T)) + +struct LogLikelihood{T} <: AbstractAccumulator + logp::T +end + +LogLikelihood{T}() where {T} = LogLikelihood(zero(T)) + +struct NumProduce{T<:Integer} <: AbstractAccumulator + num::T +end + +NumProduce{T}() where {T} = NumProduce(zero(T)) + +accumulator_name(::Type{<:LogPrior}) = :LogPrior +accumulator_name(::Type{<:LogLikelihood}) = :LogLikelihood +accumulator_name(::Type{<:NumProduce}) = :NumProduce + +split(::LogPrior{T}) where {T} = LogPrior(zero(T)) +split(::LogLikelihood{T}) where {T} = LogLikelihood(zero(T)) +split(acc::NumProduce) = acc + +combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp) +combine(acc::LogLikelihood, acc2::LogLikelihood) = LogLikelihood(acc.logp + acc2.logp) +function combine(acc::NumProduce, acc2::NumProduce) + return NumProduce(max(acc.num, acc2.num)) +end + +acc!!(acc::LogPrior, logp) = LogPrior(acc.logp + logp) +acc!!(acc::LogLikelihood, logp) = LogLikelihood(acc.logp + logp) +acc!!(acc::NumProduce, n) = NumProduce(acc.num + n) + +function accumulate_assume!!(acc::LogPrior, val, logjac, vn, right) + return acc!!(acc, logpdf(right, val) + logjac) +end +accumulate_observe!!(acc::LogPrior, left, right) = acc + +accumulate_assume!!(acc::LogLikelihood, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihood, left, right) + return acc!!(acc, logpdf(right, left)) +end + +accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc +accumulate_observe!!(acc::NumProduce, left, right) = acc!!(acc, 1) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e4ba5d252..ce86f8cbd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,27 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. -function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp) -end -function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(childcontext(context), vi, logp) -end -function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - -function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp) -end -function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(childcontext(context), vi, logp) -end -function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - # assume """ tilde_assume(context::SamplingContext, right, vn, vi) @@ -52,36 +31,18 @@ function tilde_assume(context::SamplingContext, right, vn, vi) return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end -# Leaf contexts function tilde_assume(context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), context, args...) + return tilde_assume(childcontext(context), args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) +function tilde_assume(::DefaultContext, right, vn, vi) return assume(right, vn, vi) end -function tilde_assume(::IsParent, context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) -end -function tilde_assume( - ::IsLeaf, rng::Random.AbstractRNG, context::AbstractContext, sampler, right, vn, vi -) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume( - ::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args... -) return tilde_assume(rng, childcontext(context), args...) end - -function tilde_assume(::LikelihoodContext, right, vn, vi) - return assume(nodist(right), vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) - return assume(rng, sampler, nodist(right), vn, vi) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::PrefixContext, right, vn, vi) @@ -111,8 +72,8 @@ function tilde_assume!!(context, right, vn, vi) vi, ) else - value, logp, vi = tilde_assume(context, right, vn, vi) - value, acclogp_assume!!(context, vi, logp) + value, vi = tilde_assume(context, right, vn, vi) + return value, vi end end @@ -128,26 +89,12 @@ function tilde_observe(context::SamplingContext, right, left, vi) return tilde_observe(context.context, context.sampler, right, left, vi) end -# Leaf contexts function tilde_observe(context::AbstractContext, args...) - return tilde_observe(NodeTrait(tilde_observe, context), context, args...) -end -tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) -function tilde_observe(::IsParent, context::AbstractContext, args...) return tilde_observe(childcontext(context), args...) end -tilde_observe(::PriorContext, right, left, vi) = 0, vi -tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi - -# `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - logp, vi = tilde_observe(context.context, sampler, right, left, vi) - return context.loglike_scalar * logp, vi +function tilde_observe(::DefaultContext, args...) + return observe(args...) end # `PrefixContext` @@ -191,8 +138,8 @@ function tilde_observe!!(context, right, left, vi) "`~` with a model on the right-hand side of an observe statement is not supported", ), ) - logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) + vi = tilde_observe(context, right, left, vi) + return left, vi end function assume(rng::Random.AbstractRNG, spl::Sampler, dist) @@ -205,8 +152,11 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r, logp = invlink_with_logpdf(vi, vn, dist) - return r, logp, vi + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, dist) + x, logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + return x, vi end # TODO: Remove this thing. @@ -228,8 +178,7 @@ function assume( r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) # TODO(mhauru) This should probably be call a function called setindex_internal! - # Also, if we use !! we shouldn't ignore the return value. - BangBang.setindex!!(vi, f(r), vn) + vi = BangBang.setindex!!(vi, f(r), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -239,22 +188,23 @@ function assume( r = init(rng, dist, sampler) if istrans(vi) f = to_linked_internal_transform(vi, vn, dist) - push!!(vi, vn, f(r), dist) + vi = push!!(vi, vn, f(r), dist) # By default `push!!` sets the transformed flag to `false`. - settrans!!(vi, true, vn) + vi = settrans!!(vi, true, vn) else - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) end end # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - return r, logpdf(dist, r) - logjac, vi + vi = accumulate_assume!!(vi, r, -logjac, vn, dist) + return r, vi end # default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) + function observe(right::Distribution, left, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(right, left), vi + return accumulate_observe!!(vi, left, right) end diff --git a/src/contexts.jl b/src/contexts.jl index 58ac612b8..941ccd75e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -37,6 +37,7 @@ Return the descendant context of `context`. """ childcontext +# TODO(mhauru) Rework the below docstring to not use PriorContext. """ setchildcontext(parent::AbstractContext, child::AbstractContext) @@ -129,7 +130,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R @@ -193,48 +194,7 @@ The `DefaultContext` is used by default to compute the log joint probability of and parameters when running the model. """ struct DefaultContext <: AbstractContext end -NodeTrait(context::DefaultContext) = IsLeaf() - -""" - PriorContext <: AbstractContext - -A leaf context resulting in the exclusion of likelihood terms when running the model. -""" -struct PriorContext <: AbstractContext end -NodeTrait(context::PriorContext) = IsLeaf() - -""" - LikelihoodContext <: AbstractContext - -A leaf context resulting in the exclusion of prior terms when running the model. -""" -struct LikelihoodContext <: AbstractContext end -NodeTrait(context::LikelihoodContext) = IsLeaf() - -""" - struct MiniBatchContext{Tctx, T} <: AbstractContext - context::Tctx - loglike_scalar::T - end - -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing -`log(prior) + log(likelihood of all the data points)` in the expectation. -""" -struct MiniBatchContext{Tctx,T} <: AbstractContext - context::Tctx - loglike_scalar::T -end -function MiniBatchContext(context=DefaultContext(); batch_size, npoints) - return MiniBatchContext(context, npoints / batch_size) -end -NodeTrait(context::MiniBatchContext) = IsParent() -childcontext(context::MiniBatchContext) = context.context -function setchildcontext(parent::MiniBatchContext, child) - return MiniBatchContext(child, parent.loglike_scalar) -end +NodeTrait(::DefaultContext) = IsLeaf() """ PrefixContext{Prefix}(context) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 529092e8e..bff03386c 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - logp varinfo = nothing end @@ -90,15 +89,12 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, RESULT_SYMBOL) print(io, " ") print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) return print(io, ")") end Base.@kwdef struct ObserveStmt <: Stmt left right - logp varinfo = nothing end @@ -108,8 +104,6 @@ function Base.show(io::IO, stmt::ObserveStmt) show_right(io, stmt.left) print(io, " ~ ") show_right(io, stmt.right) - print(io, " (logprob = ") - print(io, stmt.logp) return print(io, ")") end @@ -257,12 +251,11 @@ function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) return nothing end -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, logp, varinfo) +function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo) stmt = AssumeStmt(; varname=vn, right=dist, value=value, - logp=logp, varinfo=context.record_varinfo ? varinfo : nothing, ) if context.record_statements @@ -273,19 +266,19 @@ end function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi ) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume( + value, vi = DynamicPPL.tilde_assume( rng, childcontext(context), sampler, right, vn, vi ) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end # observe @@ -300,11 +293,10 @@ function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) end end -function record_post_tilde_observe!(context::DebugContext, left, right, logp, varinfo) +function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) stmt = ObserveStmt(; left=left, right=right, - logp=logp, varinfo=context.record_varinfo ? varinfo : nothing, ) if context.record_statements @@ -315,15 +307,15 @@ end function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end _conditioned_varnames(d::AbstractDict) = keys(d) diff --git a/src/model.jl b/src/model.jl index b4d5f6bb7..2431935f0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1056,7 +1056,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) + return getlogprior(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1103,7 +1103,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) + return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext()))) end """ diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index cb9ea4894..83b0a7476 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -122,8 +122,12 @@ end function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) - value, logp, vi = tilde_assume(context.context, right, vn, vi) + value, vi = tilde_assume(context.context, right, vn, vi) # Track loglikelihood value. + # TODO(mhauru) logp here should be the logp that resulted from this tilde call. + # Implement this with a suitable accumulator. The current setting to zero is just to + # make this run, it produces nonsense results. + logp = zero(getlogjoint(vi)) push!(context, vn, logp) return value, acclogp!!(vi, logp) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abf14b8fc..fce1c7d44 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -188,41 +188,28 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo +struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: + AbstractVarInfo "underlying representation of the realization represented" values::NT - "holds the accumulated log-probability" - logp::T + "tuple of accumulators for things like log prior and log likelihood" + accs::Accs "represents whether it assumes variables to be transformed" transformation::C end transformation(vi::SimpleVarInfo) = vi.transformation -# Makes things a bit more readable vs. putting `Float64` everywhere. -const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64 - -function SimpleVarInfo{NT,T}(values, logp) where {NT,T} - return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation()) +function SimpleVarInfo(values, accs) + return SimpleVarInfo(values, accs, NoTransformation()) end -function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T)) +function SimpleVarInfo{T}(values) where {T<:Real} + return SimpleVarInfo(values, AccumulatorTuple(LogLikelihood{T}(), LogPrior{T}())) end - -# Constructors without type-specification. -SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) -function SimpleVarInfo(θ::Union{<:NamedTuple,<:AbstractDict}) - return if isempty(θ) - # Can't infer from values, so we just use default. - SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) - else - # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ) - end +function SimpleVarInfo(values) + return SimpleVarInfo{LogProbType}(values) end -SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp) - # Using `kwargs` to specify the values. function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -235,7 +222,7 @@ end function SimpleVarInfo( model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... ) - return SimpleVarInfo{Float64}(model, args...) + return SimpleVarInfo{LogProbType}(model, args...) end function SimpleVarInfo{T}( model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... @@ -244,14 +231,9 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} - return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) -end -function SimpleVarInfo{T}( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {T<:Real,names,D} +function SimpleVarInfo(vi::VarInfo{<:NamedTuple{names}}, ::Type{D}) where {names,D} values = values_as(vi, D) - return SimpleVarInfo(values, convert(T, getlogp(vi))) + return SimpleVarInfo(values, vi.accs) end function untyped_simple_varinfo(model::Model) @@ -265,12 +247,8 @@ function typed_simple_varinfo(model::Model) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) - logp = getlogp(svi) vals = unflatten(svi.values, x) - T = eltype(x) - return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}( - vals, T(logp), svi.transformation - ) + return SimpleVarInfo(vals, svi.accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) @@ -278,21 +256,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) -getlogp(vi::SimpleVarInfo) = vi.logp -getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] - -setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp -acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp - -function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] += logp - return vi -end +getaccs(vi::SimpleVarInfo) = vi.accs +setaccs!!(vi::SimpleVarInfo, accs) = Accessors.@set vi.accs = accs """ keys(vi::SimpleVarInfo) @@ -307,7 +272,7 @@ function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) print(io, "Transformed ") end - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", getaccs(svi), ")") end function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) @@ -454,11 +419,11 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - logp = getlogp(varinfo_right) + accs = getaccs(varinfo_right) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) - return SimpleVarInfo(values, logp, transformation) + return SimpleVarInfo(values, accs, transformation) end # Context implementations @@ -473,9 +438,11 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = to_maybe_linked_internal(vi, vn, dist, value) + f = to_maybe_linked_internal_transform(vi, vn, dist) + value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi + vi = accumulate_assume!!(vi, value, -logjac, vn, dist) + return value, vi end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. @@ -497,8 +464,8 @@ islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} - isempty(vi) && return T[] +function values_as(vi::SimpleVarInfo, ::Type{Vector}) + isempty(vi) && return Any[] return mapreduce(tovec, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} @@ -613,12 +580,11 @@ function link!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) x = vi.values y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new) + vi_new = Accessors.@set(vi.values = y) + vi_new = acclogprior!!(vi_new, -logjac) return settrans!!(vi_new, t) end @@ -627,12 +593,11 @@ function invlink!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = t.bijector y = vi.values x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new) + vi_new = Accessors.@set(vi.values = x) + vi_new = acclogprior!!(vi_new, logjac) return settrans!!(vi_new, NoTransformation()) end @@ -645,13 +610,4 @@ function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) return invlink_transform(dist) end -# Threadsafe stuff. -# For `SimpleVarInfo` we don't really need `Ref` so let's not use it. -function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads())) -end -function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) -end - has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 7404a9af7..17c9a08fe 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -61,7 +61,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # To see change, let's make sure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - PriorContext() + DynamicPPL.DynamicTransformationContext{false}() else DefaultContext() end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 539872143..07a308c7a 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -37,12 +37,6 @@ function setup_varinfos( svi_untyped = SimpleVarInfo(OrderedDict()) svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - # SimpleVarInfo{<:Any,<:Ref} - svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) - svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) - svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv))) - - lp = getlogp(vi_typed_metadata) varinfos = map(( vi_untyped_metadata, vi_untyped_vnv, @@ -51,12 +45,10 @@ function setup_varinfos( svi_typed, svi_untyped, svi_vnv, - svi_typed_ref, - svi_untyped_ref, - svi_vnv_ref, )) do vi - # Set them all to the same values. - DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) + # Set them all to the same values and evaluate logp. + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) end if include_threadsafe diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2dc2645de..bed865526 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -2,69 +2,74 @@ ThreadSafeVarInfo A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of log probabilities for thread-safe execution of a probabilistic model. +array of accumulators for thread-safe execution of a probabilistic model. """ -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo +struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo varinfo::V - logps::L + accs_by_thread::Vector{L} end function ThreadSafeVarInfo(vi::AbstractVarInfo) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) + accs_by_thread = [ + AccumulatorTuple(map(split, vi.accs.nt)) for _ in 1:Threads.nthreads() + ] + return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi -const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{ - V,<:AbstractArray{<:Ref} -} - transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) -# Instead of updating the log probability of the underlying variables we -# just update the array of log probabilities. -function acclogp!!(vi::ThreadSafeVarInfo, logp) - vi.logps[Threads.threadid()] += logp - return vi -end -function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp) - vi.logps[Threads.threadid()][] += logp - return vi +# Set the accumulator in question in vi.varinfo, and set the thread-specific +# accumulators of the same type to be empty. +function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) + inner_vi = setacc!!(vi.varinfo, acc) + news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) + return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) end -# The current log probability of the variables has to be computed from -# both the wrapped variables and the thread-specific log probabilities. -getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps) -getlogp(vi::ThreadSafeVarInfoWithRef) = getlogp(vi.varinfo) + sum(getindex, vi.logps) - -# TODO: Make remaining methods thread-safe. -function resetlogp!!(vi::ThreadSafeVarInfo) - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), zero(vi.logps)) +# Get both the main accumulator and the thread-specific accumulators of the same type and +# combine them. +function getacc(vi::ThreadSafeVarInfo, ::Type{AccType}) where {AccType} + main_acc = getacc(vi.varinfo, AccType) + other_accs = map(accs -> getacc(accs, AccType), vi.accs_by_thread) + return foldl(combine, other_accs; init=main_acc) end -function resetlogp!!(vi::ThreadSafeVarInfoWithRef) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps) + +# Calls to accumulate_assume!!, accumulate_observe!!, and acc!! are thread-specific. +function accumulate_assume!!(vi::ThreadSafeVarInfo, r, logp, vn, right) + tid = Threads.threadid() + vi.accs_by_thread[tid] = accumulate_assume!!(vi.accs_by_thread[tid], r, logp, vn, right) + return vi end -function setlogp!!(vi::ThreadSafeVarInfo, logp) - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), zero(vi.logps)) + +function accumulate_observe!!(vi::ThreadSafeVarInfo, left, right) + tid = Threads.threadid() + vi.accs_by_thread[tid] = accumulate_observe!!(vi.accs_by_thread[tid], left, right) + return vi end -function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) + +function acc!!(vi::ThreadSafeVarInfo, ::Type{AccType}, args...) where {AccType} + tid = Threads.threadid() + vi.accs_by_thread[tid] = acc!!(vi.accs_by_thread[tid], AccType, args...) + return vi end -has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) +has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end +# TODO(mhauru) Why these short-circuits? Why not use the thread-specific ones? get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) -increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) -reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) -set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n) +function increment_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(increment_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function reset_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(reset_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function set_num_produce!!(vi::ThreadSafeVarInfo, n::Int) + return ThreadSafeVarInfo(set_num_produce!!(vi.varinfo, n), vi.accs_by_thread) +end syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) @@ -169,6 +174,17 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end +function resetlogp!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) + logprior = split(getacc(vi.varinfo, LogPrior)) + loglikelihood = split(getacc(vi.varinfo, LogLikelihood)) + for i in eachindex(vi.accs_by_thread) + vi.accs_by_thread[i] = setacc!!(vi.accs_by_thread[i], logprior) + vi.accs_by_thread[i] = setacc!!(vi.accs_by_thread[i], loglikelihood) + end + return vi +end + values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) diff --git a/src/transforming.jl b/src/transforming.jl index 0239725ae..9d7e9e587 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -27,7 +27,12 @@ function tilde_assume( # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. r_transformed = isinverse ? r : link_transform(right)(r) - return r, lp, setindex!!(vi, r_transformed, vn) + vi = acclogprior!!(vi, lp) + return r, setindex!!(vi, r_transformed, vn) +end + +function tilde_observe(::DynamicTransformationContext, right, vn, vi) + return observe(right, vn, vi) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/utils.jl b/src/utils.jl index 56c3d70af..398fd499b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -83,9 +83,7 @@ true """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!( - $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) - ) + $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) end end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..3ec474940 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -65,29 +65,24 @@ end function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) if is_tracked_value(right) value = right.value - logp = zero(getlogp(vi)) else - value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + value, vi = tilde_assume(childcontext(context), right, vn, vi) end - # Save the value. push!(context, vn, value) - # Save the value. - # Pass on. - return value, logp, vi + return value, vi end function tilde_assume( rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi ) if is_tracked_value(right) value = right.value - logp = zero(getlogp(vi)) else - value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) end # Save the value. push!(context, vn, value) # Pass on. - return value, logp, vi + return value, vi end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index 360857ef7..58523ca2a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,10 +69,9 @@ end ########### """ - struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end A light wrapper over some kind of metadata. @@ -98,12 +97,13 @@ Note that for NTVarInfo, it is the user's responsibility to ensure that each symbol is visited at least once during model evaluation, regardless of any stochastic branching. """ -struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +VarInfo(meta=Metadata()) = VarInfo(meta, AccumulatorTuple(LogPrior{LogProbType}(), + LogLikelihood{LogProbType}(), NumProduce{Int}())) + """ VarInfo([rng, ]model[, sampler, context]) @@ -285,10 +285,8 @@ function typed_varinfo(vi::UntypedVarInfo) ), ) end - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, vi.accs) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -349,8 +347,7 @@ single `VarNamedVector` as its metadata field. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, vi.accs) end function untyped_vector_varinfo( rng::Random.AbstractRNG, @@ -393,15 +390,12 @@ NamedTuple of `VarNamedVector`s as its metadata field. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, vi.accs) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, vi.accs) end function typed_vector_varinfo( rng::Random.AbstractRNG, @@ -441,13 +435,7 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases - # where e.g. x is a type gradient of some AD backend. - return VarInfo( - md, - Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), - Ref(get_num_produce(vi)), - ) + return VarInfo(md, vi.accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in @@ -529,7 +517,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) + return VarInfo(metadata, varinfo.accs) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -618,9 +606,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo( - metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right)) - ) + return VarInfo(metadata, varinfo_right.accs) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -973,8 +959,8 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!!(vi) - reset_num_produce!(vi) + vi = resetlogp!!(vi) + vi = reset_num_produce!!(vi) return vi end @@ -1008,46 +994,37 @@ end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") -getlogp(vi::VarInfo) = vi.logp[] - -function setlogp!!(vi::VarInfo, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::VarInfo, logp) - vi.logp[] += logp - return vi -end +getaccs(vi::VarInfo) = vi.accs +setaccs!!(vi::VarInfo, accs) = Accessors.@set vi.accs = accs """ get_num_produce(vi::VarInfo) Return the `num_produce` of `vi`. """ -get_num_produce(vi::VarInfo) = vi.num_produce[] +get_num_produce(vi::VarInfo) = getacc(vi, NumProduce).num """ - set_num_produce!(vi::VarInfo, n::Int) + set_num_produce!!(vi::VarInfo, n::Int) Set the `num_produce` field of `vi` to `n`. """ -set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n +set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduce(n)) """ - increment_num_produce!(vi::VarInfo) + increment_num_produce!!(vi::VarInfo) Add 1 to `num_produce` in `vi`. """ -increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 +increment_num_produce!!(vi::VarInfo) = set_num_produce!!(vi, get_num_produce(vi) + 1) """ - reset_num_produce!(vi::VarInfo) + reset_num_produce!!(vi::VarInfo) Reset the value of `num_produce` the log of the joint probability of the observed data and parameters sampled in `vi` to 0. """ -reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) +reset_num_produce!!(vi::VarInfo) = set_num_produce!!(vi, 0) # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata) @@ -1061,7 +1038,7 @@ function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1069,7 +1046,7 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1082,8 +1059,7 @@ end function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1098,27 +1074,28 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, vns) +function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end + return vi else @warn("[DynamicPPL] attempt to link a linked vi") end end -# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _link!(vi::NTVarInfo, vns::VarNameTuple) - return _link!(vi, group_varnames_by_symbol(vns)) +function _link!!(vi::NTVarInfo, vns::VarNameTuple) + return _link!!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::NTVarInfo, vns::NamedTuple) - return _link!(vi.metadata, vi, vns) +function _link!!(vi::NTVarInfo, vns::NamedTuple) + return _link!!(vi.metadata, vi, vns) end """ @@ -1130,7 +1107,7 @@ function filter_subsumed(filter_vns, filtered_vns) return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end -@generated function _link!( +@generated function _link!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1148,8 +1125,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -1158,6 +1135,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1165,8 +1143,7 @@ function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1174,7 +1151,7 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1187,8 +1164,7 @@ end function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1211,29 +1187,30 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!(vi::UntypedVarInfo, vns) +function _invlink!!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end + return vi else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _invlink!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) +function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) + return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!(vi.metadata, vi, vns) +function _invlink!!(vi::NTVarInfo, vns::NamedTuple) + return _invlink!!(vi.metadata, vi, vns) end -@generated function _invlink!( +@generated function _invlink!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1251,8 +1228,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1260,6 +1237,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1276,7 +1254,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - acclogp!!(vi, -logjac) + vi = acclogprior!!(vi, -logjac) return vi end @@ -1311,8 +1289,10 @@ end function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1323,8 +1303,10 @@ end function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _link_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _link_metadata!( @@ -1333,20 +1315,39 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if f in vns_names - push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + NamedTuple{$metadata_names}($mds), cumulative_logjac + end, + ) + return expr end function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1364,7 +1365,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ # Vectorize value. yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. @@ -1389,7 +1390,8 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _link_metadata!!( @@ -1397,6 +1399,7 @@ function _link_metadata!!( ) vns = target_vns === nothing ? keys(metadata) : target_vns dists = extract_priors(model, varinfo) + cumulative_logjac = zero(LogProbType) for vn in vns # First transform from however the variable is stored in vnv to the model # representation. @@ -1409,11 +1412,11 @@ function _link_metadata!!( val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) # TODO(mhauru) We are calling a !! function but ignoring the return value. # Fix this when attending to issue #653. - acclogp!!(varinfo, -logjac1 - logjac2) + cumulative_logjac += logjac1 + logjac2 metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) settrans!(metadata, true, vn) end - return metadata + return metadata, cumulative_logjac end function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) @@ -1449,11 +1452,10 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1464,8 +1466,10 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _invlink_metadata!( @@ -1474,20 +1478,41 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if (f in vns_names) - push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _invlink_metadata!!( + model, varinfo, metadata.$f, vns.$f + ) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + (NamedTuple{$metadata_names}($mds), cumulative_logjac) + end, + ) + return expr end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1506,7 +1531,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1531,24 +1556,26 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns + cumulative_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) new_val, logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata + return metadata, cumulative_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 0ec88c07c..7e1397ae3 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -10,7 +10,7 @@ end end - test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) + test([1, 1, -1])(VarInfo(), SampleFromPrior(), DefaultContext()) end @testset "dot tilde with varying sizes" begin diff --git a/test/contexts.jl b/test/contexts.jl index 11e591f8f..fdff5f6f7 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -42,14 +42,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "contexts.jl" begin child_contexts = Dict( :default => DefaultContext(), - :prior => PriorContext(), - :likelihood => LikelihoodContext(), ) parent_contexts = Dict( :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), - :minibatch => MiniBatchContext(DefaultContext(), 0.0), :prefix => PrefixContext{:x}(DefaultContext()), :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), @@ -237,7 +234,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Values from outer context should override inner one ctx1 = ConditionContext(n1, ConditionContext(n2)) @test ctx1.values == (x=1, y=2) - # Check that the two ConditionContexts are collapsed + # Check that the two ConditionContexts are collapsed @test childcontext(ctx1) isa DefaultContext # Then test the nesting the other way round ctx2 = ConditionContext(n2, ConditionContext(n1)) diff --git a/test/independence.jl b/test/independence.jl deleted file mode 100644 index a4a834a61..000000000 --- a/test/independence.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "Turing independence" begin - @model coinflip(y) = begin - p ~ Beta(1, 1) - N = length(y) - for i in 1:N - y[i] ~ Bernoulli(p) - end - end - model = coinflip([1, 1, 0]) - model(SampleFromPrior(), LikelihoodContext()) -end diff --git a/test/runtests.jl b/test/runtests.jl index 3473d5594..cf9e0e79e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,7 +55,6 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("sampler.jl") - include("independence.jl") include("distribution_wrappers.jl") include("logdensityfunction.jl") include("linking.jl") diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index aa3b592f7..c8994c5a8 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -23,7 +23,8 @@ svi = SimpleVarInfo{Float32}(; m=1.0) @test getlogp(svi) isa Float32 - svi = SimpleVarInfo((m=1.0,), 1.0) + svi = SimpleVarInfo((m=1.0,)) + svi = acclogp!!(svi, 1.0) @test getlogp(svi) == 1.0 end @@ -98,7 +99,6 @@ vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - lp_orig = getlogp(vi) # `link!!` vi_linked = link!!(deepcopy(vi), model) @@ -158,7 +158,7 @@ # DynamicPPL.settrans!!(deepcopy(svi_dict), true), # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) - # RandOM seed is set in each `@testset`, so we need to sample + # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. retval = model() diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 72c439db8..625b94c33 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -4,9 +4,10 @@ threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() - @test all(iszero(x[]) for x in threadsafe_vi.logps) + @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} + @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() + expected_accs = DynamicPPL.AccumulatorTuple((DynamicPPL.split(acc) for acc in vi.accs)...) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end # TODO: Add more tests of the public API @@ -17,20 +18,20 @@ lp = getlogp(vi) @test getlogp(threadsafe_vi) == lp - acclogp!!(threadsafe_vi, 42) - @test threadsafe_vi.logps[Threads.threadid()][] == 42 + threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) + @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 @test getlogp(vi) == lp @test getlogp(threadsafe_vi) == lp + 42 - resetlogp!!(threadsafe_vi) - @test iszero(getlogp(vi)) + threadsafe_vi = resetlogp!!(threadsafe_vi) @test iszero(getlogp(threadsafe_vi)) - @test all(iszero(x[]) for x in threadsafe_vi.logps) + expected_accs = DynamicPPL.AccumulatorTuple((DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)...) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - setlogp!!(threadsafe_vi, 42) - @test getlogp(vi) == 42 + threadsafe_vi = setlogp!!(threadsafe_vi, 42) @test getlogp(threadsafe_vi) == 42 - @test all(iszero(x[]) for x in threadsafe_vi.logps) + expected_accs = DynamicPPL.AccumulatorTuple((DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)...) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @testset "model" begin diff --git a/test/varinfo.jl b/test/varinfo.jl index 777917aa6..1ca3308aa 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -941,19 +941,19 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -961,12 +961,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @@ -975,21 +975,21 @@ end vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -997,12 +997,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] From 4fb0bf40042d8fa7582dd60593d6efc2cc10f96a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 10 Apr 2025 09:45:58 +0100 Subject: [PATCH 06/48] Fix some variable names --- src/abstract_varinfo.jl | 4 ++-- src/accumulators.jl | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 4aa3e402e..963c41513 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -128,8 +128,8 @@ function getacc(vi::AbstractVarInfo, ::Type{AccType}) where {AccType} return getacc(getaccs(vi), AccType) end -function accumulate_assume!!(vi::AbstractVarInfo, r, logp, vn, right) - return setaccs!!(vi, accumulate_assume!!(getaccs(vi), r, logp, vn, right)) +function accumulate_assume!!(vi::AbstractVarInfo, r, logjac, vn, right) + return setaccs!!(vi, accumulate_assume!!(getaccs(vi), r, logjac, vn, right)) end function accumulate_observe!!(vi::AbstractVarInfo, left, right) diff --git a/src/accumulators.jl b/src/accumulators.jl index 5dd0660ca..293f03e77 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -9,7 +9,8 @@ A collection of accumulators, stored as a `NamedTuple`. This is defined as a separate type to be able to dispatch on it cleanly and without method ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the -constraint the name in the tuple for each accumulator `acc` must be `accumulator_name(acc)`. +constraint that the name in the tuple for each accumulator `acc` must be +`accumulator_name(acc)`. """ struct AccumulatorTuple{N,T<:NamedTuple} nt::T @@ -36,8 +37,8 @@ function getacc(at::AccumulatorTuple, ::Type{AccType}) where {AccType} return at[accumulator_name(AccType)] end -function accumulate_assume!!(at::AccumulatorTuple, r, logp, vn, right) - return AccumulatorTuple(map(acc -> accumulate_assume!!(acc, r, logp, vn, right), at.nt)) +function accumulate_assume!!(at::AccumulatorTuple, r, logjac, vn, right) + return AccumulatorTuple(map(acc -> accumulate_assume!!(acc, r, logjac, vn, right), at.nt)) end function accumulate_observe!!(at::AccumulatorTuple, left, right) From 97788bdfc34b6def29a2c1f85c911088d57e96c3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 11 Apr 2025 18:17:59 +0100 Subject: [PATCH 07/48] Fix pointwise_logdensities, gut tilde_observe, remove resetlogp!! --- src/DynamicPPL.jl | 4 +- src/abstract_varinfo.jl | 14 +- src/accumulators.jl | 20 ++- src/compiler.jl | 2 +- src/context_implementations.jl | 61 ++------- src/debug_utils.jl | 16 +-- src/model.jl | 4 +- src/pointwise_logdensities.jl | 232 +++++++++++---------------------- src/simple_varinfo.jl | 2 +- src/test_utils/contexts.jl | 14 +- src/test_utils/models.jl | 28 ++-- src/threadsafe.jl | 17 ++- src/transforming.jl | 4 +- src/varinfo.jl | 12 +- test/contexts.jl | 6 +- test/pointwise_logdensities.jl | 2 +- test/simple_varinfo.jl | 4 +- test/threadsafe.jl | 12 +- 18 files changed, 181 insertions(+), 273 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ed389be7a..e743bbd8b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -52,7 +52,7 @@ export AbstractVarInfo, getlogp, setlogp!!, acclogp!!, - resetlogp!!, + resetaccs!!, get_num_produce, set_num_produce!!, reset_num_produce!!, @@ -95,9 +95,7 @@ export AbstractVarInfo, PrefixContext, ConditionContext, assume, - observe, tilde_assume, - tilde_observe, # Pseudo distributions NamedDist, NoDist, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 963c41513..9cc51c916 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -132,8 +132,8 @@ function accumulate_assume!!(vi::AbstractVarInfo, r, logjac, vn, right) return setaccs!!(vi, accumulate_assume!!(getaccs(vi), r, logjac, vn, right)) end -function accumulate_observe!!(vi::AbstractVarInfo, left, right) - return setaccs!!(vi, accumulate_observe!!(getaccs(vi), left, right)) +function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + return setaccs!!(vi, accumulate_observe!!(getaccs(vi), right, left, vn)) end function acc!!(vi::AbstractVarInfo, ::Type{AccType}, args...) where {AccType} @@ -162,7 +162,15 @@ acclogp!!(vi::AbstractVarInfo, logp) = accloglikelihood!!(vi, logp) Reset the value of the log of the joint probability of the observed data and parameters sampled in `vi` to 0, mutating if it makes sense. """ -resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) +function resetaccs!!(vi::AbstractVarInfo) + accs = getaccs(vi) + for acc in accs + accs = setacc!!(accs, resetacc!!(acc)) + end + return setaccs!!(vi, accs) +end + +haslogp(vi::AbstractVarInfo) = hasacc(vi, LogPrior) || hasacc(vi, LogLikelihood) # Variables and their realizations. @doc """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 293f03e77..4a94ca4dc 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -38,11 +38,13 @@ function getacc(at::AccumulatorTuple, ::Type{AccType}) where {AccType} end function accumulate_assume!!(at::AccumulatorTuple, r, logjac, vn, right) - return AccumulatorTuple(map(acc -> accumulate_assume!!(acc, r, logjac, vn, right), at.nt)) + return AccumulatorTuple( + map(acc -> accumulate_assume!!(acc, r, logjac, vn, right), at.nt) + ) end -function accumulate_observe!!(at::AccumulatorTuple, left, right) - return AccumulatorTuple(map(acc -> accumulate_observe!!(acc, left, right), at.nt)) +function accumulate_observe!!(at::AccumulatorTuple, right, left, vn) + return AccumulatorTuple(map(acc -> accumulate_observe!!(acc, right, left, vn), at.nt)) end function acc!!(at::AccumulatorTuple, ::Type{AccType}, args...) where {AccType} @@ -72,6 +74,12 @@ accumulator_name(::Type{<:LogPrior}) = :LogPrior accumulator_name(::Type{<:LogLikelihood}) = :LogLikelihood accumulator_name(::Type{<:NumProduce}) = :NumProduce +resetacc!!(acc::LogPrior) = LogPrior(zero(acc.logp)) +resetacc!!(acc::LogLikelihood) = LogLikelihood(zero(acc.logp)) +# TODO(mhauru) How to handle reset for NumProduce? Do we need to define different types of +# resets? +resetacc!!(acc::NumProduce) = acc + split(::LogPrior{T}) where {T} = LogPrior(zero(T)) split(::LogLikelihood{T}) where {T} = LogLikelihood(zero(T)) split(acc::NumProduce) = acc @@ -89,12 +97,12 @@ acc!!(acc::NumProduce, n) = NumProduce(acc.num + n) function accumulate_assume!!(acc::LogPrior, val, logjac, vn, right) return acc!!(acc, logpdf(right, val) + logjac) end -accumulate_observe!!(acc::LogPrior, left, right) = acc +accumulate_observe!!(acc::LogPrior, right, left, vn) = acc accumulate_assume!!(acc::LogLikelihood, val, logjac, vn, right) = acc -function accumulate_observe!!(acc::LogLikelihood, left, right) +function accumulate_observe!!(acc::LogLikelihood, right, left, vn) return acc!!(acc, logpdf(right, left)) end accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc -accumulate_observe!!(acc::NumProduce, left, right) = acc!!(acc, 1) +accumulate_observe!!(acc::NumProduce, right, left, vn) = acc!!(acc, 1) diff --git a/src/compiler.jl b/src/compiler.jl index 4771b0171..c9a79d715 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -438,7 +438,7 @@ function generate_tilde_literal(left, right) @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__ ) $value end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ce86f8cbd..f8efa87c1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -79,34 +79,30 @@ end # observe """ - tilde_observe(context::SamplingContext, right, left, vi) + tilde_observe!!(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. +Falls back to `tilde_observe!!(context.context, right, left, vi)`. """ -function tilde_observe(context::SamplingContext, right, left, vi) - return tilde_observe(context.context, context.sampler, right, left, vi) +function tilde_observe!!(context::SamplingContext, right, left, vn, vi) + return tilde_observe!!(context.context, right, left, vn, vi) end -function tilde_observe(context::AbstractContext, args...) - return tilde_observe(childcontext(context), args...) -end - -function tilde_observe(::DefaultContext, args...) - return observe(args...) +function tilde_observe!!(context::AbstractContext, right, left, vn, vi) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end -function tilde_observe(context::PrefixContext, sampler, right, left, vi) - return tilde_observe(context.context, sampler, right, left, vi) +function tilde_observe!!(context::PrefixContext, right, left, vn, vi) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. + prefixed_varname = vn !== nothing ? prefix(context, vn) : vn + return tilde_observe!!(context.context, right, left, prefixed_varname, vi) end """ - tilde_observe!!(context, right, left, vname, vi) + tilde_observe!!(context, right, left, vn, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value and updated `vi`. @@ -114,31 +110,13 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(context, right, left, vname, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return tilde_observe!!(context, right, left, vi) -end - -""" - tilde_observe(context, right, left, vi) - -Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and -return the observed value. - -By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_observe!!(context, right, left, vi) +function tilde_observe!!(context::DefaultContext, right, left, vn, vi) is_rhs_model(right) && throw( ArgumentError( "`~` with a model on the right-hand side of an observe statement is not supported", ), ) - vi = tilde_observe(context, right, left, vi) + vi = accumulate_observe!!(vi, right, left, vn) return left, vi end @@ -146,10 +124,6 @@ function assume(rng::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end -function observe(spl::Sampler, weight) - return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") -end - # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) y = getindex_internal(vi, vn) @@ -201,10 +175,3 @@ function assume( vi = accumulate_assume!!(vi, r, -logjac, vn, dist) return r, vi end - -# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) -observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) - -function observe(right::Distribution, left, vi) - return accumulate_observe!!(vi, left, right) -end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index bff03386c..505383e24 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -274,9 +274,7 @@ function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi ) record_pre_tilde_assume!(context, vn, right, vi) - value, vi = DynamicPPL.tilde_assume( - rng, childcontext(context), sampler, right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi) record_post_tilde_assume!(context, vn, right, value, vi) return value, vi end @@ -295,9 +293,7 @@ end function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) stmt = ObserveStmt(; - left=left, - right=right, - varinfo=context.record_varinfo ? varinfo : nothing, + left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing ) if context.record_statements push!(context.statements, stmt) @@ -305,15 +301,15 @@ function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) return nothing end -function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi) + vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi) record_post_tilde_observe!(context, left, right, vi) return vi end -function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi) + vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi) record_post_tilde_observe!(context, left, right, vi) return vi end diff --git a/src/model.jl b/src/model.jl index 2431935f0..7df61e939 100644 --- a/src/model.jl +++ b/src/model.jl @@ -882,7 +882,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ function evaluate_threadunsafe!!(model, varinfo, context) - return _evaluate!!(model, resetlogp!!(varinfo), context) + return _evaluate!!(model, resetaccs!!(varinfo), context) end """ @@ -897,7 +897,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ function evaluate_threadsafe!!(model, varinfo, context) - wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) + wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper, context) return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 83b0a7476..c856adffe 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,135 +1,68 @@ -# Context version -struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext - logdensities::A - context::Ctx +struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: + AbstractAccumulator + logps::D end -function PointwiseLogdensityContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=DefaultContext(), -) - return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end - -NodeTrait(::PointwiseLogdensityContext) = IsParent() -childcontext(context::PointwiseLogdensityContext) = context.context -function setchildcontext(context::PointwiseLogdensityContext, child) - return PointwiseLogdensityContext(context.logdensities, child) -end - -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,VarName}() end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[vn] = logp +function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} + logps = OrderedDict{KeyType,Vector{LogProbType}}() + return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) +function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) + logps = acc.logps + # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. + T = last(fieldtypes(eltype(logps))) + logpvec = get!(logps, vn, T()) + return push!(logpvec, logp) end function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[string(vn)] = logp + acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp +) where {whichlogprob} + return push!(acc, string(vn), logp) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +function accumulator_name( + ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} +) where {whichlogprob} + return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.logdensities[vn] = logp -end +# TODO(mhauru) Implement these to make PointwiseLogProbAccumulator work with +# ThreadSafeVarInfo. +# split(::LogPrior{T}) where {T} = LogPrior(zero(T)) +# combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp) +# acc!!(acc::PointwiseLogPrior, logp) = LogPrior(acc.logp + logp) -function _include_prior(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{PriorContext,DefaultContext} -end -function _include_likelihood(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} -end +resetacc!!(acc::PointwiseLogProbAccumulator) = acc -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) -end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return tilde_observe!!(context.context, right, left, vn, vi) +function accumulate_assume!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right +) where {whichlogprob} + if whichlogprob == :both || whichlogprob == :prior + subacc = accumulate_assume!!(LogPrior{LogProbType}(), val, logjac, vn, right) + push!(acc, vn, subacc.logp) end - - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) + return acc end -# Note on submodels (penelopeysm) -# -# We don't need to overload tilde_observe!! for Sampleables (yet), because it -# is currently not possible to evaluate a model with a Sampleable on the RHS -# of an observe statement. -# -# Note that calling tilde_assume!! on a Sampleable does not necessarily imply -# that there are no observe statements inside the Sampleable. There could well -# be likelihood terms in there, which must be included in the returned logp. -# See e.g. the `demo_dot_assume_observe_submodel` demo model. -# -# This is handled by passing the same context to rand_like!!, which figures out -# which terms to include using the context, and also mutates the context and vi -# appropriately. Thus, we don't need to check against _include_prior(context) -# here. -function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi) - value, vi = DynamicPPL.rand_like!!(right, context, vi) - return value, vi -end - -function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) - !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) - value, vi = tilde_assume(context.context, right, vn, vi) - # Track loglikelihood value. - # TODO(mhauru) logp here should be the logp that resulted from this tilde call. - # Implement this with a suitable accumulator. The current setting to zero is just to - # make this run, it produces nonsense results. - logp = zero(getlogjoint(vi)) - push!(context, vn, logp) - return value, acclogp!!(vi, logp) +function accumulate_observe!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, right, left, vn +) where {whichlogprob} + # If `vn` is nothing the LHS of ~ is a literal and we don't have a name to attach this + # acc to, and thus do nothing. + if vn === nothing + return acc + end + if whichlogprob == :both || whichlogprob == :likelihood + subacc = accumulate_observe!!(LogLikelihood{LogProbType}(), right, left, vn) + push!(acc, vn, subacc.logp) + end + return acc end """ @@ -238,14 +171,19 @@ julia> m = demo([1.0; 1.0]); julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` - """ function pointwise_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() -) where {T} + model::Model, + chain, + ::Type{KeyType}=String, + context::AbstractContext=DefaultContext(), + ::Val{whichlogprob}=Val(:both), +) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) - point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + acctype = PointwiseLogProbAccumulator{whichlogprob,KeyType} + vi = setaccs!!(vi, AccumulatorTuple(acctype())) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -253,26 +191,29 @@ function pointwise_logdensities( setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, point_context) + vi = last(evaluate!!(model, vi, context)) end + logps = getacc(vi, acctype).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in point_context.logdensities + varname => reshape(vals, niters, nchains) for (varname, vals) in logps ) return logdensities end function pointwise_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - point_context = PointwiseLogdensityContext( - OrderedDict{VarName,Vector{Float64}}(), context - ) - model(varinfo, point_context) - return point_context.logdensities + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext=DefaultContext(), + ::Val{whichlogprob}=Val(:both), +) where {whichlogprob} + acctype = PointwiseLogProbAccumulator{whichlogprob} + # TODO(mhauru) Don't needlessly evaluate the model twice. + varinfo = setaccs!!(varinfo, AccumulatorTuple(acctype())) + varinfo = last(evaluate!!(model, varinfo, context)) + return getacc(varinfo, acctype).logps end """ @@ -284,26 +225,15 @@ including the likelihood terms. See also: [`pointwise_logdensities`](@ref). """ function pointwise_loglikelihoods( - model::Model, - chain, - keytype::Type{T}=String, - context::AbstractContext=LikelihoodContext(), + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, context, Val(:likelihood)) end function pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() ) - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, varinfo, context) + return pointwise_logdensities(model, varinfo, context, Val(:likelihood)) end """ @@ -315,21 +245,13 @@ including the prior terms. See also: [`pointwise_logdensities`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, context, Val(:prior)) end function pointwise_prior_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() ) - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, varinfo, context) + return pointwise_logdensities(model, varinfo, context, Val(:prior)) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index fce1c7d44..e31d0337c 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -252,7 +252,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) end function BangBang.empty!!(vi::SimpleVarInfo) - return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values)) + return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 17c9a08fe..46b4e477d 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -5,7 +5,7 @@ """ Context that multiplies each log-prior by mod -used to test whether varwise_logpriors respects child-context. +used to test whether pointwise_logpriors respects child-context. """ struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext mod::T @@ -23,12 +23,14 @@ function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child return TestLogModifyingChildContext(context.mod, child) end function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, logp * context.mod, vi + value, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) + return value, vi end -function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi +function DynamicPPL.tilde_observe!!( + context::TestLogModifyingChildContext, right, left, vn, vi +) + vi = DynamicPPL.tilde_observe!!(context.context, right, left, vn, vi) + return vi end # Dummy context to test nested behaviors. diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index e29614982..90b6ac7ac 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -148,7 +148,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s, m, x=[1.5, 2.0]) end function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) @@ -194,7 +194,7 @@ end m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -225,7 +225,7 @@ end end x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -248,7 +248,7 @@ end m ~ MvNormal(zero(x), Diagonal(s)) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -279,7 +279,7 @@ end x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -304,7 +304,7 @@ end m ~ Normal(0, sqrt(s)) x .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -327,7 +327,7 @@ end m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -358,7 +358,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -384,7 +384,7 @@ end 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -407,7 +407,7 @@ end m ~ Normal(0, sqrt(s)) [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -440,7 +440,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -478,7 +478,7 @@ end # capture the result, so we just use a dummy variable _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -505,7 +505,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -535,7 +535,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index bed865526..ebc42e0fd 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -41,9 +41,9 @@ function accumulate_assume!!(vi::ThreadSafeVarInfo, r, logp, vn, right) return vi end -function accumulate_observe!!(vi::ThreadSafeVarInfo, left, right) +function accumulate_observe!!(vi::ThreadSafeVarInfo, right, left, vn) tid = Threads.threadid() - vi.accs_by_thread[tid] = accumulate_observe!!(vi.accs_by_thread[tid], left, right) + vi.accs_by_thread[tid] = accumulate_observe!!(vi.accs_by_thread[tid], right, left, vn) return vi end @@ -171,16 +171,15 @@ end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) + return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end -function resetlogp!!(vi::ThreadSafeVarInfo) - vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) - logprior = split(getacc(vi.varinfo, LogPrior)) - loglikelihood = split(getacc(vi.varinfo, LogLikelihood)) +function resetaccs!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo) for i in eachindex(vi.accs_by_thread) - vi.accs_by_thread[i] = setacc!!(vi.accs_by_thread[i], logprior) - vi.accs_by_thread[i] = setacc!!(vi.accs_by_thread[i], loglikelihood) + for acc in getaccs(vi.varinfo) + vi.accs_by_thread[i] = setacc!!(vi.accs_by_thread[i], split(acc)) + end end return vi end diff --git a/src/transforming.jl b/src/transforming.jl index 9d7e9e587..a9a75ebf5 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -31,8 +31,8 @@ function tilde_assume( return r, setindex!!(vi, r_transformed, vn) end -function tilde_observe(::DynamicTransformationContext, right, vn, vi) - return observe(right, vn, vi) +function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/varinfo.jl b/src/varinfo.jl index 58523ca2a..c8c5c2640 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -101,8 +101,14 @@ struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta accs::Accs end -VarInfo(meta=Metadata()) = VarInfo(meta, AccumulatorTuple(LogPrior{LogProbType}(), - LogLikelihood{LogProbType}(), NumProduce{Int}())) +function VarInfo(meta=Metadata()) + return VarInfo( + meta, + AccumulatorTuple( + LogPrior{LogProbType}(), LogLikelihood{LogProbType}(), NumProduce{Int}() + ), + ) +end """ VarInfo([rng, ]model[, sampler, context]) @@ -959,7 +965,7 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - vi = resetlogp!!(vi) + vi = resetaccs!!(vi) vi = reset_num_produce!!(vi) return vi end diff --git a/test/contexts.jl b/test/contexts.jl index fdff5f6f7..b77658fd1 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,7 +8,6 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLogdensityContext, contextual_isassumption, ConditionContext, decondition_context, @@ -40,15 +39,12 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "contexts.jl" begin - child_contexts = Dict( - :default => DefaultContext(), - ) + child_contexts = Dict(:default => DefaultContext()) parent_contexts = Dict( :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), :prefix => PrefixContext{:x}(DefaultContext()), - :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 61c842638..bb64d072f 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,7 +1,7 @@ @testset "logdensities_likelihoods.jl" begin mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS[1:1] example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index c8994c5a8..e489ae88d 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -207,8 +207,8 @@ svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end - # Reset the logp field. - svi_eval = DynamicPPL.resetlogp!!(svi_eval) + # Reset the logp accumulators. + svi_eval = DynamicPPL.resetaccs!!(svi_eval) # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 625b94c33..12df5706a 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -6,7 +6,9 @@ @test threadsafe_vi.varinfo === vi @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() - expected_accs = DynamicPPL.AccumulatorTuple((DynamicPPL.split(acc) for acc in vi.accs)...) + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in vi.accs)... + ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @@ -25,12 +27,16 @@ threadsafe_vi = resetlogp!!(threadsafe_vi) @test iszero(getlogp(threadsafe_vi)) - expected_accs = DynamicPPL.AccumulatorTuple((DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)...) + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)... + ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) threadsafe_vi = setlogp!!(threadsafe_vi, 42) @test getlogp(threadsafe_vi) == 42 - expected_accs = DynamicPPL.AccumulatorTuple((DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)...) + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)... + ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end From 7fe03ecd4989a536afe5849bd9e3aa84b1f86363 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 15 Apr 2025 12:23:36 +0100 Subject: [PATCH 08/48] Map rather than broadcast Co-authored-by: Tor Erlend Fjelde --- src/accumulators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index 4a94ca4dc..7b82b0ee9 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -16,7 +16,7 @@ struct AccumulatorTuple{N,T<:NamedTuple} nt::T function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}} - names = accumulator_name.(t) + names = map(accumulator_name, t) nt = NamedTuple{names}(t) return new{N,typeof(nt)}(nt) end From d49f7be25e35167c362b424c029bccfcce62278d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 15 Apr 2025 14:46:38 +0100 Subject: [PATCH 09/48] Start documenting accumulators --- src/accumulators.jl | 44 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index 7b82b0ee9..3b3320414 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -1,3 +1,26 @@ +""" + AbstractAccumulator + +An abstract type for accumulators. + +An accumulator is an object that may change its value at every tilde_assume or tilde_observe +call based on the value of the random variable in question. The obvious examples of +accumulators or the log prior and log likelihood. Others examples might be a variable that +counts the number of observations in a trace, or the names of random variables seen so far. + +An accumulator must implement the following methods: +- `accumulator_name(acc::AbstractAccumulator)`: returns a Symbol by which accumulators of +this type are identified. This name is unique in the sense that a `VarInfo` can only have +one accumulator for each name. Often the name is just the name of the type. +- `accumulate_observe!!(acc::AbstractAccumulator, right, left, vn)`: updates `acc` based on +observing the random variable `vn` with value `left`, with `right` being the distribution on +the RHS of the tilde statement. `accumulate_observe!!` may mutate `acc`, but not any of the +other arguments. `vn` is `nothing` in the case of literal observations like +`0.0 ~ Normal()`. `accumulate_observe!!` is called within `tilde_observe!!` for each +accumulator in the current `VarInfo`. +- `accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right)`: updates `acc` +at when a `tilde_assume` call is made. `vn` is the name of the variable being assumed +""" abstract type AbstractAccumulator end accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) @@ -5,12 +28,19 @@ accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) """ AccumulatorTuple{N,T<:NamedTuple} -A collection of accumulators, stored as a `NamedTuple`. +A collection of accumulators, stored as a `NamedTuple` of length `N` This is defined as a separate type to be able to dispatch on it cleanly and without method ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the constraint that the name in the tuple for each accumulator `acc` must be -`accumulator_name(acc)`. +`accumulator_name(acc)`, and these names must be unique. + +The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The +names will be generated automatically. One can also call the constructor with a `NamedTuple` +but the names in the argument will be discarded in favour of the generated ones. + +# Fields +$(TYPEDFIELDS) """ struct AccumulatorTuple{N,T<:NamedTuple} nt::T @@ -29,7 +59,15 @@ Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] Base.length(::AccumulatorTuple{N}) where {N} = N Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) -function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) +""" + setacc(at::AccumulatorTuple, acc::AbstractAccumulator) + +Add `acc` to `at`. Returns a new `AccumulatorTuple`. + +If an `AbstractAccumulator` with the same `accumulator_name` already exists in `at` it is +replaced. +""" +function setacc(at::AccumulatorTuple, acc::AbstractAccumulator) return Accessors.@set at.nt[accumulator_name(acc)] = acc end From 28bbf1c453a911de692a176bb0a5b280f10ebf95 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 15 Apr 2025 15:00:48 +0100 Subject: [PATCH 10/48] Use Val{symbols} instead of AccTypes to index --- src/abstract_varinfo.jl | 18 +++++++++--------- src/accumulators.jl | 14 +++++++------- src/pointwise_logdensities.jl | 4 ++-- src/threadsafe.jl | 10 +++++----- src/varinfo.jl | 2 +- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 9cc51c916..e5a773b20 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -102,8 +102,8 @@ getlogp(vi::AbstractVarInfo) = getlogjoint(vi) function setaccs!! end function getaccs end -getlogprior(vi::AbstractVarInfo) = getacc(vi, LogPrior).logp -getloglikelihood(vi::AbstractVarInfo) = getacc(vi, LogLikelihood).logp +getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp +getloglikelihood(vi::AbstractVarInfo) = getacc(vi, Val(:LogLikelihood)).logp function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) return setaccs!!(vi, setacc!!(getaccs(vi), acc)) @@ -124,8 +124,8 @@ function setlogp!!(vi::AbstractVarInfo, logp) return vi end -function getacc(vi::AbstractVarInfo, ::Type{AccType}) where {AccType} - return getacc(getaccs(vi), AccType) +function getacc(vi::AbstractVarInfo, accname) + return getacc(getaccs(vi), accname) end function accumulate_assume!!(vi::AbstractVarInfo, r, logjac, vn, right) @@ -136,16 +136,16 @@ function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) return setaccs!!(vi, accumulate_observe!!(getaccs(vi), right, left, vn)) end -function acc!!(vi::AbstractVarInfo, ::Type{AccType}, args...) where {AccType} - return setaccs!!(vi, acc!!(getaccs(vi), AccType, args...)) +function acc!!(vi::AbstractVarInfo, accname, args...) + return setaccs!!(vi, acc!!(getaccs(vi), accname, args...)) end function acclogprior!!(vi::AbstractVarInfo, logp) - return acc!!(vi, LogPrior, logp) + return acc!!(vi, Val(:LogPrior), logp) end function accloglikelihood!!(vi::AbstractVarInfo, logp) - return acc!!(vi, LogLikelihood, logp) + return acc!!(vi, Val(:LogLikelihood), logp) end """ @@ -170,7 +170,7 @@ function resetaccs!!(vi::AbstractVarInfo) return setaccs!!(vi, accs) end -haslogp(vi::AbstractVarInfo) = hasacc(vi, LogPrior) || hasacc(vi, LogLikelihood) +haslogp(vi::AbstractVarInfo) = hasacc(vi, Val(:LogPrior)) || hasacc(vi, Val(:LogLikelihood)) # Variables and their realizations. @doc """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 3b3320414..383b7aa3b 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -60,19 +60,20 @@ Base.length(::AccumulatorTuple{N}) where {N} = N Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) """ - setacc(at::AccumulatorTuple, acc::AbstractAccumulator) + setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) Add `acc` to `at`. Returns a new `AccumulatorTuple`. If an `AbstractAccumulator` with the same `accumulator_name` already exists in `at` it is -replaced. +replaced. `at` will never be mutated, but the name has the `!!` for consistency with the +corresponding function for `AbstractVarInfo`. """ -function setacc(at::AccumulatorTuple, acc::AbstractAccumulator) +function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) return Accessors.@set at.nt[accumulator_name(acc)] = acc end -function getacc(at::AccumulatorTuple, ::Type{AccType}) where {AccType} - return at[accumulator_name(AccType)] +function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} + return at[accname] end function accumulate_assume!!(at::AccumulatorTuple, r, logjac, vn, right) @@ -85,8 +86,7 @@ function accumulate_observe!!(at::AccumulatorTuple, right, left, vn) return AccumulatorTuple(map(acc -> accumulate_observe!!(acc, right, left, vn), at.nt)) end -function acc!!(at::AccumulatorTuple, ::Type{AccType}, args...) where {AccType} - accname = accumulator_name(AccType) +function acc!!(at::AccumulatorTuple, ::Val{accname}, args...) where {accname} return Accessors.@set at.nt[accname] = acc!!(at[accname], args...) end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index c856adffe..8779460a2 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -194,7 +194,7 @@ function pointwise_logdensities( vi = last(evaluate!!(model, vi, context)) end - logps = getacc(vi, acctype).logps + logps = getacc(vi, Val(accumulator_name(acctype))).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( @@ -213,7 +213,7 @@ function pointwise_logdensities( # TODO(mhauru) Don't needlessly evaluate the model twice. varinfo = setaccs!!(varinfo, AccumulatorTuple(acctype())) varinfo = last(evaluate!!(model, varinfo, context)) - return getacc(varinfo, acctype).logps + return getacc(varinfo, Val(accumulator_name(acctype))).logps end """ diff --git a/src/threadsafe.jl b/src/threadsafe.jl index ebc42e0fd..e29f50b73 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -28,9 +28,9 @@ end # Get both the main accumulator and the thread-specific accumulators of the same type and # combine them. -function getacc(vi::ThreadSafeVarInfo, ::Type{AccType}) where {AccType} - main_acc = getacc(vi.varinfo, AccType) - other_accs = map(accs -> getacc(accs, AccType), vi.accs_by_thread) +function getacc(vi::ThreadSafeVarInfo, accname) + main_acc = getacc(vi.varinfo, accname) + other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) return foldl(combine, other_accs; init=main_acc) end @@ -47,9 +47,9 @@ function accumulate_observe!!(vi::ThreadSafeVarInfo, right, left, vn) return vi end -function acc!!(vi::ThreadSafeVarInfo, ::Type{AccType}, args...) where {AccType} +function acc!!(vi::ThreadSafeVarInfo, accname, args...) tid = Threads.threadid() - vi.accs_by_thread[tid] = acc!!(vi.accs_by_thread[tid], AccType, args...) + vi.accs_by_thread[tid] = acc!!(vi.accs_by_thread[tid], accname, args...) return vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index c8c5c2640..e4ae7a13a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1008,7 +1008,7 @@ setaccs!!(vi::VarInfo, accs) = Accessors.@set vi.accs = accs Return the `num_produce` of `vi`. """ -get_num_produce(vi::VarInfo) = getacc(vi, NumProduce).num +get_num_produce(vi::VarInfo) = getacc(vi, Val(:NumProduce)).num """ set_num_produce!!(vi::VarInfo, n::Int) From a0ed6657aab05e866fec0ba38f174cb760a97af0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 15 Apr 2025 15:33:22 +0100 Subject: [PATCH 11/48] More documentation for accumulators --- src/accumulators.jl | 170 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 155 insertions(+), 15 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index 383b7aa3b..9a1cd5500 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -5,26 +5,106 @@ An abstract type for accumulators. An accumulator is an object that may change its value at every tilde_assume or tilde_observe call based on the value of the random variable in question. The obvious examples of -accumulators or the log prior and log likelihood. Others examples might be a variable that -counts the number of observations in a trace, or the names of random variables seen so far. - -An accumulator must implement the following methods: -- `accumulator_name(acc::AbstractAccumulator)`: returns a Symbol by which accumulators of -this type are identified. This name is unique in the sense that a `VarInfo` can only have -one accumulator for each name. Often the name is just the name of the type. -- `accumulate_observe!!(acc::AbstractAccumulator, right, left, vn)`: updates `acc` based on -observing the random variable `vn` with value `left`, with `right` being the distribution on -the RHS of the tilde statement. `accumulate_observe!!` may mutate `acc`, but not any of the -other arguments. `vn` is `nothing` in the case of literal observations like -`0.0 ~ Normal()`. `accumulate_observe!!` is called within `tilde_observe!!` for each -accumulator in the current `VarInfo`. -- `accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right)`: updates `acc` -at when a `tilde_assume` call is made. `vn` is the name of the variable being assumed +accumulators are the log prior and log likelihood. Others examples might be a variable that +counts the number of observations in a trace, or a list of the names of random variables +seen so far. + +An accumulator type `T` must implement the following methods: +- `accumulator_name(acc::T)` +- `accumulate_observe!!(acc::T, right, left, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, right)` + +To be able to work with multi-threading, it should also implement: +- `split(acc::T)` +- `combine(acc::T, acc2::T)` + +It may also want to implement +- `acc!!(acc::T, args...)` + +See the documentation for each of these functions for more details. """ abstract type AbstractAccumulator end +# TODO(mhauru) Add to the above docstring stuff about resets. + +""" + accumulator_name(acc::AbstractAccumulator) + +Return a Symbol which can be used as a name for `acc`. + +The name has to be unique in the sense that a `VarInfo` can only have one accumulator for +each name. The most typical case, and the default implementation, is that the name only +depends on the type of `acc`, not on its value. +""" accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) +""" + accumulate_observe!!(acc::AbstractAccumulator, right, left, vn) + +Update `acc` in a `tilde_observe` call. Returns the updated `acc`. + +`vn` is the name of the variable being observed, `left` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case +of literal observations like `0.0 ~ Normal()`. + +`accumulate_observe!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_assume!!`](@ref) +""" +function accumulate_observe!! end + +""" + accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right) + +Update `acc` in a `tilde_assume` call. Returns the updated `acc`. + +`vn` is the name of the variable being assumed, `val` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `logjac` is the log +determinant of the Jacobian of the transformation that was done to convert the value of `vn` +as it was given (e.g. by sampler operating in linked space) to `val`. + +`accumulate_assume!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_observe!!`](@ref) +""" +function accumulate_assume!! end + +""" + split(acc::AbstractAccumulator) + +Return a new accumulator like `acc` but empty. + +The precise meaning of "empty" is that that the returned value should be such that +`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading +where different threads may accumulate independently and the results are the combined. + +See also: [`combine`](@ref) +""" +function split end + +""" + combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) + +Combine two accumulators of the same type. Returns a new accumulator. + +See also: [`split`](@ref) +""" +function combine end + +""" + acc!!(acc::AbstractAccumulator, args...) + +Update `acc` with the values in `args`. Returns the updated `acc`. + +What this means depends greatly on the type of `acc`. For example, for `LogPrior` `args` +would be just `logp`. The utility of this function is that one can call +`acc!!(varinfo::AbstractVarinfo, Val(accname), args...)`, and this call will be propagated +to a call on the particular accumulator. +""" +function acc!! end + +# END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE + """ AccumulatorTuple{N,T<:NamedTuple} @@ -72,16 +152,35 @@ function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) return Accessors.@set at.nt[accumulator_name(acc)] = acc end +""" + getacc(at::AccumulatorTuple, ::Val{accname}) + +Get the accumulator with name `accname` from `at`. +""" function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} return at[accname] end +""" + accumulate_assume!!(at::AccumulatorTuple, r, logjac, vn, right) + +Call `accumulate_assume!!` on each accumulator in `at`. + +Returns a new AccumulatorTuple. +""" function accumulate_assume!!(at::AccumulatorTuple, r, logjac, vn, right) return AccumulatorTuple( map(acc -> accumulate_assume!!(acc, r, logjac, vn, right), at.nt) ) end +""" + accumulate_observe!!(at::AccumulatorTuple, right, left, vn) + +Call `accumulate_observe!!` on each accumulator in `at`. + +Returns a new AccumulatorTuple. +""" function accumulate_observe!!(at::AccumulatorTuple, right, left, vn) return AccumulatorTuple(map(acc -> accumulate_observe!!(acc, right, left, vn), at.nt)) end @@ -90,22 +189,63 @@ function acc!!(at::AccumulatorTuple, ::Val{accname}, args...) where {accname} return Accessors.@set at.nt[accname] = acc!!(at[accname], args...) end +# END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS + +""" + LogPrior{T} <: AbstractAccumulator + +An accumulator that tracks the cumulative log prior during model execution. + +# Fields +$(TYPEDFIELDS) +""" struct LogPrior{T} <: AbstractAccumulator logp::T end +""" + LogPrior{T}() where {T} + +Create a new `LogPrior` accumulator with the log prior initialized to zero. +""" LogPrior{T}() where {T} = LogPrior(zero(T)) +""" + LogLikelihood{T} <: AbstractAccumulator + +An accumulator that tracks the cumulative log likelihood during model execution. + +# Fields +$(TYPEDFIELDS) +""" struct LogLikelihood{T} <: AbstractAccumulator logp::T end +""" + LogLikelihood{T}() where {T} + +Create a new `LogLikelihood` accumulator with the log likelihood initialized to zero. +""" LogLikelihood{T}() where {T} = LogLikelihood(zero(T)) +""" + NumProduce{T} <: AbstractAccumulator + +An accumulator that tracks the number of observations during model execution. + +# Fields +$(TYPEDFIELDS) +""" struct NumProduce{T<:Integer} <: AbstractAccumulator num::T end +""" + NumProduce{T}() where {T<:Integer} + +Create a new `NumProduce` accumulator with the number of observations initialized to zero. +""" NumProduce{T}() where {T} = NumProduce(zero(T)) accumulator_name(::Type{<:LogPrior}) = :LogPrior From be2763633a8c47103cf4943d41d893849d22b6ec Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 16 Apr 2025 13:28:50 +0100 Subject: [PATCH 12/48] Link varinfo by default in AD testing utilities; make test suite run on linked varinfos (#890) * Link VarInfo by default * Tweak interface * Fix tests * Fix interface so that callers can inspect results * Document * Fix tests * Fix changelog * Test linked varinfos Closes #891 * Fix docstring + use AbstractFloat --- HISTORY.md | 12 +++++ docs/src/api.md | 1 + src/test_utils/ad.jl | 104 ++++++++++++++++++++++++++--------------- src/transforming.jl | 4 +- test/ad.jl | 18 +++---- test/simple_varinfo.jl | 6 --- 6 files changed, 91 insertions(+), 54 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index a21258ec0..a45644a64 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,18 @@ **Breaking changes** +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + ### VarInfo constructors `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. diff --git a/docs/src/api.md b/docs/src/api.md index 2c61f54fc..ec741c9ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -212,6 +212,7 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL ```@docs DynamicPPL.TestUtils.AD.run_ad DynamicPPL.TestUtils.AD.ADResult +DynamicPPL.TestUtils.AD.ADIncorrectException ``` ## Demo models diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 06c76df5e..d38915c12 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,19 +4,13 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: Random, Xoshiro using Statistics: median using Test: @test -export ADResult, run_ad - -# This function needed to work around the fact that different backends can -# return different AbstractArrays for the gradient. See -# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more -# context. -_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x) +export ADResult, run_ad, ADIncorrectException """ REFERENCE_ADTYPE @@ -27,33 +21,50 @@ it's the default AD backend used in Turing.jl. const REFERENCE_ADTYPE = AutoForwardDiff() """ - ADResult + ADIncorrectException{T<:AbstractFloat} + +Exception thrown when an AD backend returns an incorrect value or gradient. + +The type parameter `T` is the numeric type of the value and gradient. +""" +struct ADIncorrectException{T<:AbstractFloat} <: Exception + value_expected::T + value_actual::T + grad_expected::Vector{T} + grad_actual::Vector{T} +end + +""" + ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} Data structure to store the results of the AD correctness test. + +The type parameter `Tparams` is the numeric type of the parameters passed in; +`Tresult` is the type of the value and the gradient. """ -struct ADResult +struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} "The DynamicPPL model that was tested" model::Model "The VarInfo that was used" varinfo::AbstractVarInfo "The values at which the model was evaluated" - params::Vector{<:Real} + params::Vector{Tparams} "The AD backend that was tested" adtype::AbstractADType "The absolute tolerance for the value of logp" - value_atol::Real + value_atol::Tresult "The absolute tolerance for the gradient of logp" - grad_atol::Real + grad_atol::Tresult "The expected value of logp" - value_expected::Union{Nothing,Float64} + value_expected::Union{Nothing,Tresult} "The expected gradient of logp" - grad_expected::Union{Nothing,Vector{Float64}} + grad_expected::Union{Nothing,Vector{Tresult}} "The value of logp (calculated using `adtype`)" - value_actual::Union{Nothing,Real} + value_actual::Union{Nothing,Tresult} "The gradient of logp (calculated using `adtype`)" - grad_actual::Union{Nothing,Vector{Float64}} + grad_actual::Union{Nothing,Vector{Tresult}} "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" - time_vs_primal::Union{Nothing,Float64} + time_vs_primal::Union{Nothing,Tresult} end """ @@ -64,26 +75,27 @@ end benchmark=false, value_atol=1e-6, grad_atol=1e-6, - varinfo::AbstractVarInfo=VarInfo(model), - params::Vector{<:Real}=varinfo[:], + varinfo::AbstractVarInfo=link(VarInfo(model), model), + params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult +### Description + Test the correctness and/or benchmark the AD backend `adtype` for the model `model`. Whether to test and benchmark is controlled by the `test` and `benchmark` keyword arguments. By default, `test` is `true` and `benchmark` is `false`. -Returns an [`ADResult`](@ref) object, which contains the results of the -test and/or benchmark. - Note that to run AD successfully you will need to import the AD backend itself. For example, to test with `AutoReverseDiff()` you will need to run `import ReverseDiff`. +### Arguments + There are two positional arguments, which absolutely must be provided: 1. `model` - The model being tested. @@ -96,7 +108,9 @@ Everything else is optional, and can be categorised into several groups: DynamicPPL contains several different types of VarInfo objects which change the way model evaluation occurs. If you want to use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to - using a `TypedVarInfo` generated from the model. + using a linked `TypedVarInfo` generated from the model. Here, _linked_ + means that the parameters in the VarInfo have been transformed to + unconstrained Euclidean space if they aren't already in that space. 2. _How to specify the parameters._ @@ -140,27 +154,40 @@ Everything else is optional, and can be categorised into several groups: By default, this function prints messages when it runs. To silence it, set `verbose=false`. + +### Returns / Throws + +Returns an [`ADResult`](@ref) object, which contains the results of the +test and/or benchmark. + +If `test` is `true` and the AD backend returns an incorrect value or gradient, an +`ADIncorrectException` is thrown. If a different error occurs, it will be +thrown as-is. """ function run_ad( model::Model, adtype::AbstractADType; - test=true, - benchmark=false, - value_atol=1e-6, - grad_atol=1e-6, - varinfo::AbstractVarInfo=VarInfo(model), - params::Vector{<:Real}=varinfo[:], + test::Bool=true, + benchmark::Bool=false, + value_atol::AbstractFloat=1e-6, + grad_atol::AbstractFloat=1e-6, + varinfo::AbstractVarInfo=link(VarInfo(model), model), + params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, reference_adtype::AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult + if isnothing(params) + params = varinfo[:] + end + params = map(identity, params) # Concretise + verbose && @info "Running AD on $(model.f) with $(adtype)\n" - params = map(identity, params) verbose && println(" params : $(params)") ldf = LogDensityFunction(model, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) - grad = _to_vec_f64(grad) + grad = collect(grad) verbose && println(" actual : $((value, grad))") if test @@ -172,10 +199,11 @@ function run_ad( expected_value_and_grad end verbose && println(" expected : $((value_true, grad_true))") - grad_true = _to_vec_f64(grad_true) - # Then compare - @test isapprox(value, value_true; atol=value_atol) - @test isapprox(grad, grad_true; atol=grad_atol) + grad_true = collect(grad_true) + + exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) + isapprox(value, value_true; atol=value_atol) || exc() + isapprox(grad, grad_true; atol=grad_atol) || exc() else value_true = nothing grad_true = nothing diff --git a/src/transforming.jl b/src/transforming.jl index 0239725ae..429562ec8 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -19,9 +19,9 @@ function tilde_assume( lp = Bijectors.logpdf_with_trans(right, r, !isinverse) if istrans(vi, vn) - @assert isinverse "Trying to link already transformed variables" + isinverse || @warn "Trying to link an already transformed variable ($vn)" else - @assert !isinverse "Trying to invlink non-transformed variables" + isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" end # Only transform if `!isinverse` since `vi[vn, right]` diff --git a/test/ad.jl b/test/ad.jl index 33d581228..69ab99e19 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -23,21 +23,23 @@ using DynamicPPL: LogDensityFunction varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = LogDensityFunction(m, varinfo) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" + @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" # Put predicates here to avoid long lines is_mooncake = adtype isa AutoMooncake is_1_10 = v"1.10" <= VERSION < v"1.11" is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} + is_svi_vnv = + linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} # Mooncake doesn't work with several combinations of SimpleVarInfo. if is_mooncake && is_1_11 && is_svi_vnv @@ -56,12 +58,12 @@ using DynamicPPL: LogDensityFunction ref_ldf, adtype ) else - DynamicPPL.TestUtils.AD.run_ad( + @test DynamicPPL.TestUtils.AD.run_ad( m, adtype; - varinfo=varinfo, + varinfo=linked_varinfo, expected_value_and_grad=(ref_logp, ref_grad), - ) + ) isa Any end end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index aa3b592f7..380c24e7d 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -111,12 +111,6 @@ # Should be approx. the same as the "lazy" transformation. @test logjoint(model, vi_linked) ≈ lp_linked - # TODO: Should not `VarInfo` also error here? The current implementation - # only warns and acts as a no-op. - if vi isa SimpleVarInfo - @test_throws AssertionError link!!(vi_linked, model) - end - # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) lp_invlinked = getlogp(vi_invlinked) From e6453feb0e80cb3332d597d71fad9266e940f36e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 16 Apr 2025 14:32:56 +0100 Subject: [PATCH 13/48] Fix resetlogp!! and type stability for accumulators --- src/DynamicPPL.jl | 2 +- src/abstract_varinfo.jl | 164 +++++++++++++++++++++++++++++----- src/accumulators.jl | 67 +++++++------- src/model.jl | 4 +- src/pointwise_logdensities.jl | 3 - src/simple_varinfo.jl | 2 +- src/threadsafe.jl | 32 ++++--- src/varinfo.jl | 9 +- test/simple_varinfo.jl | 2 +- 9 files changed, 205 insertions(+), 80 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e743bbd8b..cfb730f52 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -52,7 +52,7 @@ export AbstractVarInfo, getlogp, setlogp!!, acclogp!!, - resetaccs!!, + resetlogp!!, get_num_produce, set_num_produce!!, reset_num_produce!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index e5a773b20..e86525298 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -93,23 +93,91 @@ function transformation end """ getlogjoint(vi::AbstractVarInfo) -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. +Return the log of the joint probability of the observed data and parameters in `vi`. + +See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). """ getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) -getlogp(vi::AbstractVarInfo) = getlogjoint(vi) +function getlogp(vi::AbstractVarInfo) + Base.depwarn("getlogp is deprecated, use getlogjoint instead", :getlogp) + return getlogjoint(vi) +end + +""" + setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple) + +Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense. +This should be implemented by each subtype of `AbstractVarInfo`. `setaccs!!` is not +user-facing, but used in the implementation of many other functions. +""" function setaccs!! end + +""" + getaccs(vi::AbstractVarInfo) + +Return the `AccumulatorTuple` of `vi`. + +This should be implemented by each subtype of `AbstractVarInfo`. `getaccs` is not +user-facing, but used in the implementation of many other functions. +""" function getaccs end +""" + hasacc(vi::AbstractVarInfo, ::Val{accname}) where {accname} + +Return a boolean for whether `vi` has an accumulator with name `accname`. +""" +hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) + +""" + getlogprior(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters in `vi`. + +See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ref). +""" getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp + +""" + getloglikelihood(vi::AbstractVarInfo) + +Return the log of the likelihood probability of the observed data in `vi`. + +See also: [`getlogjoint`](@ref), [`getlogprior`](@ref), [`setloglikelihood!!`](@ref). +""" getloglikelihood(vi::AbstractVarInfo) = getacc(vi, Val(:LogLikelihood)).logp +""" + setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + +Add `acc` to the `AccumulatorTuple` of `vi`, mutating if it makes sense. + +If an accumulator with the same [`accumulator_name`](@ref) already exists, it will be +replaced. + +See also: [`getaccs`](@ref). +""" function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) return setaccs!!(vi, setacc!!(getaccs(vi), acc)) end +""" + setlogprior!!(vi::AbstractVarInfo, logp) + +Set the log of the prior probability of the parameters sampled in `vi` to `logp`. + +See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref). +""" setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPrior(logp)) + +""" + setloglikelihood!!(vi::AbstractVarInfo, logp) + +Set the log of the likelihood probability of the observed data sampled in `vi` to `logp`. + +See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref). +""" setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp)) """ @@ -117,35 +185,86 @@ setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp) Set the log of the joint probability of the observed data and parameters sampled in `vi` to `logp`, mutating if it makes sense. + +See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). """ function setlogp!!(vi::AbstractVarInfo, logp) + Base.depwarn( + "setlogp!! is deprecated, use setlogprior!! or setloglikelihood!! instead", + :setlogp!!, + ) vi = setlogprior!!(vi, zero(logp)) vi = setloglikelihood!!(vi, logp) return vi end +""" + getacc(vi::AbstractVarInfo, accname) + +Return the `AbstractAccumulator` of `vi` with name `accname`. +""" function getacc(vi::AbstractVarInfo, accname) return getacc(getaccs(vi), accname) end -function accumulate_assume!!(vi::AbstractVarInfo, r, logjac, vn, right) - return setaccs!!(vi, accumulate_assume!!(getaccs(vi), r, logjac, vn, right)) +""" + accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + +Update all the accumulators of `vi` by calling `accumulate_assume!!` on them. +""" +function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + return map_accumulator!!(vi, accumulate_assume!!, val, logjac, vn, right) end +""" + accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + +Update all the accumulators of `vi` by calling `accumulate_observe!!` on them. +""" function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) - return setaccs!!(vi, accumulate_observe!!(getaccs(vi), right, left, vn)) + return map_accumulator!!(vi, accumulate_observe!!, right, left, vn) +end + +""" + map_accumulator!!(vi::AbstractVarInfo, func::Function, args...) where {accname} + +Update all accumulators of `vi` by calling `func(acc, args...)` on them and replacing +them with the return values. +""" +function map_accumulator!!(vi::AbstractVarInfo, func::Function, args...) + return setaccs!!(vi, map_accumulator!!(getaccs(vi), func, args...)) end -function acc!!(vi::AbstractVarInfo, accname, args...) - return setaccs!!(vi, acc!!(getaccs(vi), accname, args...)) +""" + map_accumulator!!(vi::AbstractVarInfo, ::Val{accname}, func::Function, args...) where {accname} + +Update the accumulator `accname` of `vi` by calling `func(acc, args...)` on and replacing +it with the return value. +""" +function map_accumulator!!(vi::AbstractVarInfo, accname::Val, func::Function, args...) + return setaccs!!(vi, map_accumulator!!(getaccs(vi), accname, func, args...)) end +""" + acclogprior!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the prior probability in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). +""" function acclogprior!!(vi::AbstractVarInfo, logp) - return acc!!(vi, Val(:LogPrior), logp) + return map_accumulator!!(vi, Val(:LogPrior), +, logp) end +""" + accloglikelihood!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the likelihood in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). +""" function accloglikelihood!!(vi::AbstractVarInfo, logp) - return acc!!(vi, Val(:LogLikelihood), logp) + return map_accumulator!!(vi, Val(:LogLikelihood), +, logp) end """ @@ -154,24 +273,29 @@ end Add `logp` to the value of the log of the joint probability of the observed data and parameters sampled in `vi`, mutating if it makes sense. """ -acclogp!!(vi::AbstractVarInfo, logp) = accloglikelihood!!(vi, logp) +function acclogp!!(vi::AbstractVarInfo, logp) + Base.depwarn( + "acclogp!! is deprecated, use acclogprior!! or accloglikelihood!! instead", + :acclogp!!, + ) + return accloglikelihood!!(vi, logp) +end """ resetlogp!!(vi::AbstractVarInfo) -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0, mutating if it makes sense. +Reset the values of the log probabilities (prior and likelihood) in `vi` """ -function resetaccs!!(vi::AbstractVarInfo) - accs = getaccs(vi) - for acc in accs - accs = setacc!!(accs, resetacc!!(acc)) +function resetlogp!!(vi::AbstractVarInfo) + if hasacc(vi, Val(:LogPrior)) + vi = map_accumulator!!(vi, Val(:LogPrior), zero) + end + if hasacc(vi, Val(:LogLikelihood)) + vi = map_accumulator!!(vi, Val(:LogLikelihood), zero) end - return setaccs!!(vi, accs) + return vi end -haslogp(vi::AbstractVarInfo) = hasacc(vi, Val(:LogPrior)) || hasacc(vi, Val(:LogLikelihood)) - # Variables and their realizations. @doc """ keys(vi::AbstractVarInfo) diff --git a/src/accumulators.jl b/src/accumulators.jl index 9a1cd5500..6ae3ffcbf 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -18,9 +18,6 @@ To be able to work with multi-threading, it should also implement: - `split(acc::T)` - `combine(acc::T, acc2::T)` -It may also want to implement -- `acc!!(acc::T, args...)` - See the documentation for each of these functions for more details. """ abstract type AbstractAccumulator end @@ -138,6 +135,7 @@ AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] Base.length(::AccumulatorTuple{N}) where {N} = N Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) +Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} = haskey(at.nt, accname) """ setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) @@ -149,7 +147,9 @@ replaced. `at` will never be mutated, but the name has the `!!` for consistency corresponding function for `AbstractVarInfo`. """ function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) - return Accessors.@set at.nt[accumulator_name(acc)] = acc + accname = accumulator_name(acc) + new_nt = merge(at.nt, NamedTuple{(accname,)}((acc,))) + return AccumulatorTuple(new_nt) end """ @@ -162,31 +162,37 @@ function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} end """ - accumulate_assume!!(at::AccumulatorTuple, r, logjac, vn, right) + map_accumulator!!(at::AccumulatorTuple, func::Function, args...) -Call `accumulate_assume!!` on each accumulator in `at`. +Update the accumulators in `at` by calling `func(acc, args...)` on them and replacing them +with the return values. -Returns a new AccumulatorTuple. +Returns a new `AccumulatorTuple`. The `!!` in the name is for consistency with the +corresponding function for `AbstractVarInfo`. """ -function accumulate_assume!!(at::AccumulatorTuple, r, logjac, vn, right) - return AccumulatorTuple( - map(acc -> accumulate_assume!!(acc, r, logjac, vn, right), at.nt) - ) +function map_accumulator!!(at::AccumulatorTuple, func::Function, args...) + return AccumulatorTuple(map(acc -> func(acc, args...), at.nt)) end """ - accumulate_observe!!(at::AccumulatorTuple, right, left, vn) + map_accumulator!!(at::AccumulatorTuple, ::Val{accname}, func::Function, args...) -Call `accumulate_observe!!` on each accumulator in `at`. +Update the accumulator with name `accname` in `at` by calling `func(acc, args...)` on it +and replacing it with the return value. -Returns a new AccumulatorTuple. +Returns a new `AccumulatorTuple`. The `!!` in the name is for consistency with the +corresponding function for `AbstractVarInfo`. """ -function accumulate_observe!!(at::AccumulatorTuple, right, left, vn) - return AccumulatorTuple(map(acc -> accumulate_observe!!(acc, right, left, vn), at.nt)) -end - -function acc!!(at::AccumulatorTuple, ::Val{accname}, args...) where {accname} - return Accessors.@set at.nt[accname] = acc!!(at[accname], args...) +function map_accumulator!!( + at::AccumulatorTuple, ::Val{accname}, func::Function, args... +) where {accname} + # Would like to write this as + # return Accessors.@set at.nt[accname] = func(at[accname], args...) + # for readability, but that one isn't type stable due to + # https://github.com/JuliaObjects/Accessors.jl/issues/198 + new_val = func(at[accname], args...) + new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,))) + return AccumulatorTuple(new_nt) end # END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS @@ -252,12 +258,6 @@ accumulator_name(::Type{<:LogPrior}) = :LogPrior accumulator_name(::Type{<:LogLikelihood}) = :LogLikelihood accumulator_name(::Type{<:NumProduce}) = :NumProduce -resetacc!!(acc::LogPrior) = LogPrior(zero(acc.logp)) -resetacc!!(acc::LogLikelihood) = LogLikelihood(zero(acc.logp)) -# TODO(mhauru) How to handle reset for NumProduce? Do we need to define different types of -# resets? -resetacc!!(acc::NumProduce) = acc - split(::LogPrior{T}) where {T} = LogPrior(zero(T)) split(::LogLikelihood{T}) where {T} = LogLikelihood(zero(T)) split(acc::NumProduce) = acc @@ -268,19 +268,22 @@ function combine(acc::NumProduce, acc2::NumProduce) return NumProduce(max(acc.num, acc2.num)) end -acc!!(acc::LogPrior, logp) = LogPrior(acc.logp + logp) -acc!!(acc::LogLikelihood, logp) = LogLikelihood(acc.logp + logp) -acc!!(acc::NumProduce, n) = NumProduce(acc.num + n) +Base.:+(acc::LogPrior{T}, logp::T) where {T} = LogPrior(acc.logp + logp) +Base.:+(acc::LogLikelihood{T}, logp::T) where {T} = LogLikelihood(acc.logp + logp) +Base.:+(acc::NumProduce{T}, num::T) where {T} = NumProduce(acc.num + num) +Base.zero(acc::LogPrior) = LogPrior(zero(acc.logp)) +Base.zero(acc::LogLikelihood) = LogLikelihood(zero(acc.logp)) +Base.zero(acc::NumProduce) = NumProduce(zero(acc.num)) function accumulate_assume!!(acc::LogPrior, val, logjac, vn, right) - return acc!!(acc, logpdf(right, val) + logjac) + return acc + (logpdf(right, val) + logjac) end accumulate_observe!!(acc::LogPrior, right, left, vn) = acc accumulate_assume!!(acc::LogLikelihood, val, logjac, vn, right) = acc function accumulate_observe!!(acc::LogLikelihood, right, left, vn) - return acc!!(acc, logpdf(right, left)) + return acc + logpdf(right, left) end accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc -accumulate_observe!!(acc::NumProduce, right, left, vn) = acc!!(acc, 1) +accumulate_observe!!(acc::NumProduce, right, left, vn) = acc + 1 diff --git a/src/model.jl b/src/model.jl index 7df61e939..2431935f0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -882,7 +882,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ function evaluate_threadunsafe!!(model, varinfo, context) - return _evaluate!!(model, resetaccs!!(varinfo), context) + return _evaluate!!(model, resetlogp!!(varinfo), context) end """ @@ -897,7 +897,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ function evaluate_threadsafe!!(model, varinfo, context) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) + wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper, context) return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 8779460a2..256dc2b47 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -36,9 +36,6 @@ end # ThreadSafeVarInfo. # split(::LogPrior{T}) where {T} = LogPrior(zero(T)) # combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp) -# acc!!(acc::PointwiseLogPrior, logp) = LogPrior(acc.logp + logp) - -resetacc!!(acc::PointwiseLogProbAccumulator) = acc function accumulate_assume!!( acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e31d0337c..fce1c7d44 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -252,7 +252,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) end function BangBang.empty!!(vi::SimpleVarInfo) - return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) + return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values)) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index e29f50b73..6f8bc58e6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -34,22 +34,19 @@ function getacc(vi::ThreadSafeVarInfo, accname) return foldl(combine, other_accs; init=main_acc) end -# Calls to accumulate_assume!!, accumulate_observe!!, and acc!! are thread-specific. -function accumulate_assume!!(vi::ThreadSafeVarInfo, r, logp, vn, right) +# Calls to map_accumulator!! are thread-specific by default. For any use of them that should +# _not_ be thread-specific a specific method has to be written. +function map_accumulator!!(vi::ThreadSafeVarInfo, accname::Val, func::Function, args...) tid = Threads.threadid() - vi.accs_by_thread[tid] = accumulate_assume!!(vi.accs_by_thread[tid], r, logp, vn, right) - return vi -end - -function accumulate_observe!!(vi::ThreadSafeVarInfo, right, left, vn) - tid = Threads.threadid() - vi.accs_by_thread[tid] = accumulate_observe!!(vi.accs_by_thread[tid], right, left, vn) + vi.accs_by_thread[tid] = map_accumulator!!( + vi.accs_by_thread[tid], accname, func, args... + ) return vi end -function acc!!(vi::ThreadSafeVarInfo, accname, args...) +function map_accumulator!!(vi::ThreadSafeVarInfo, func::Function, args...) tid = Threads.threadid() - vi.accs_by_thread[tid] = acc!!(vi.accs_by_thread[tid], accname, args...) + vi.accs_by_thread[tid] = map_accumulator!!(vi.accs_by_thread[tid], func, args...) return vi end @@ -171,15 +168,16 @@ end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) + return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end -function resetaccs!!(vi::ThreadSafeVarInfo) - vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo) +function resetlogp!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) for i in eachindex(vi.accs_by_thread) - for acc in getaccs(vi.varinfo) - vi.accs_by_thread[i] = setacc!!(vi.accs_by_thread[i], split(acc)) - end + vi.accs_by_thread[i] = map_accumulator!!(vi.accs_by_thread[i], Val(:LogPrior), zero) + vi.accs_by_thread[i] = map_accumulator!!( + vi.accs_by_thread[i], Val(:LogLikelihood), zero + ) end return vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index e4ae7a13a..db9b6376f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -965,7 +965,7 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - vi = resetaccs!!(vi) + vi = resetlogp!!(vi) vi = reset_num_produce!!(vi) return vi end @@ -1022,7 +1022,10 @@ set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduce(n)) Add 1 to `num_produce` in `vi`. """ -increment_num_produce!!(vi::VarInfo) = set_num_produce!!(vi, get_num_produce(vi) + 1) +function increment_num_produce!!(vi::VarInfo) + num_produce = get_num_produce(vi) + return set_num_produce!!(vi, num_produce + oneunit(num_produce)) +end """ reset_num_produce!!(vi::VarInfo) @@ -1030,7 +1033,7 @@ increment_num_produce!!(vi::VarInfo) = set_num_produce!!(vi, get_num_produce(vi) Reset the value of `num_produce` the log of the joint probability of the observed data and parameters sampled in `vi` to 0. """ -reset_num_produce!!(vi::VarInfo) = set_num_produce!!(vi, 0) +reset_num_produce!!(vi::VarInfo) = map_accumulator!!(vi, Val(:NumProduce), zero) # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index e489ae88d..dbef4223a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -208,7 +208,7 @@ end # Reset the logp accumulators. - svi_eval = DynamicPPL.resetaccs!!(svi_eval) + svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) From c59400d8426457cd8f9fdfcf40dcd9e67ae02292 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 16 Apr 2025 14:45:30 +0100 Subject: [PATCH 14/48] Fix type rigidity of LogProbs and NumProduce --- src/abstract_varinfo.jl | 4 ++-- src/accumulators.jl | 13 +++++++------ src/varinfo.jl | 5 +---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index e86525298..1ea613466 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -253,7 +253,7 @@ Add `logp` to the value of the log of the prior probability in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). """ function acclogprior!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(vi, Val(:LogPrior), +, logp) + return map_accumulator!!(vi, Val(:LogPrior), +, LogPrior(logp)) end """ @@ -264,7 +264,7 @@ Add `logp` to the value of the log of the likelihood in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). """ function accloglikelihood!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(vi, Val(:LogLikelihood), +, logp) + return map_accumulator!!(vi, Val(:LogLikelihood), +, LogLikelihood(logp)) end """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 6ae3ffcbf..af2979937 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -268,22 +268,23 @@ function combine(acc::NumProduce, acc2::NumProduce) return NumProduce(max(acc.num, acc2.num)) end -Base.:+(acc::LogPrior{T}, logp::T) where {T} = LogPrior(acc.logp + logp) -Base.:+(acc::LogLikelihood{T}, logp::T) where {T} = LogLikelihood(acc.logp + logp) -Base.:+(acc::NumProduce{T}, num::T) where {T} = NumProduce(acc.num + num) +Base.:+(acc1::LogPrior, acc2::LogPrior) = LogPrior(acc1.logp + acc2.logp) +Base.:+(acc1::LogLikelihood, acc2::LogLikelihood) = LogLikelihood(acc1.logp + acc2.logp) +increment(acc::NumProduce) = NumProduce(acc.num + oneunit(acc.num)) + Base.zero(acc::LogPrior) = LogPrior(zero(acc.logp)) Base.zero(acc::LogLikelihood) = LogLikelihood(zero(acc.logp)) Base.zero(acc::NumProduce) = NumProduce(zero(acc.num)) function accumulate_assume!!(acc::LogPrior, val, logjac, vn, right) - return acc + (logpdf(right, val) + logjac) + return acc + LogPrior(logpdf(right, val) + logjac) end accumulate_observe!!(acc::LogPrior, right, left, vn) = acc accumulate_assume!!(acc::LogLikelihood, val, logjac, vn, right) = acc function accumulate_observe!!(acc::LogLikelihood, right, left, vn) - return acc + logpdf(right, left) + return acc + LogLikelihood(logpdf(right, left)) end accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc -accumulate_observe!!(acc::NumProduce, right, left, vn) = acc + 1 +accumulate_observe!!(acc::NumProduce, right, left, vn) = increment(acc) diff --git a/src/varinfo.jl b/src/varinfo.jl index db9b6376f..bdf996c6d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1022,10 +1022,7 @@ set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduce(n)) Add 1 to `num_produce` in `vi`. """ -function increment_num_produce!!(vi::VarInfo) - num_produce = get_num_produce(vi) - return set_num_produce!!(vi, num_produce + oneunit(num_produce)) -end +increment_num_produce!!(vi::VarInfo) = map_accumulator!!(vi, Val(:NumProduce), increment) """ reset_num_produce!!(vi::VarInfo) From 47033ce2591249f0a8d8eded00d6c3befd79d7de Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 17 Apr 2025 15:18:03 +0100 Subject: [PATCH 15/48] Fix uses of getlogp and other assorted issues --- Project.toml | 2 ++ docs/src/api.md | 12 +++++--- ext/DynamicPPLMCMCChainsExt.jl | 12 ++++---- src/DynamicPPL.jl | 9 ++++++ src/abstract_varinfo.jl | 45 ++++++++++++++++++++++++++---- src/accumulators.jl | 12 +++++++- src/contexts.jl | 12 ++++---- src/debug_utils.jl | 10 +++---- src/logdensityfunction.jl | 5 ++-- src/model.jl | 8 +++--- src/simple_varinfo.jl | 12 ++++---- src/submodel_macro.jl | 4 +-- src/threadsafe.jl | 24 ++++++++++++---- src/utils.jl | 49 +++++++++++++-------------------- src/varinfo.jl | 48 ++++++++++++++++++++++---------- test/compiler.jl | 18 ++++++------ test/context_implementations.jl | 5 ++-- test/linking.jl | 12 ++++---- test/model.jl | 4 +-- test/sampler.jl | 8 +++--- test/simple_varinfo.jl | 28 +++++++++---------- test/threadsafe.jl | 22 +++++++-------- test/utils.jl | 10 +++---- test/varinfo.jl | 46 +++++++++++++++++++------------ test/varnamedvector.jl | 4 +-- 25 files changed, 258 insertions(+), 163 deletions(-) diff --git a/Project.toml b/Project.toml index 516dee26e..ef9684043 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -65,6 +66,7 @@ MCMCChains = "6" MacroTools = "0.5.6" Mooncake = "0.4.95" OrderedCollections = "1" +Printf = "1.10" Random = "1.6" Requires = "1" Test = "1.6" diff --git a/docs/src/api.md b/docs/src/api.md index abc2e3016..215c66ddb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -163,7 +163,7 @@ returned(::Model) It is possible to manually increase (or decrease) the accumulated log density from within a model function. ```@docs -@addlogprob! +@addloglikelihood! ``` Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref). @@ -340,9 +340,13 @@ SimpleVarInfo #### Accumulation of log-probabilities ```@docs -getlogp -setlogp!! -acclogp!! +getlogprior +getloglikelihood +getlogjoint +setlogprior!! +setloglikelihood!! +acclogprior!! +accloglikelihood!! resetlogp!! ``` diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7fcbd6a7c..70f0f0182 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,10 +48,10 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. -The `model` passed to `predict` is often different from the one used to generate `chain`. -Typically, the model from which `chain` originated treats certain variables as observed (i.e., -data points), while the model you pass to `predict` may mark these same variables as missing -or unobserved. Calling `predict` then leverages the previously inferred parameter values to +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to simulate what new, unobserved data might look like, given your posterior beliefs. For each parameter configuration in `chain`: @@ -59,7 +59,7 @@ For each parameter configuration in `chain`: 2. Any variables not included in `chain` are sampled from their prior distributions. If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior +the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples @@ -124,7 +124,7 @@ function DynamicPPL.predict( map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), ) - return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) end chain_result = reduce( diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index cfb730f52..b2c79b7fa 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Bijectors using Compat using Distributions using OrderedCollections: OrderedCollections, OrderedDict +using Printf: Printf using AbstractMCMC: AbstractMCMC using ADTypes: ADTypes @@ -50,8 +51,15 @@ export AbstractVarInfo, empty!!, subset, getlogp, + getlogjoint, + getlogprior, + getloglikelihood, setlogp!!, + setlogprior!!, + setloglikelihood!!, acclogp!!, + acclogprior!!, + accloglikelihood!!, resetlogp!!, get_num_produce, set_num_produce!!, @@ -115,6 +123,7 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, + @addloglikelihood!, @submodel, value_iterator_from_chain, check_model, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 1ea613466..e067b6e6b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -129,6 +129,21 @@ function getaccs end Return a boolean for whether `vi` has an accumulator with name `accname`. """ hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) +function hassacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method hasacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use hasacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + acckeys(vi::AbstractVarInfo) + +Return the names of the accumulators in `vi`. +""" +acckeys(vi::AbstractVarInfo) = keys(getaccs(vi)) """ getlogprior(vi::AbstractVarInfo) @@ -203,9 +218,17 @@ end Return the `AbstractAccumulator` of `vi` with name `accname`. """ -function getacc(vi::AbstractVarInfo, accname) +function getacc(vi::AbstractVarInfo, accname::Val) return getacc(getaccs(vi), accname) end +function getacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method getacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use getacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end """ accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) @@ -245,6 +268,18 @@ function map_accumulator!!(vi::AbstractVarInfo, accname::Val, func::Function, ar return setaccs!!(vi, map_accumulator!!(getaccs(vi), accname, func, args...)) end +function map_accumulator!!(vi::AbstractVarInfo, accname::Symbol, func::Function, args...) + return error( + """ + The method + map_accumulator!!(vi::AbstractVarInfo, accname::Symbol, func::Function, args...) + does not exist. For type stability reasons use + map_accumulator!!(vi::AbstractVarInfo, accname::Val, func::Function, args...) + instead. + """ + ) +end + """ acclogprior!!(vi::AbstractVarInfo, logp) @@ -732,8 +767,8 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, y), lp_new) + lp_new = getlogprior(vi) - logjac + vi_new = setlogprior!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end @@ -744,8 +779,8 @@ function invlink!!( y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, x), lp_new) + lp_new = getlogprior(vi) + logjac + vi_new = setlogprior!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end diff --git a/src/accumulators.jl b/src/accumulators.jl index af2979937..39de37ad9 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -132,10 +132,13 @@ end AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) +# When showing with text/plain, leave out information about the wrapper AccumulatorTuple. +Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] Base.length(::AccumulatorTuple{N}) where {N} = N Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} = haskey(at.nt, accname) +Base.keys(at::AccumulatorTuple) = keys(at.nt) """ setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) @@ -254,6 +257,10 @@ Create a new `NumProduce` accumulator with the number of observations initialize """ NumProduce{T}() where {T} = NumProduce(zero(T)) +Base.show(io::IO, acc::LogPrior) = print(io, "LogPrior($(repr(acc.logp)))") +Base.show(io::IO, acc::LogLikelihood) = print(io, "LogLikelihood($(repr(acc.logp)))") +Base.show(io::IO, acc::NumProduce) = print(io, "NumProduce($(repr(acc.num)))") + accumulator_name(::Type{<:LogPrior}) = :LogPrior accumulator_name(::Type{<:LogLikelihood}) = :LogLikelihood accumulator_name(::Type{<:NumProduce}) = :NumProduce @@ -283,7 +290,10 @@ accumulate_observe!!(acc::LogPrior, right, left, vn) = acc accumulate_assume!!(acc::LogLikelihood, val, logjac, vn, right) = acc function accumulate_observe!!(acc::LogLikelihood, right, left, vn) - return acc + LogLikelihood(logpdf(right, left)) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acc + LogLikelihood(Distributions.loglikelihood(right, left)) end accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc diff --git a/src/contexts.jl b/src/contexts.jl index 941ccd75e..cede1c672 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -46,15 +46,17 @@ effectively updating the child context. # Examples ```jldoctest +julia> using DynamicPPL: DynamicTransformationContext + julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); julia> DynamicPPL.childcontext(ctx_prior) -PriorContext() +DynamicTransformationContext{true}() ``` """ setchildcontext @@ -79,7 +81,7 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext julia> struct ParentContext{C} <: AbstractContext context::C @@ -97,8 +99,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext() + leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) +DynamicTransformationContext{true}() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 505383e24..bed818fdf 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -88,8 +88,7 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, " ") print(io, RESULT_SYMBOL) print(io, " ") - print(io, stmt.value) - return print(io, ")") + return print(io, stmt.value) end Base.@kwdef struct ObserveStmt <: Stmt @@ -103,8 +102,7 @@ function Base.show(io::IO, stmt::ObserveStmt) print(io, "observe: ") show_right(io, stmt.left) print(io, " ~ ") - show_right(io, stmt.right) - return print(io, ")") + return show_right(io, stmt.right) end # Some utility methods for extracting information from a trace. @@ -397,7 +395,7 @@ julia> issuccess true julia> print(trace) - assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356) + assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); @@ -405,7 +403,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) (logprob = -1.41894) +observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..76626dd3f 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,4 +1,5 @@ import DifferentiationInterface as DI +# TOOD(mhauru) Rework this file to use LogPrior and LogLikelihood. """ is_supported(adtype::AbstractADType) @@ -175,13 +176,13 @@ end Evaluate the log density of the given `model` at the given parameter values `x`, using the given `varinfo` and `context`. Note that the `varinfo` argument is provided only for its structure, in the sense that the parameters from the vector `x` are inserted into -it, and its own parameters are discarded. +it, and its own parameters are discarded. """ function logdensity_at( x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) varinfo_new = unflatten(varinfo, x) - return getlogp(last(evaluate!!(model, varinfo_new, context))) + return getlogjoint(last(evaluate!!(model, varinfo_new, context))) end ### LogDensityProblems interface diff --git a/src/model.jl b/src/model.jl index 2431935f0..d7303cb17 100644 --- a/src/model.jl +++ b/src/model.jl @@ -899,7 +899,7 @@ See also: [`evaluate_threadunsafe!!`](@ref) function evaluate_threadsafe!!(model, varinfo, context) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper, context) - return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) end """ @@ -1009,7 +1009,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogjoint(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1357,7 +1357,7 @@ We can check that the log joint probability of the model accumulated in `vi` is ```jldoctest submodel-to_submodel julia> x = vi[@varname(a.x)]; -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true ``` @@ -1421,7 +1421,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); -julia> getlogp(vi) ≈ logprior + loglikelihood +julia> getlogjoint(vi) ≈ logprior + loglikelihood true ``` diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index fce1c7d44..dccb24cbb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -125,18 +125,18 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), 0.0) +Transformed SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihood(0.0), LogPrior = LogPrior(0.0))) julia> # (✓) Positive probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), 0.0) +SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihood(0.0), LogPrior = LogPrior(0.0))) julia> # (✓) No probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -Inf ``` @@ -267,12 +267,12 @@ Return an iterator of keys present in `vi`. Base.keys(vi::SimpleVarInfo) = keys(vi.values) Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) -function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) +function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transformation isa NoTransformation) print(io, "Transformed ") end - return print(io, "SimpleVarInfo(", svi.values, ", ", getaccs(svi), ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") end function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f6b9c4479..799fcf011 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -45,7 +45,7 @@ We can check that the log joint probability of the model accumulated in `vi` is ```jldoctest submodel julia> x = vi[@varname(x)]; -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true ``` """ @@ -124,7 +124,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); -julia> getlogp(vi) ≈ logprior + loglikelihood +julia> getlogjoint(vi) ≈ logprior + loglikelihood true ``` diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6f8bc58e6..9155f5da5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -28,12 +28,24 @@ end # Get both the main accumulator and the thread-specific accumulators of the same type and # combine them. -function getacc(vi::ThreadSafeVarInfo, accname) +function getacc(vi::ThreadSafeVarInfo, accname::Val) main_acc = getacc(vi.varinfo, accname) other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) return foldl(combine, other_accs; init=main_acc) end +hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) +acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) + +function getaccs(vi::ThreadSafeVarInfo) + # This method is a bit finicky to maintain type stability. For instance, moving the + # accname -> Val(accname) part in the main `map` call makes constant propagation fail + # and this becomes unstable. Do check the effects if you make edits. + accnames = acckeys(vi) + accname_vals = map(Val, accnames) + return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) +end + # Calls to map_accumulator!! are thread-specific by default. For any use of them that should # _not_ be thread-specific a specific method has to be written. function map_accumulator!!(vi::ThreadSafeVarInfo, accname::Val, func::Function, args...) @@ -96,8 +108,8 @@ end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates -# to define `getlogp(vi)`. +# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates +# to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end @@ -132,9 +144,9 @@ end function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the - # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogp(vi)`. + # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the + # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogprior(vi)`. return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end diff --git a/src/utils.jl b/src/utils.jl index 398fd499b..aaa36ad90 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,7 +18,21 @@ const LogProbType = float(Real) """ @addlogprob!(ex) -Add the result of the evaluation of `ex` to the joint log probability. +A deprecated alias for `@addloglikelihood!`. +""" +macro addlogprob!(ex) + return quote + depwarn( + "`@addlogprob!` is deprecated, use `@addloglikelihood!` instead.", :addlogprob! + ) + $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) + end +end + +""" + @addloglikelihood!(ex) + +Add the result of the evaluation of `ex` to the joint log prior probability. # Examples @@ -29,7 +43,7 @@ julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); julia> @model function demo(x) μ ~ Normal() - @addlogprob! myloglikelihood(x, μ) + @addloglikelihood! myloglikelihood(x, μ) end; julia> x = [1.3, -2.1]; @@ -44,7 +58,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): julia> @model function demo(x) m ~ MvNormal(zero(x), I) if dot(m, x) < 0 - @addlogprob! -Inf + @addloglikelihood! -Inf # Exit the model evaluation early return end @@ -55,35 +69,10 @@ julia> @model function demo(x) julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf true ``` - -!!! note - The `@addlogprob!` macro increases the accumulated log probability regardless of the evaluation context, - i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. - If you would like to avoid this behaviour you should check the evaluation context. - It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: - ```jldoctest; setup = :(using Distributions) - julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); - - julia> @model function demo(x) - μ ~ Normal() - if DynamicPPL.leafcontext(__context__) !== PriorContext() - @addlogprob! myloglikelihood(x, μ) - end - end; - - julia> x = [1.3, -2.1]; - - julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) - true - - julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) - true - ``` """ -macro addlogprob!(ex) +macro addloglikelihood!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) end end diff --git a/src/varinfo.jl b/src/varinfo.jl index bdf996c6d..3dfe32e83 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1738,19 +1738,35 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) end function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - vi_str = """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - | Varnames : $(string(vi.metadata.vns)) - | Range : $(vi.metadata.ranges) - | Vals : $(vi.metadata.vals) - | Orders : $(vi.metadata.orders) - | Logp : $(getlogp(vi)) - | #produce : $(get_num_produce(vi)) - | flags : $(vi.metadata.flags) - \\======================================================================= - """ + lines = Tuple{String,Any}[ + ("VarNames", vi.metadata.vns), + ("Range", vi.metadata.ranges), + ("Vals", vi.metadata.vals), + ("Orders", vi.metadata.orders), + ] + for accname in acckeys(vi) + push!(lines, (string(accname), getacc(vi, Val(accname)))) + end + push!(lines, ("flags", vi.metadata.flags)) + max_name_length = maximum(map(length ∘ first, lines)) + fmt = Printf.Format("%-$(max_name_length)s") + vi_str = ( + """ + /======================================================================= + | VarInfo + |----------------------------------------------------------------------- + """ * + prod( + map(lines) do (name, value) + """ + | $(Printf.format(fmt, name)) : $(value) + """ + end, + ) * + """ + \\======================================================================= + """ + ) return print(io, vi_str) end @@ -1780,7 +1796,11 @@ end function Base.show(io::IO, vi::UntypedVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) - print(io, "; logp: ", round(getlogp(vi); digits=3)) + print(io, "; accumulators: ") + # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation + # of vi anyway. However, technically `show(io, x)` should give full details of x and + # preferably output valid Julia code. + show(io, MIME"text/plain"(), getaccs(vi)) return print(io, ")") end diff --git a/test/compiler.jl b/test/compiler.jl index a0286d405..81c018111 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -189,12 +189,12 @@ module Issue537 end global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end model = testmodel_missing3([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model @test context_ isa SamplingContext @@ -208,13 +208,13 @@ module Issue537 end global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end false lpold = lp model = testmodel_missing4([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp == lpold + @test getlogjoint(varinfo) == lp == lpold # test DPPL#61 @model function testmodel_missing5(z) @@ -333,14 +333,14 @@ module Issue537 end function makemodel(p) @model function testmodel(x) x[1] ~ Bernoulli(p) - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end return testmodel end model = makemodel(0.5)([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @@ -364,9 +364,9 @@ module Issue537 end # TODO(torfjelde): We need conditioning for `Dict`. @test_broken f2_c() == 1 @test_broken f3_c() == 1 - @test_broken getlogp(VarInfo(f1_c)) == - getlogp(VarInfo(f2_c)) == - getlogp(VarInfo(f3_c)) + @test_broken getlogjoint(VarInfo(f1_c)) == + getlogjoint(VarInfo(f2_c)) == + getlogjoint(VarInfo(f3_c)) end @testset "custom tilde" begin @model demo() = begin diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 7e1397ae3..ac6321d69 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -18,13 +18,14 @@ @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) y .~ Normal(x) - return y, getlogp(__varinfo__) + return y end for ysize in ((2,), (2, 3), (2, 3, 4)) x = randn() model = test(x, ysize) - y, lp = model() + y = model() + lp = logjoint(model, (; y=y)) @test lp ≈ sum(logpdf.(Normal.(x), y)) ys = [first(model()) for _ in 1:10_000] diff --git a/test/linking.jl b/test/linking.jl index d424a9c2d..4f1707263 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -85,7 +85,7 @@ end DynamicPPL.link(vi, model) end # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) + @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -98,7 +98,7 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) - @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) + @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) end end @@ -130,7 +130,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +138,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end @@ -164,7 +164,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +172,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end diff --git a/test/model.jl b/test/model.jl index dd5a35fe6..6e4a24ae6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -41,7 +41,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() m = vi[@varname(m)] # extract log pdf of variable object - lp = getlogp(vi) + lp = getlogjoint(vi) # log prior probability lprior = logprior(model, vi) @@ -494,7 +494,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() varinfo_linked_result = last( DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) ) - @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) end end diff --git a/test/sampler.jl b/test/sampler.jl index 8c4f1ed96..fe9fd331a 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,7 +84,7 @@ let inits = (; p=0.2) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -98,7 +98,7 @@ ) for c in chains @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end @@ -113,7 +113,7 @@ chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -128,7 +128,7 @@ for c in chains @test c[1].metadata.s.vals == [4] @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index dbef4223a..88c3edc71 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -2,12 +2,12 @@ @testset "constructor & indexing" begin @testset "NamedTuple" begin svi = SimpleVarInfo(; m=1.0) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(; m=[1.0]) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -21,21 +21,21 @@ @test !haskey(svi, @varname(m.a.b)) svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogp(svi) isa Float32 + @test getlogjoint(svi) isa Float32 svi = SimpleVarInfo((m=1.0,)) - svi = acclogp!!(svi, 1.0) - @test getlogp(svi) == 1.0 + svi = accloglikelihood!!(svi, 1.0) + @test getlogjoint(svi) == 1.0 end @testset "Dict" begin svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -60,12 +60,12 @@ @testset "VarNamedVector" begin svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -102,7 +102,7 @@ # `link!!` vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogp(vi_linked) + lp_linked = getlogjoint(vi_linked) values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_constrained... ) @@ -119,7 +119,7 @@ # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogp(vi_invlinked) + lp_invlinked = getlogjoint(vi_invlinked) lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( model, values_constrained... ) @@ -172,7 +172,7 @@ end # Logjoint should be non-zero wp. 1. - @test getlogp(svi_new) != 0 + @test getlogjoint(svi_new) != 0 ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @@ -256,7 +256,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) + lp = getlogjoint(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -312,7 +312,7 @@ DynamicPPL.tovec(retval_unconstrained.m) # The resulting varinfo should hold the correct logp. - lp = getlogp(vi_linked_result) + lp = getlogjoint(vi_linked_result) @test lp ≈ lp_true end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 12df5706a..2fa84bad8 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -17,23 +17,23 @@ vi = VarInfo(gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - lp = getlogp(vi) - @test getlogp(threadsafe_vi) == lp + lp = getlogjoint(vi) + @test getlogjoint(threadsafe_vi) == lp threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 - @test getlogp(vi) == lp - @test getlogp(threadsafe_vi) == lp + 42 + @test getlogjoint(vi) == lp + @test getlogjoint(threadsafe_vi) == lp + 42 threadsafe_vi = resetlogp!!(threadsafe_vi) - @test iszero(getlogp(threadsafe_vi)) + @test iszero(getlogjoint(threadsafe_vi)) expected_accs = DynamicPPL.AccumulatorTuple( (DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - threadsafe_vi = setlogp!!(threadsafe_vi, 42) - @test getlogp(threadsafe_vi) == 42 + threadsafe_vi = setlogprior!!(threadsafe_vi, 42) + @test getlogjoint(threadsafe_vi) == 42 expected_accs = DynamicPPL.AccumulatorTuple( (DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)... ) @@ -55,7 +55,7 @@ vi = VarInfo() wthreads(x)(vi) - lp_w_threads = getlogp(vi) + lp_w_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -72,7 +72,7 @@ vi, SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), ) - @test getlogp(vi) ≈ lp_w_threads + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe!!:") @@ -92,7 +92,7 @@ vi = VarInfo() wothreads(x)(vi) - lp_wo_threads = getlogp(vi) + lp_wo_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -111,7 +111,7 @@ vi, SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), ) - @test getlogp(vi) ≈ lp_w_threads + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe!!:") diff --git a/test/utils.jl b/test/utils.jl index d683f132d..196effdf0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,15 +1,15 @@ @testset "utils.jl" begin - @testset "addlogprob!" begin + @testset "addloglikelihood!" begin @model function testmodel() - global lp_before = getlogp(__varinfo__) - @addlogprob!(42) - return global lp_after = getlogp(__varinfo__) + global lp_before = getlogjoint(__varinfo__) + @addloglikelihood!(42) + return global lp_after = getlogjoint(__varinfo__) end model = testmodel() varinfo = VarInfo(model) @test iszero(lp_before) - @test getlogp(varinfo) == lp_after == 42 + @test getlogjoint(varinfo) == lp_after == 42 end @testset "getargs_dottilde" begin diff --git a/test/varinfo.jl b/test/varinfo.jl index 1ca3308aa..4c0dea5f7 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -80,7 +80,7 @@ end function test_base!!(vi_original) vi = empty!!(vi_original) - @test getlogp(vi) == 0 + @test getlogjoint(vi) == 0 @test isempty(vi[:]) vn = @varname x @@ -123,13 +123,25 @@ end @testset "get/set/acc/resetlogp" begin function test_varinfo_logp!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 - vi = DynamicPPL.setlogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 1.0 - vi = DynamicPPL.acclogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 vi = DynamicPPL.resetlogp!!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 end vi = VarInfo() @@ -460,7 +472,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) @@ -469,7 +481,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` @@ -478,7 +490,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) @@ -486,7 +498,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) @@ -494,7 +506,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @testset "values_as" begin @@ -596,8 +608,8 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true - @test getlogp(varinfo) ≈ lp_true - lp_linked = getlogp(varinfo_linked) + @test getlogjoint(varinfo) ≈ lp_true + lp_linked = getlogjoint(varinfo_linked) @test lp_linked ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for @@ -609,7 +621,7 @@ end varinfo_linked_unflattened, model ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) - @test getlogp(varinfo_invlinked) ≈ lp_true + @test getlogjoint(varinfo_invlinked) ≈ lp_true end end end @@ -1017,8 +1029,8 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) # `Int`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) end end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index bd3f5553f..f21d458a8 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -607,7 +607,7 @@ end DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) ) # Log density should be the same. - @test getlogp(varinfo_eval) ≈ logp_true + @test getlogjoint(varinfo_eval) ≈ logp_true # Values should be the same. DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) @@ -616,7 +616,7 @@ end DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) ) # Log density should be different. - @test getlogp(varinfo_sample) != getlogp(varinfo) + @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. DynamicPPL.TestUtils.test_values( varinfo_sample, value_true, vns; compare=!isequal From 8b841c98bfa098bc61a235b7404c0fc700ecf9ba Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 17 Apr 2025 15:35:22 +0100 Subject: [PATCH 16/48] setaccs!! nicer interface and logdensity function fixes --- src/DynamicPPL.jl | 3 +++ src/abstract_varinfo.jl | 9 +++++++-- src/accumulators.jl | 3 +++ src/logdensityfunction.jl | 20 +++++++++++++++----- src/simple_varinfo.jl | 2 +- src/varinfo.jl | 2 +- 6 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b2c79b7fa..455cc5a29 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -47,6 +47,9 @@ import Base: export AbstractVarInfo, VarInfo, SimpleVarInfo, + LogLikelihood, + LogPrior, + NumProduce, push!!, empty!!, subset, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index e067b6e6b..8fcdd70ce 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -105,14 +105,19 @@ end """ setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple) + setaccs!!(vi::AbstractVarInfo, accs::AbstractAccumulator...) Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense. -This should be implemented by each subtype of `AbstractVarInfo`. `setaccs!!` is not -user-facing, but used in the implementation of many other functions. +`setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype +of `AbstractVarInfo`. """ function setaccs!! end +function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N} + return setaccs!!(vi, AccumulatorTuple(accs)) +end + """ getaccs(vi::AbstractVarInfo) diff --git a/src/accumulators.jl b/src/accumulators.jl index 39de37ad9..a1f019610 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -218,6 +218,7 @@ end Create a new `LogPrior` accumulator with the log prior initialized to zero. """ LogPrior{T}() where {T} = LogPrior(zero(T)) +LogPrior() = LogPrior{LogProbType}() """ LogLikelihood{T} <: AbstractAccumulator @@ -237,6 +238,7 @@ end Create a new `LogLikelihood` accumulator with the log likelihood initialized to zero. """ LogLikelihood{T}() where {T} = LogLikelihood(zero(T)) +LogLikelihood() = LogLikelihood{LogProbType}() """ NumProduce{T} <: AbstractAccumulator @@ -256,6 +258,7 @@ end Create a new `NumProduce` accumulator with the number of observations initialized to zero. """ NumProduce{T}() where {T} = NumProduce(zero(T)) +NumProduce() = NumProduce{Int}() Base.show(io::IO, acc::LogPrior) = print(io, "LogPrior($(repr(acc.logp)))") Base.show(io::IO, acc::LogLikelihood) = print(io, "LogLikelihood($(repr(acc.logp)))") diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 76626dd3f..97c9ed2fd 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,5 +1,4 @@ import DifferentiationInterface as DI -# TOOD(mhauru) Rework this file to use LogPrior and LogLikelihood. """ is_supported(adtype::AbstractADType) @@ -52,7 +51,7 @@ $(FIELDS) ```jldoctest julia> using Distributions -julia> using DynamicPPL: LogDensityFunction, contextualize +julia> using DynamicPPL: LogDensityFunction, setaccs!! julia> @model function demo(x) m ~ Normal() @@ -79,8 +78,8 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary. julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # This also respects the context in `model`. - f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); +julia> # LogDensityFunction respects the accumulators in VarInfo: + f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPrior(),))); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -182,7 +181,18 @@ function logdensity_at( x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) varinfo_new = unflatten(varinfo, x) - return getlogjoint(last(evaluate!!(model, varinfo_new, context))) + varinfo_eval = last(evaluate!!(model, varinfo_new, context)) + has_prior = hasacc(varinfo_eval, Val(:LogPrior)) + has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) + if has_prior && has_likelihood + return getlogjoint(varinfo_eval) + elseif has_prior + return getlogprior(varinfo_eval) + elseif has_likelihood + return getloglikelihood(varinfo_eval) + else + error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") + end end ### LogDensityProblems interface diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index dccb24cbb..c0f85fee4 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -257,7 +257,7 @@ end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) getaccs(vi::SimpleVarInfo) = vi.accs -setaccs!!(vi::SimpleVarInfo, accs) = Accessors.@set vi.accs = accs +setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs """ keys(vi::SimpleVarInfo) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3dfe32e83..98882c423 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1001,7 +1001,7 @@ istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") getaccs(vi::VarInfo) = vi.accs -setaccs!!(vi::VarInfo, accs) = Accessors.@set vi.accs = accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs """ get_num_produce(vi::VarInfo) From 3ee398994d46dcbf1e7781f1911816d18f36a930 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 22 Apr 2025 16:37:08 +0100 Subject: [PATCH 17/48] Revert back to calling the macro @addlogprob! --- docs/src/api.md | 2 +- src/DynamicPPL.jl | 1 - src/utils.jl | 22 ++++------------------ test/utils.jl | 4 ++-- 4 files changed, 7 insertions(+), 22 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 215c66ddb..5f9ecdc9f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -163,7 +163,7 @@ returned(::Model) It is possible to manually increase (or decrease) the accumulated log density from within a model function. ```@docs -@addloglikelihood! +@addlogprob! ``` Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 455cc5a29..4b912e751 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -126,7 +126,6 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, - @addloglikelihood!, @submodel, value_iterator_from_chain, check_model, diff --git a/src/utils.jl b/src/utils.jl index aaa36ad90..385145fc5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,21 +18,7 @@ const LogProbType = float(Real) """ @addlogprob!(ex) -A deprecated alias for `@addloglikelihood!`. -""" -macro addlogprob!(ex) - return quote - depwarn( - "`@addlogprob!` is deprecated, use `@addloglikelihood!` instead.", :addlogprob! - ) - $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) - end -end - -""" - @addloglikelihood!(ex) - -Add the result of the evaluation of `ex` to the joint log prior probability. +Add the result of the evaluation of `ex` to the log likelihood. # Examples @@ -43,7 +29,7 @@ julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); julia> @model function demo(x) μ ~ Normal() - @addloglikelihood! myloglikelihood(x, μ) + @addlogprob! myloglikelihood(x, μ) end; julia> x = [1.3, -2.1]; @@ -58,7 +44,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): julia> @model function demo(x) m ~ MvNormal(zero(x), I) if dot(m, x) < 0 - @addloglikelihood! -Inf + @addlogprob! -Inf # Exit the model evaluation early return end @@ -70,7 +56,7 @@ julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf true ``` """ -macro addloglikelihood!(ex) +macro addlogprob!(ex) return quote $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) end diff --git a/test/utils.jl b/test/utils.jl index 196effdf0..4aa2d9943 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,8 +1,8 @@ @testset "utils.jl" begin - @testset "addloglikelihood!" begin + @testset "addlogprob!" begin @model function testmodel() global lp_before = getlogjoint(__varinfo__) - @addloglikelihood!(42) + @addlogprob!(42) return global lp_after = getlogjoint(__varinfo__) end From 13163f29dcbfcdf0017d0b50b6a8cbe298d1baba Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 22 Apr 2025 16:43:51 +0100 Subject: [PATCH 18/48] Remove a dead test --- src/test_utils/contexts.jl | 30 ------------------------------ test/pointwise_logdensities.jl | 7 ------- 2 files changed, 37 deletions(-) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 46b4e477d..08acdfada 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -3,36 +3,6 @@ # # Utilities for testing contexts. -""" -Context that multiplies each log-prior by mod -used to test whether pointwise_logpriors respects child-context. -""" -struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext - mod::T - context::Ctx -end -function TestLogModifyingChildContext( - mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() -) - return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) -end - -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context -function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) - return TestLogModifyingChildContext(context.mod, child) -end -function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) - value, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, vi -end -function DynamicPPL.tilde_observe!!( - context::TestLogModifyingChildContext, right, left, vn, vi -) - vi = DynamicPPL.tilde_observe!!(context.context, right, left, vn, vi) - return vi -end - # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext context::C diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index bb64d072f..3406a58e1 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,6 +1,4 @@ @testset "logdensities_likelihoods.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) - mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS[1:1] example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -37,11 +35,6 @@ lps = pointwise_logdensities(model, vi) logp = sum(sum, values(lps)) @test logp ≈ (logprior_true + loglikelihood_true) - - # Test that modifications of Setup are picked up - lps = pointwise_logdensities(model, vi, mod_ctx2) - logp = sum(sum, values(lps)) - @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4 end end From 37dd6ddfb51fccdfbf4b0ac5ae83dc6ca8be8664 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 22 Apr 2025 16:44:01 +0100 Subject: [PATCH 19/48] Clarify a comment --- src/logdensityfunction.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 97c9ed2fd..a3eea9360 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -174,8 +174,9 @@ end Evaluate the log density of the given `model` at the given parameter values `x`, using the given `varinfo` and `context`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` are inserted into -it, and its own parameters are discarded. +only for its structure, in the sense that the parameters from the vector `x` are inserted +into it, and its own parameters are discarded. It does, however, determine whether the log +prior, likelihood, or joint is returned, based on which accumulators are set in it. """ function logdensity_at( x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext From d7013b61a5abd9e0f701a0cd7d270f1ee4729a82 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 22 Apr 2025 18:01:07 +0100 Subject: [PATCH 20/48] Implement split/combine for PointwiseLogdensityAccumulator --- src/pointwise_logdensities.jl | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 256dc2b47..e608e523d 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,8 +1,27 @@ +""" + PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator + +An accumulator that stores the log-probabilities of each variable in a model. + +Internally this context stores the log-probabilities in a dictionary, where the keys are +the variable names and the values are vectors of log-probabilities. Each element in a vector +corresponds to one execution of the model. + +`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies +which log-probabilities to store in the accumulator. `KeyType` is the type by which variable +names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary +used internally to store the log-probabilities, by default +`OrderedDict{KeyType, Vector{LogProbType}}`. +""" struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator logps::D end +function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) +end + function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} return PointwiseLogProbAccumulator{whichlogprob,VarName}() end @@ -32,10 +51,16 @@ function accumulator_name( return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -# TODO(mhauru) Implement these to make PointwiseLogProbAccumulator work with -# ThreadSafeVarInfo. -# split(::LogPrior{T}) where {T} = LogPrior(zero(T)) -# combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp) +function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) +end + +function combine( + acc::PointwiseLogProbAccumulator{whichlogprob}, + acc2::PointwiseLogProbAccumulator{whichlogprob}, +) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) +end function accumulate_assume!!( acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right From 40d4caa748f996f5824db004aeae6d0cc95f335b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 22 Apr 2025 18:02:04 +0100 Subject: [PATCH 21/48] Switch ThreadSafeVarInfo.accs_by_thread to be a tuple --- src/threadsafe.jl | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9155f5da5..311663574 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -4,14 +4,15 @@ A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an array of accumulators for thread-safe execution of a probabilistic model. """ -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo +struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:NTuple{N,AccumulatorTuple} where {N}} <: + AbstractVarInfo varinfo::V - accs_by_thread::Vector{L} + accs_by_thread::L end + function ThreadSafeVarInfo(vi::AbstractVarInfo) - accs_by_thread = [ - AccumulatorTuple(map(split, vi.accs.nt)) for _ in 1:Threads.nthreads() - ] + split_accs = AccumulatorTuple(map(split, vi.accs.nt)) + accs_by_thread = ntuple(_ -> deepcopy(split_accs), Val(Threads.nthreads())) return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi @@ -50,16 +51,17 @@ end # _not_ be thread-specific a specific method has to be written. function map_accumulator!!(vi::ThreadSafeVarInfo, accname::Val, func::Function, args...) tid = Threads.threadid() - vi.accs_by_thread[tid] = map_accumulator!!( - vi.accs_by_thread[tid], accname, func, args... - ) - return vi + new_accs = map_accumulator!!(vi.accs_by_thread[tid], accname, func, args...) + accs_by_thread = BangBang.setindex!!(vi.accs_by_thread, new_accs, tid) + return ThreadSafeVarInfo(vi.varinfo, accs_by_thread) end function map_accumulator!!(vi::ThreadSafeVarInfo, func::Function, args...) tid = Threads.threadid() - vi.accs_by_thread[tid] = map_accumulator!!(vi.accs_by_thread[tid], func, args...) - return vi + new_accs = map_accumulator!!(vi.accs_by_thread[tid], func, args...) + # We need to use setindex!! in case new_accs has a different type than the original. + accs_by_thread = BangBang.setindex!!(vi.accs_by_thread, new_accs, tid) + return ThreadSafeVarInfo(vi.varinfo, accs_by_thread) end has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) @@ -185,13 +187,12 @@ end function resetlogp!!(vi::ThreadSafeVarInfo) vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) - for i in eachindex(vi.accs_by_thread) - vi.accs_by_thread[i] = map_accumulator!!(vi.accs_by_thread[i], Val(:LogPrior), zero) - vi.accs_by_thread[i] = map_accumulator!!( - vi.accs_by_thread[i], Val(:LogLikelihood), zero - ) + accs_by_thread = map(vi.accs_by_thread) do accs + new_accs = map_accumulator!!(accs, Val(:LogPrior), zero) + new_accs = map_accumulator!!(new_accs, Val(:LogLikelihood), zero) + new_accs end - return vi + return ThreadSafeVarInfo(vi.varinfo, accs_by_thread) end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) From ff5f2cba98aecac764288267f059c43e162729cb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Apr 2025 12:09:31 +0100 Subject: [PATCH 22/48] Fix `condition` and `fix` in submodels (#892) * Fix conditioning in submodels * Simplify contextual_isassumption * Add documentation * Fix some tests * Add tests; fix a bunch of nested submodel issues * Fix fix as well * Fix doctests * Add unit tests for new functions * Add changelog entry * Update changelog Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Finish docs * Add a test for conditioning submodel via arguments * Clean new tests up a bit * Fix for VarNames with non-identity lenses * Apply suggestions from code review Co-authored-by: Markus Hauru * Apply suggestions from code review * Make PrefixContext contain a varname rather than symbol (#896) --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Markus Hauru --- HISTORY.md | 98 +++++-- docs/Project.toml | 1 + docs/make.jl | 4 +- docs/src/api.md | 14 +- docs/src/internals/submodel_condition.md | 356 +++++++++++++++++++++++ src/compiler.jl | 60 ++-- src/context_implementations.jl | 42 ++- src/contexts.jl | 211 ++++++++++++-- src/model.jl | 71 ++--- src/submodel_macro.jl | 4 +- src/utils.jl | 9 +- test/contexts.jl | 273 +++++++++++------ test/runtests.jl | 1 + test/submodels.jl | 199 +++++++++++++ 14 files changed, 1109 insertions(+), 234 deletions(-) create mode 100644 docs/src/internals/submodel_condition.md create mode 100644 test/submodels.jl diff --git a/HISTORY.md b/HISTORY.md index a45644a64..ac3e40970 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,38 +4,25 @@ **Breaking changes** -### AD testing utilities - -`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. -To disable this, pass the `linked=false` keyword argument. -If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. -This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. -From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. +### Submodels: conditioning -### SimpleVarInfo linking / invlinking - -Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. +Variables in a submodel can now be conditioned and fixed in a correct way. +See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this: -### VarInfo constructors - -`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. - -The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. -If you were not using this argument (most likely), then there is no change needed. -If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). - -The `UntypedVarInfo` constructor and type is no longer exported. -If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. - -The `TypedVarInfo` constructor and type is no longer exported. -The _type_ has been replaced with `DynamicPPL.NTVarInfo`. -The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. +```julia +@model function inner() + x ~ Normal() + return y ~ Normal() +end +@model function outer() + return a ~ to_submodel(inner() | (x=1.0,)) +end +``` -Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. -Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. -Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. +and the `a.x` variable will be correctly conditioned. +(Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.) -### VarName prefixing behaviour +### Submodel prefixing The way in which VarNames in submodels are prefixed has been changed. This is best explained through an example. @@ -77,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,) outer() | (a.x=1.0,) ``` -If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected. +Consider the following setup: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model function outer() + a = Vector{Float64}(undef, 1) + a[1] ~ to_submodel(inner()) + return a +end +``` + +In this case, the variable sampled is actually the `x` field of the first element of `a`: + +```julia +julia> only(keys(VarInfo(outer()))) == @varname(a[1].x) +true +``` + +Before this version, it used to be a single variable called `var"a[1].x"`. + +Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + +### VarInfo constructors + +`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. + +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + **Other changes** While these are technically breaking, they are only internal changes and do not affect the public API. diff --git a/docs/Project.toml b/docs/Project.toml index 40a719e03..93f449308 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..7984fa1d1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,9 @@ makedocs(; format=Documenter.HTML(; size_threshold=2^10 * 400), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ - "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] + "Home" => "index.md", + "API" => "api.md", + "Internals" => ["internals/varinfo.md", "internals/submodel_condition.md"], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index ec741c9ad..08522e2ce 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -78,9 +78,9 @@ decondition ## Fixing and unfixing -We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref). +We can also _fix_ a collection of variables in a [`Model`](@ref) to certain values using [`DynamicPPL.fix`](@ref). -This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings, +This is quite similar to the aforementioned [`condition`](@ref) and its siblings, but they are indeed different operations: - `condition`ed variables are considered to be _observations_, and are thus @@ -89,19 +89,19 @@ but they are indeed different operations: - `fix`ed variables are considered to be _constant_, and are thus not included in any log-probability computations. -The differences are more clearly spelled out in the docstring of [`fix`](@ref) below. +The differences are more clearly spelled out in the docstring of [`DynamicPPL.fix`](@ref) below. ```@docs -fix +DynamicPPL.fix DynamicPPL.fixed ``` -The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above. +The difference between [`DynamicPPL.fix`](@ref) and [`DynamicPPL.condition`](@ref) is described in the docstring of [`DynamicPPL.fix`](@ref) above. -Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning: +Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the variables to their original meaning: ```@docs -unfix +DynamicPPL.unfix ``` ## Predicting diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md new file mode 100644 index 000000000..ecb9d452b --- /dev/null +++ b/docs/src/internals/submodel_condition.md @@ -0,0 +1,356 @@ +# How `PrefixContext` and `ConditionContext` interact + +```@meta +ShareDefaultModule = true +``` + +## PrefixContext + +`PrefixContext` is a context that, as the name suggests, prefixes all variables inside a model with a given symbol. +Thus, for example: + +```@example +using DynamicPPL, Distributions + +@model function f() + x ~ Normal() + return y ~ Normal() +end + +@model function g() + return a ~ to_submodel(f()) +end +``` + +inside the submodel `f`, the variables `x` and `y` become `a.x` and `a.y` respectively. +This is easiest to observe by running the model: + +```@example +vi = VarInfo(g()) +keys(vi) +``` + +!!! note + + In this case, where `to_submodel` is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde. + We will return to the 'manual prefixing' case later. + +The phrase 'becoming' a different variable is a little underspecified: it is useful to pinpoint the exact location where the prefixing occurs, which is `tilde_assume`. +The method responsible for it is `tilde_assume(::PrefixContext, right, vn, vi)`: this attaches the prefix in the context to the `VarName` argument, before recursively calling `tilde_assume` with the new prefixed `VarName`. +This means that even though a statement `x ~ dist` still enters the tilde pipeline at the top level as `x`, if the model evaluation context contains a `PrefixContext`, any function from `tilde_assume` onwards will see `a.x` instead. + +## ConditionContext + +`ConditionContext` is a context which stores values of variables that are to be conditioned on. +These values may be stored as a `Dict` which maps `VarName`s to values, or alternatively as a `NamedTuple`. +The latter only works correctly if all `VarName`s are 'basic', in that they have an identity optic (i.e., something like `a.x` or `a[1]` is forbidden). +Because of this limitation, we will only use `Dict` in this example. + +!!! note + + If a `ConditionContext` with a `NamedTuple` encounters anything to do with a prefix, its internal `NamedTuple` is converted to a `Dict` anyway, so it is quite reasonable to ignore the `NamedTuple` case in this exposition. + +One can inspect the conditioning values with, for example: + +```@example +@model function d() + x ~ Normal() + return y ~ Normal() +end + +cond_model = d() | (@varname(x) => 1.0) +cond_ctx = cond_model.context +``` + +There are several internal functions that are used to determine whether a variable is conditioned, and if so, what its value is. + +```@example +DynamicPPL.hasconditioned_nested(cond_ctx, @varname(x)) +``` + +```@example +DynamicPPL.getconditioned_nested(cond_ctx, @varname(x)) +``` + +These functions are in turn used by the function `DynamicPPL.contextual_isassumption`, which is largely the same as `hasconditioned_nested`, but also checks whether the value is `missing` (in which case it isn't really conditioned). + +```@example +DynamicPPL.contextual_isassumption(cond_ctx, @varname(x)) +``` + +!!! note + + Notice that (neglecting `missing` values) the return value of `contextual_isassumption` is the _opposite_ of `hasconditioned_nested`, i.e. for a variable that _is_ conditioned on, `contextual_isassumption` returns `false`. + +If a variable `x` is conditioned on, then the effect of this is to set the value of `x` to the given value (while still including its contribution to the log probability density). +Since `x` is no longer a random variable, if we were to evaluate the model, we would find only one key in the `VarInfo`: + +```@example +keys(VarInfo(cond_model)) +``` + +## Joint behaviour: desiderata at the model level + +When paired together, these two contexts have the potential to cause substantial confusion: `PrefixContext` modifies the variable names that are seen, which may cause them to be out of sync with the values contained inside the `ConditionContext`. + +We begin by mentioning some high-level desiderata for their joint behaviour. +Take these models, for example: + +```@example +# We define a helper function to unwrap a layer of SamplingContext, to +# avoid cluttering the print statements. +unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context +unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx +@model function inner() + println("inner context: $(unwrap_sampling_context(__context__))") + x ~ Normal() + return y ~ Normal() +end + +@model function outer() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner()) +end + +# 'Outer conditioning' +with_outer_cond = outer() | (@varname(a.x) => 1.0) + +# 'Inner conditioning' +inner_cond = inner() | (@varname(x) => 1.0) +@model function outer2() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner_cond) +end +with_inner_cond = outer2() +``` + +We want that: + + 1. `keys(VarInfo(outer()))` should return `[a.x, a.y]`; + 2. `keys(VarInfo(with_outer_cond))` should return `[a.y]`; + 3. `keys(VarInfo(with_inner_cond))` should return `[a.y]`, + +**In other words, we can condition submodels either from the outside (point (2)) or from the inside (point (3)), and the variable name we use to specify the conditioning should match the level at which we perform the conditioning.** + +This is an incredibly salient point because it means that submodels can be treated as individual, opaque objects, and we can condition them without needing to know what it will be prefixed with, or the context in which that submodel is being used. +For example, this means we can reuse `inner_cond` in another model with a different prefix, and it will _still_ have its inner `x` value be conditioned, despite the prefix differing. + +!!! info + + In the current version of DynamicPPL, these criteria are all fulfilled. However, this was not the case in the past: in particular, point (3) was not fulfilled, and users had to condition the internal submodel with the prefixes that were used outside. (See [this GitHub issue](https://github.com/TuringLang/DynamicPPL.jl/issues/857) for more information; this issue was the direct motivation for this documentation page.) + +## Desiderata at the context level + +The above section describes how we expect conditioning and prefixing to behave from a user's perpective. +We now turn to the question of how we implement this in terms of DynamicPPL contexts. +We do not specify the implementation details here, but we will sketch out something resembling an API that will allow us to achieve the target behaviour. + +**Point (1)** does not involve any conditioning, only prefixing; it is therefore already satisfied by virtue of the `tilde_assume` method shown above. + +**Points (2) and (3)** are more tricky. +As the reader may surmise, the difference between them is the order in which the contexts are stacked. + +For the _outer_ conditioning case (point (2)), the `ConditionContext` will contain a `VarName` that is already prefixed. +When we enter the inner submodel, this `ConditionContext` has to be passed down and somehow combined with the `PrefixContext` that is created when we enter the submodel. +We make the claim here that the best way to do this is to nest the `PrefixContext` _inside_ the `ConditionContext`. +This is indeed what happens, as can be demonstrated by running the model. + +```@example +with_outer_cond(); +nothing; +``` + +!!! info + + The `; nothing` at the end is purely to circumvent a Documenter.jl quirk where stdout is only shown if the return value of the final statement is `nothing`. + If these documentation pages are moved to Quarto, it will be possible to remove this. + +For the _inner_ conditioning case (point (3)), the outer model is not run with any special context. +The inner model will itself contain a `ConditionContext` will contain a `VarName` that is not prefixed. +When we run the model, this `ConditionContext` should be then nested _inside_ a `PrefixContext` to form the final evaluation context. +Again, we can run the model to see this in action: + +```@example +with_inner_cond(); +nothing; +``` + +Putting all of the information so far together, what it means is that if we have these two inner contexts (taken from above): + +```@example +using DynamicPPL: PrefixContext, ConditionContext, DefaultContext + +inner_ctx_with_outer_cond = ConditionContext( + Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a)) +) +inner_ctx_with_inner_cond = PrefixContext( + @varname(a), ConditionContext(Dict(@varname(x) => 1.0)) +) +``` + +then we want both of these to be `true` (and thankfully, they are!): + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x)) +``` + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x)) +``` + +This allows us to finally specify our task as follows: + +(1) Given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly. + +(2) We need to make sure that both the correct arguments are supplied. In order to do so: + + - (2a) We need to make sure that when evaluating a submodel, the context stack is arranged such that `PrefixContext` is applied _inside_ the parent model's context, but _outside_ the submodel's own context. + + - (2b) We also need to make sure that the `VarName` passed to it is prefixed correctly. + +## How do we do it? + +(1) `hasconditioned_nested` and `getconditioned_nested` accomplish this by first 'collapsing' the context stack, i.e. they go through the context stack, remove all `PrefixContext`s, and apply those prefixes to any conditioned variables below it in the stack. +Once the `PrefixContext`s have been removed, one can then iterate through the context stack and check if any of the `ConditionContext`s contain the variable, or get the value itself. +For more details the reader is encouraged to read the source code. + +(2a) We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`. +This function is called whenever a model (which itself contains a context) is evaluated with a separate ('external') context, and makes sure to arrange both of these contexts such that _the model's context is nested inside the external context_. +Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an external context to give the behaviour seen above. + +(2b) At first glance, it seems like `tilde_assume` can take care of the `VarName` prefixing for us (as described in the first section). +However, this is not actually the case: `contextual_isassumption`, which is the function that calls `hasconditioned_nested`, is much higher in the call stack than `tilde_assume` is. +So, we need to explicitly prefix it before passing it to `contextual_isassumption`. +This is done inside the `@model` macro, or technically, its subsidiary function `isassumption`. + +## Nested submodels + +Just in case the above wasn't complicated enough, we need to also be very careful when dealing with nested submodels, which have multiple layers of `PrefixContext`s which may be interspersed with `ConditionContext`s. +For example, in this series of nested submodels, + +```@example +@model function charlie() + x ~ Normal() + y ~ Normal() + return z ~ Normal() +end +@model function bravo() + return b ~ to_submodel(charlie() | (@varname(x) => 1.0)) +end +@model function alpha() + return a ~ to_submodel(bravo() | (@varname(b.y) => 1.0)) +end +``` + +we expect that the only variable to be sampled should be `z` inside `charlie`, or rather, `a.b.z` once it has been through the prefixes. + +```@example +keys(VarInfo(alpha())) +``` + +The general strategy that we adopt is similar to above. +Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be: + +```@example +big_ctx = PrefixContext( + @varname(a), + ConditionContext( + Dict(@varname(b.y) => 1.0), + PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))), + ), +) +``` + +We need several things to work correctly here: we need the `VarName` prefixing to behave correctly, and then we need to implement `hasconditioned_nested` and `getconditioned_nested` on the resulting prefixed `VarName`. +It turns out that the prefixing itself is enough to illustrate the most important point in this section, namely, the need to traverse the context stack in a _different direction_ to what most of DynamicPPL does. + +Let's work with a function called `myprefix(::AbstractContext, ::VarName)` (to avoid confusion with any existing DynamicPPL function). +We should like `myprefix(big_ctx, @varname(x))` to return `@varname(a.b.x)`. +Consider the following naive implementation, which mirrors a lot of code in the tilde-pipeline: + +```@example +using DynamicPPL: NodeTrait, IsLeaf, IsParent, childcontext, AbstractContext +using AbstractPPL: AbstractPPL + +function myprefix(ctx::DynamicPPL.AbstractContext, vn::VarName) + return myprefix(NodeTrait(ctx), ctx, vn) +end +function myprefix(::IsLeaf, ::AbstractContext, vn::VarName) + return vn +end +function myprefix(::IsParent, ctx::AbstractContext, vn::VarName) + return myprefix(childcontext(ctx), vn) +end +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) + # The functionality to actually manipulate the VarNames is in AbstractPPL + new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix) + # Then pass to the child context + return myprefix(childcontext(ctx), new_vn) +end + +myprefix(big_ctx, @varname(x)) +``` + +This implementation clearly is not correct, because it applies the _inner_ `PrefixContext` before the outer one. + +The right way to implement `myprefix` is to, essentially, reverse the order of two lines above: + +```@example +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) + # Pass to the child context first + new_vn = myprefix(childcontext(ctx), vn) + # Then apply this context's prefix + return AbstractPPL.prefix(new_vn, ctx.vn_prefix) +end + +myprefix(big_ctx, @varname(x)) +``` + +This is a much better result! +The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions. +When editing this code, it is worth being mindful of this as a potential source of incorrectness. + +!!! info + + If you have encountered left and right folds, the above discussion illustrates the difference between them: the wrong implementation of `myprefix` uses a left fold (which collects prefixes in the opposite order from which they are encountered), while the correct implementation uses a right fold. + +## Loose ends 1: Manual prefixing + +Sometimes users may want to manually prefix a model, for example: + +```@example +@model function inner_manual() + x ~ Normal() + return y ~ Normal() +end + +@model function outer_manual() + return _unused ~ to_submodel(prefix(inner_manual(), :a), false) +end +``` + +In this case, the `VarName` on the left-hand side of the tilde is not used, and the prefix is instead specified using the `prefix` function. + +The way to deal with this follows on from the previous discussion. +Specifically, we said that: + +> [...] as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined [...] + +When automatic prefixing is used, this application of `PrefixContext` occurs inside the `tilde_assume!!` method. +In the manual prefixing case, we need to make sure that `prefix(submodel::Model, ::Symbol)` does the same thing, i.e. it inserts a `PrefixContext` at the outermost layer of `submodel`'s context. +We can see that this is precisely what happens: + +```@example +@model f() = x ~ Normal() + +model = f() +prefixed_model = prefix(model, :a) + +(model.context, prefixed_model.context) +``` + +## Loose ends 2: FixedContext + +Finally, note that all of the above also applies to the interaction between `PrefixContext` and `FixedContext`, except that the functions have different names. +(`FixedContext` behaves the same way as `ConditionContext`, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.) +This generally results in a large amount of code duplication, but the concepts that underlie both contexts are exactly the same. diff --git a/src/compiler.jl b/src/compiler.jl index 4771b0171..6f7489b8e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -53,7 +53,9 @@ function isassumption( vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), ) return quote - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + if $(DynamicPPL.contextual_isassumption)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) # Considered an assumption by `__context__` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of @@ -87,67 +89,45 @@ isassumption(expr) = :(false) contextual_isassumption(context, vn) Return `true` if `vn` is considered an assumption by `context`. - -The default implementation for `AbstractContext` always returns `true`. """ -contextual_isassumption(::IsLeaf, context, vn) = true -function contextual_isassumption(::IsParent, context, vn) - return contextual_isassumption(childcontext(context), vn) -end function contextual_isassumption(context::AbstractContext, vn) - return contextual_isassumption(NodeTrait(context), context, vn) -end -function contextual_isassumption(context::ConditionContext, vn) - if hasconditioned(context, vn) - val = getconditioned(context, vn) + if hasconditioned_nested(context, vn) + val = getconditioned_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return true else return false end + else + return true end - - # We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isassumption(childcontext(context), vn) -end -function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(childcontext(context), prefix(context, vn)) end isfixed(expr, vn) = false -isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn)) +function isfixed(::Union{Symbol,Expr}, vn) + return :($(DynamicPPL.contextual_isfixed)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + )) +end """ contextual_isfixed(context, vn) Return `true` if `vn` is considered fixed by `context`. """ -contextual_isfixed(::IsLeaf, context, vn) = false -function contextual_isfixed(::IsParent, context, vn) - return contextual_isfixed(childcontext(context), vn) -end function contextual_isfixed(context::AbstractContext, vn) - return contextual_isfixed(NodeTrait(context), context, vn) -end -function contextual_isfixed(context::PrefixContext, vn) - return contextual_isfixed(childcontext(context), prefix(context, vn)) -end -function contextual_isfixed(context::FixedContext, vn) - if hasfixed(context, vn) - val = getfixed(context, vn) + if hasfixed_nested(context, vn) + val = getfixed_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return false else return true end + else + return false end - - # We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isfixed(childcontext(context), vn) end # If we're working with, say, a `Symbol`, then we're not going to `view`. @@ -467,13 +447,17 @@ function generate_tilde(left, right) ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) + $left = $(DynamicPPL.getfixed_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) + $left = $(DynamicPPL.getconditioned_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e4ba5d252..eb025dec8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -85,12 +85,23 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig end function tilde_assume(context::PrefixContext, right, vn, vi) - return tilde_assume(context.context, right, prefix(context, vn), vi) + # Note that we can't use something like this here: + # new_vn = prefix(context, vn) + # return tilde_assume(childcontext(context), right, new_vn, vi) + # This is because `prefix` applies _all_ prefixes in a given context to a + # variable name. Thus, if we had two levels of nested prefixes e.g. + # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the + # first call would apply the prefix `a.b._`, and the recursive call + # would apply the prefix `b._`, resulting in `b.a.b._`. + # This is why we need a special function, `prefix_and_strip_contexts`. + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(new_context, right, new_vn, vi) end function tilde_assume( rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi ) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(rng, new_context, sampler, right, new_vn, vi) end """ @@ -104,12 +115,27 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) return if is_rhs_model(right) - # Prefix the variables using the `vn`. - rand_like!!( - right, - should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context, - vi, - ) + # Here, we apply the PrefixContext _not_ to the parent `context`, but + # to the context of the submodel being evaluated. This means that later= + # on in `make_evaluate_args_and_kwargs`, the context stack will be + # correctly arranged such that it goes like this: + # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> + # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext + # See the docstring of `make_evaluate_args_and_kwargs`, and the internal + # DynamicPPL documentation on submodel conditioning, for more details. + # + # NOTE: This relies on the existence of `right.model.model`. Right now, + # the only thing that can return true for `is_rhs_model` is something + # (a `Sampleable`) that has a `model` field that itself (a + # `ReturnedModelWrapper`) has a `model` field. This may or may not + # change in the future. + if should_auto_prefix(right) + dppl_model = right.model.model # This isa DynamicPPL.Model + prefixed_submodel_context = PrefixContext(vn, dppl_model.context) + new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) + right = to_submodel(new_dppl_model, true) + end + rand_like!!(right, context, vi) else value, logp, vi = tilde_assume(context, right, vn, vi) value, acclogp_assume!!(context, vi, logp) diff --git a/src/contexts.jl b/src/contexts.jl index 58ac612b8..8ac085663 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -237,27 +237,34 @@ function setchildcontext(parent::MiniBatchContext, child) end """ - PrefixContext{Prefix}(context) + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} Create a context that allows you to use the wrapped `context` when running the model and -adds the `Prefix` to all parameters. +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. This context is useful in nested models to ensure that the names of the parameters are unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Prefix,C} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn context::C end -function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(context)}(context) +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} - return PrefixContext{Prefix}(child) +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) end """ @@ -265,8 +272,8 @@ end Apply the prefixes in the context `ctx` to the variable name `vn`. """ -function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}()) +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end function prefix(ctx::AbstractContext, vn::VarName) return prefix(NodeTrait(ctx), ctx, vn) @@ -277,11 +284,52 @@ function prefix(::IsParent, ctx::AbstractContext, vn::VarName) end """ - prefix(model::Model, x) + prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + +Same as `prefix`, but additionally returns a new context stack that has all the +PrefixContexts removed. -Return `model` but with all random variables prefixed by `x`. +NOTE: This does _not_ modify any variables in any `ConditionContext` and +`FixedContext` that may be present in the context stack. This is because this +function is only used in `tilde_assume`, which is lower in the tilde-pipeline +than `contextual_isassumption` and `contextual_isfixed` (the functions which +actually use the `ConditionContext` and `FixedContext` values). Thus, by this +time, any `ConditionContext`s and `FixedContext`s present have already served +their purpose. -If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing. +If you call this function, you must therefore be careful to ensure that you _do +not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you +_do_ need to modify them, then you may need to use +`prefix_cond_and_fixed_variables` instead. +""" +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + child_context = childcontext(ctx) + # vn_prefixed contains the prefixes from all lower levels + vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( + child_context, vn + ) + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes +end +function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) + return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) +end +prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) + vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) + return vn, setchildcontext(ctx, new_ctx) +end + +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. # Examples @@ -291,17 +339,19 @@ julia> using DynamicPPL: prefix julia> @model demo() = x ~ Dirac(1) demo (generic function with 2 methods) -julia> rand(prefix(demo(), :my_prefix)) +julia> rand(prefix(demo(), @varname(my_prefix))) (var"my_prefix.x" = 1,) -julia> # One can also use `Val` to avoid runtime overheads. - rand(prefix(demo(), Val(:my_prefix))) +julia> rand(prefix(demo(), Val(:my_prefix))) (var"my_prefix.x" = 1,) ``` """ -prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context)) -function prefix(model::Model, ::Val{x}) where {x} - return contextualize(model, PrefixContext{Symbol(x)}(model.context)) +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) end """ @@ -370,7 +420,9 @@ Return value of `vn` in `context`. function getconditioned(context::AbstractContext, vn::VarName) return error("context $(context) does not contain value for $vn") end -getconditioned(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) +function getconditioned(context::ConditionContext, vn::VarName) + return getvalue(context.values, vn) +end """ hasconditioned_nested(context, vn) @@ -388,7 +440,7 @@ function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(childcontext(context), prefix(context, vn)) + return hasconditioned_nested(collapse_prefix_stack(context), vn) end """ @@ -406,7 +458,7 @@ function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(childcontext(context), prefix(context, vn)) + return getconditioned_nested(collapse_prefix_stack(context), vn) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) @@ -476,6 +528,9 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end +function conditioned(context::PrefixContext) + return conditioned(collapse_prefix_stack(context)) +end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values @@ -539,7 +594,7 @@ function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(childcontext(context), prefix(context, vn)) + return hasfixed_nested(collapse_prefix_stack(context), vn) end """ @@ -557,7 +612,7 @@ function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(childcontext(context), prefix(context, vn)) + return getfixed_nested(collapse_prefix_stack(context), vn) end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) @@ -652,3 +707,113 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end +function fixed(context::PrefixContext) + return fixed(collapse_prefix_stack(context)) +end + +""" + collapse_prefix_stack(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove +the `PrefixContext`s from the context stack. + +!!! note + If you are reading this docstring, you might probably be interested in a more +thorough explanation of how PrefixContext and ConditionContext / FixedContext +interact with one another, especially in the context of submodels. + The DynamicPPL documentation contains [a separate page on this +topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) +which explains this in much more detail. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_stack + +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); + +julia> collapse_prefix_stack(c1) +ConditionContext(Dict(a.x => 1), DefaultContext()) + +julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); + +julia> collapsed = collapse_prefix_stack(c2); + +julia> # `collapsed` really looks something like this: + # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) + # To avoid fragility arising from the order of the keys in the doctest, we test + # this indirectly: + collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] +(1, 2) +``` +""" +function collapse_prefix_stack(context::PrefixContext) + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_stack(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) +end +function collapse_prefix_stack(context::AbstractContext) + return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) +end +collapse_prefix_stack(::IsLeaf, context) = context +function collapse_prefix_stack(::IsParent, context) + new_child_context = collapse_prefix_stack(childcontext(context)) + return setchildcontext(context, new_child_context) +end + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end diff --git a/src/model.jl b/src/model.jl index b4d5f6bb7..c7c4bdf57 100644 --- a/src/model.jl +++ b/src/model.jl @@ -425,29 +425,32 @@ julia> # Returns all the variables we have conditioned on + their values. conditioned(condition(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); +julia> # Nested ones also work. + # (Note that `PrefixContext` also prefixes the variables of any + # ConditionContext that is _inside_ it; because of this, the type of the + # container has to be broadened to a `Dict`.) + cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); -julia> conditioned(cm) -(x = 100.0, m = 1.0) +julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) +true -julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. +julia> # Since we conditioned on `a.m`, it is not treated as a random variable. + # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m - -julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0); +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> conditioned(cm)[@varname(x)] -100.0 +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> conditioned(cm)[@varname(a.m)] -1.0 +julia> conditioned(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # No variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ conditioned(model::Model) = conditioned(model.context) @@ -765,29 +768,27 @@ julia> # Returns all the variables we have fixed on + their values. fixed(fix(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0); +julia> # The rest of this is the same as the `condition` example above. + cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); -julia> fixed(cm) -(x = 100.0, m = 1.0) - -julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m +julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) +true -julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. - cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0); +julia> keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> fixed(cm)[@varname(x)] -100.0 +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> fixed(cm)[@varname(a.m)] -1.0 +julia> fixed(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # <= no variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ fixed(model::Model) = fixed(model.context) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f6b9c4479..5f1ec95ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -223,12 +223,12 @@ end prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) function prefix_submodel_context(prefix, ctx) # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx)) + return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) end function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) # E.g. `prefix="asd"`. - return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx)) + return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) end function prefix_submodel_context(prefix::Bool, ctx) diff --git a/src/utils.jl b/src/utils.jl index 56c3d70af..71919480c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1286,7 +1286,10 @@ broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) # Convert (x=1,) to Dict(@varname(x) => 1) -_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) +function to_varname_dict(nt::NamedTuple) + return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt)) +end +to_varname_dict(d::AbstractDict) = d # Version of `merge` used by `conditioned` and `fixed` to handle # the scenario where we might try to merge a dict with an empty # tuple. @@ -1294,9 +1297,9 @@ _nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, ::NamedTuple{()}) = left -_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right)) +_merge(left::AbstractDict, right::NamedTuple) = merge(left, to_varname_dict(right)) _merge(::NamedTuple{()}, right::AbstractDict) = right -_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right) +_merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right) """ unique_syms(vns::T) where {T<:NTuple{N,VarName}} diff --git a/test/contexts.jl b/test/contexts.jl index 11e591f8f..1ba099a37 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,5 @@ using Test, DynamicPPL, Accessors +using AbstractPPL: getoptic using DynamicPPL: leafcontext, setleafcontext, @@ -10,12 +11,18 @@ using DynamicPPL: IsParent, PointwiseLogdensityContext, contextual_isassumption, + FixedContext, ConditionContext, decondition_context, hasconditioned, getconditioned, + conditioned, + fixed, hasconditioned_nested, - getconditioned_nested + getconditioned_nested, + collapse_prefix_stack, + prefix_cond_and_fixed_variables, + getvalue using EnzymeCore @@ -50,14 +57,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), :minibatch => MiniBatchContext(DefaultContext(), 0.0), - :prefix => PrefixContext{:x}(DefaultContext()), + :prefix => PrefixContext(@varname(x)), :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), :condition3 => ConditionContext( - (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0))) + (x=1.0,), + PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -70,91 +78,52 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "contextual_isassumption" begin - @testset "$(name)" for (name, context) in contexts - # Any `context` should return `true` by default. - @test contextual_isassumption(context, VarName{gensym(:x)}()) - - if any(Base.Fix2(isa, ConditionContext), context) - # We have a `ConditionContext` among us. - # Let's first extract the conditioned variables. - conditioned_values = DynamicPPL.conditioned(context) + @testset "extracting conditioned values" begin + # This testset tests `contextual_isassumption`, `getconditioned_nested`, and + # `hasconditioned_nested`. - # The conditioned values might be a NamedTuple, or a Dict. - # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end - - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) - else - vn - end - - @show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # Let's check elementwise. - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - if getoptic(vn_child)(val) === missing - @test contextual_isassumption(context, vn_child) - else - @test !contextual_isassumption(context, vn_child) - end - end - end - end - end - end - - @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$name" for (name, context) in contexts + @testset "$(name)" for (name, context) in contexts + # If the varname doesn't exist, it should always be an assumption. fake_vn = VarName{gensym(:x)}() + @test contextual_isassumption(context, fake_vn) @test !hasconditioned_nested(context, fake_vn) @test_throws ErrorException getconditioned_nested(context, fake_vn) if any(Base.Fix2(isa, ConditionContext), context) - # `ConditionContext` specific. - + # We have a `ConditionContext` among us. # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) + # The conditioned values might be a NamedTuple, or a Dict. # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end + conditioned_values = DynamicPPL.to_varname_dict(conditioned_values) + + # Extract all conditioned variables. We also use varname_leaves + # here to split up arrays which could potentially have some, + # but not all, elements being `missing`. + conditioned_vns = mapreduce( + p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + vcat, + pairs(conditioned_values), + ) - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + # We can now loop over them to check which ones are missing. We use + # `getvalue` to handle the awkward case where sometimes + # `conditioned_values` contains the full Varname (e.g. `a.x`) and + # sometimes only the main symbol (e.g. it contains `x` when + # `vn` is `x[1]`) + for vn in conditioned_vns + val = DynamicPPL.getvalue(conditioned_values, vn) + # These VarNames are present in the conditioning values, so + # we should always be able to extract the value. + @test hasconditioned_nested(context, vn) + @test getconditioned_nested(context, vn) === val + # However, the return value of contextual_isassumption depends on + # whether the value is missing or not. + if ismissing(val) + @test contextual_isassumption(context, vn) else - vn - end - - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # `vn_child` should be in `context`. - @test hasconditioned_nested(context, vn_child) - # Value should be the same as extracted above. - @test getconditioned_nested(context, vn_child) === - getoptic(vn_child)(val) + @test !contextual_isassumption(context, vn) end end end @@ -163,39 +132,68 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "PrefixContext" begin @testset "prefixing" begin - ctx = @inferred PrefixContext{:a}( - PrefixContext{:b}( - PrefixContext{:c}( - PrefixContext{:d}( - PrefixContext{:e}(PrefixContext{:f}(DefaultContext())) + ctx = @inferred PrefixContext( + @varname(a), + PrefixContext( + @varname(b), + PrefixContext( + @varname(c), + PrefixContext( + @varname(d), + PrefixContext( + @varname(e), PrefixContext(@varname(f), DefaultContext()) + ), ), ), ), ) - vn = VarName{:x}() + vn = @varname(x) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x) - vn = VarName{:x}(((1,),)) + vn = @varname(x[1]) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) + ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext{:b}(ctx2) + ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end + @testset "prefix_and_strip_contexts" begin + vn = @varname(x[1]) + ctx1 = PrefixContext(@varname(a)) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == DefaultContext() + + ctx2 = SamplingContext(PrefixContext(@varname(a))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext() + + ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == ConditionContext((a=1,)) + + ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext(ConditionContext((a=1,))) + end + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix = :my_prefix - context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) + prefix_vn = @varname(my_prefix) + context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) @@ -204,7 +202,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Extract the ground truth varnames vns_expected = Set([ - AbstractPPL.prefix(vn, VarName{prefix}()) for + AbstractPPL.prefix(vn, prefix_vn) for vn in DynamicPPL.TestUtils.varnames(model) ]) @@ -343,4 +341,103 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) end end + + @testset "PrefixContext + Condition/FixedContext interactions" begin + @testset "prefix_cond_and_fixed_variables" begin + c1 = ConditionContext((c=1, d=2)) + c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) + @test c1_prefixed isa ConditionContext + @test childcontext(c1_prefixed) isa DefaultContext + @test c1_prefixed.values[@varname(a.c)] == 1 + @test c1_prefixed.values[@varname(a.d)] == 2 + + c2 = FixedContext((f=1, g=2)) + c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) + @test c2_prefixed isa FixedContext + @test childcontext(c2_prefixed) isa DefaultContext + @test c2_prefixed.values[@varname(a.f)] == 1 + @test c2_prefixed.values[@varname(a.g)] == 2 + + c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) + c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) + c3_prefixed_child = childcontext(c3_prefixed) + @test c3_prefixed isa ConditionContext + @test c3_prefixed.values[@varname(a.c)] == 1 + @test c3_prefixed.values[@varname(a.d)] == 2 + @test c3_prefixed_child isa FixedContext + @test c3_prefixed_child.values[@varname(a.f)] == 1 + @test c3_prefixed_child.values[@varname(a.g)] == 2 + @test childcontext(c3_prefixed_child) isa DefaultContext + end + + @testset "collapse_prefix_stack" begin + # Utility function to make sure that there are no PrefixContexts in + # the context stack. + function has_no_prefixcontexts(ctx::AbstractContext) + return !(ctx isa PrefixContext) && ( + NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) + ) + end + + # Prefix -> Condition + c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) + c1 = collapse_prefix_stack(c1) + @test has_no_prefixcontexts(c1) + c1_vals = conditioned(c1) + @test length(c1_vals) == 2 + @test getvalue(c1_vals, @varname(a.c)) == 1 + @test getvalue(c1_vals, @varname(a.d)) == 2 + + # Condition -> Prefix + c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) + c2 = collapse_prefix_stack(c2) + @test has_no_prefixcontexts(c2) + c2_vals = conditioned(c2) + @test length(c2_vals) == 2 + @test getvalue(c2_vals, @varname(c)) == 1 + @test getvalue(c2_vals, @varname(d)) == 2 + + # Prefix -> Fixed + c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) + c3 = collapse_prefix_stack(c3) + c3_vals = fixed(c3) + @test length(c3_vals) == 2 + @test length(c3_vals) == 2 + @test getvalue(c3_vals, @varname(a.f)) == 1 + @test getvalue(c3_vals, @varname(a.g)) == 2 + + # Fixed -> Prefix + c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) + c4 = collapse_prefix_stack(c4) + @test has_no_prefixcontexts(c4) + c4_vals = fixed(c4) + @test length(c4_vals) == 2 + @test getvalue(c4_vals, @varname(f)) == 1 + @test getvalue(c4_vals, @varname(g)) == 2 + + # Prefix -> Condition -> Prefix -> Condition + c5 = PrefixContext( + @varname(a), + ConditionContext( + (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) + ), + ) + c5 = collapse_prefix_stack(c5) + @test has_no_prefixcontexts(c5) + c5_vals = conditioned(c5) + @test length(c5_vals) == 2 + @test getvalue(c5_vals, @varname(a.c)) == 1 + @test getvalue(c5_vals, @varname(a.b.d)) == 2 + + # Prefix -> Condition -> Prefix -> Fixed + c6 = PrefixContext( + @varname(a), + ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), + ) + c6 = collapse_prefix_stack(c6) + @test has_no_prefixcontexts(c6) + @test conditioned(c6) == Dict(@varname(a.c) => 1) + @test fixed(c6) == Dict(@varname(a.b.d) => 2) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 3473d5594..72f33f2d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -67,6 +67,7 @@ include("test_util.jl") include("threadsafe.jl") include("debug_utils.jl") include("deprecated.jl") + include("submodels.jl") end if GROUP == "All" || GROUP == "Group2" diff --git a/test/submodels.jl b/test/submodels.jl new file mode 100644 index 000000000..e79eed2c3 --- /dev/null +++ b/test/submodels.jl @@ -0,0 +1,199 @@ +module DPPLSubmodelTests + +using DynamicPPL +using Distributions +using Test + +@testset "submodels.jl" begin + @testset "$op with AbstractPPL API" for op in [condition, fix] + x_val = 1.0 + x_logp = op == condition ? logpdf(Normal(), x_val) : 0.0 + + @testset "Auto prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(inner()) + end + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(inner_op) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(a.x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) + end + end + + @testset "No prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(inner(), false) + end + @model function outer2() + return a ~ to_submodel(inner_op, false) + end + with_inner_op = outer2() + inner_op = op(inner(), (@varname(x) => x_val)) + with_outer_op = op(outer(), (@varname(x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(y)]) + end + end + + @testset "Manual prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(prefix(inner(), :b), false) + end + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(prefix(inner_op, :b), false) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(b.x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) + end + end + + @testset "Complex prefixes" begin + mutable struct P + a::Float64 + b::Float64 + end + @model function f() + x = Vector{Float64}(undef, 1) + x[1] ~ Normal() + y ~ Normal() + return x[1] + end + @model function g() + p = P(1.0, 2.0) + p.a ~ to_submodel(f()) + p.b ~ Normal() + return (p.a, p.b) + end + expected_vns = Set([@varname(p.a.x[1]), @varname(p.a.y), @varname(p.b)]) + @test Set(keys(VarInfo(g()))) == expected_vns + + # Check that we can condition/fix on any of them from the outside + for vn in expected_vns + op_g = op(g(), (vn => 1.0)) + vi = VarInfo(op_g) + @test Set(keys(vi)) == symdiff(expected_vns, Set([vn])) + end + end + + @testset "Nested submodels" begin + @model function f() + x ~ Normal() + return y ~ Normal() + end + @model function g() + return _unused ~ to_submodel(prefix(f(), :b), false) + end + @model function h() + return a ~ to_submodel(g()) + end + + # No conditioning + vi = VarInfo(h()) + @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogp(vi) == + logpdf(Normal(), vi[@varname(a.b.x)]) + + logpdf(Normal(), vi[@varname(a.b.y)]) + + # Conditioning/fixing at the top level + op_h = op(h(), (@varname(a.b.x) => x_val)) + + # Conditioning/fixing at the second level + op_g = op(g(), (@varname(b.x) => x_val)) + @model function h2() + return a ~ to_submodel(op_g) + end + + # Conditioning/fixing at the very bottom + op_f = op(f(), (@varname(x) => x_val)) + @model function g2() + return _unused ~ to_submodel(prefix(op_f, :b), false) + end + @model function h3() + return a ~ to_submodel(g2()) + end + + models = [("top", op_h), ("middle", h2()), ("bottom", h3())] + @testset "$name" for (name, model) in models + vi = VarInfo(model) + @test Set(keys(vi)) == Set([@varname(a.b.y)]) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + end + end + end + + @testset "conditioning via model arguments" begin + @model function f(x) + x ~ Normal() + return y ~ Normal() + end + @model function g(inner_x) + return a ~ to_submodel(f(inner_x)) + end + + vi = VarInfo(g(1.0)) + @test Set(keys(vi)) == Set([@varname(a.y)]) + + vi = VarInfo(g(missing)) + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + end +end + +end From 13da08a990943bf3b5ec3ec81394e55784644a26 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 10:58:00 +0100 Subject: [PATCH 23/48] Revert ThreadSafeVarInfo back to Vectors and fix some AD type casting in (Simple)VarInfo --- src/accumulators.jl | 22 ++++++++++++++++++++++ src/simple_varinfo.jl | 6 +++++- src/threadsafe.jl | 35 +++++++++++++++++------------------ src/varinfo.jl | 12 +++++++++++- 4 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index a1f019610..c096a25a7 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -140,6 +140,10 @@ Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} = haskey(at.nt, accname) Base.keys(at::AccumulatorTuple) = keys(at.nt) +function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} + return AccumulatorTuple(convert(T, accs.nt)) +end + """ setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) @@ -301,3 +305,21 @@ end accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc accumulate_observe!!(acc::NumProduce, right, left, vn) = increment(acc) + +Base.convert(::Type{LogPrior{T}}, acc::LogPrior) where {T} = LogPrior(convert(T, acc.logp)) +function Base.convert(::Type{LogLikelihood{T}}, acc::LogLikelihood) where {T} + return LogLikelihood(convert(T, acc.logp)) +end +function Base.convert(::Type{NumProduce{T}}, acc::NumProduce) where {T} + return NumProduce(convert(T, acc.num)) +end + +convert_eltype(acc::LogPrior, ::Type{T}) where {T} = LogPrior(convert(T, acc.logp)) +function convert_eltype(acc::LogLikelihood, ::Type{T}) where {T} + return LogLikelihood(convert(T, acc.logp)) +end +# TODO(mhauru) +# We ignore the convert_eltype calls for NumProduce. This is because they are only used to +# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is +# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. +convert_eltype(acc::NumProduce, ::Type) = NumProduce(acc.num) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c0f85fee4..aba7e7d39 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -248,7 +248,11 @@ end function unflatten(svi::SimpleVarInfo, x::AbstractVector) vals = unflatten(svi.values, x) - return SimpleVarInfo(vals, svi.accs, svi.transformation) + # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is + # required but undesireable. + et = float_type_with_fallback(eltype(x)) + accs = map_accumulator!!(svi.accs, convert_eltype, et) + return SimpleVarInfo(vals, accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 311663574..9155f5da5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -4,15 +4,14 @@ A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an array of accumulators for thread-safe execution of a probabilistic model. """ -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:NTuple{N,AccumulatorTuple} where {N}} <: - AbstractVarInfo +struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo varinfo::V - accs_by_thread::L + accs_by_thread::Vector{L} end - function ThreadSafeVarInfo(vi::AbstractVarInfo) - split_accs = AccumulatorTuple(map(split, vi.accs.nt)) - accs_by_thread = ntuple(_ -> deepcopy(split_accs), Val(Threads.nthreads())) + accs_by_thread = [ + AccumulatorTuple(map(split, vi.accs.nt)) for _ in 1:Threads.nthreads() + ] return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi @@ -51,17 +50,16 @@ end # _not_ be thread-specific a specific method has to be written. function map_accumulator!!(vi::ThreadSafeVarInfo, accname::Val, func::Function, args...) tid = Threads.threadid() - new_accs = map_accumulator!!(vi.accs_by_thread[tid], accname, func, args...) - accs_by_thread = BangBang.setindex!!(vi.accs_by_thread, new_accs, tid) - return ThreadSafeVarInfo(vi.varinfo, accs_by_thread) + vi.accs_by_thread[tid] = map_accumulator!!( + vi.accs_by_thread[tid], accname, func, args... + ) + return vi end function map_accumulator!!(vi::ThreadSafeVarInfo, func::Function, args...) tid = Threads.threadid() - new_accs = map_accumulator!!(vi.accs_by_thread[tid], func, args...) - # We need to use setindex!! in case new_accs has a different type than the original. - accs_by_thread = BangBang.setindex!!(vi.accs_by_thread, new_accs, tid) - return ThreadSafeVarInfo(vi.varinfo, accs_by_thread) + vi.accs_by_thread[tid] = map_accumulator!!(vi.accs_by_thread[tid], func, args...) + return vi end has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) @@ -187,12 +185,13 @@ end function resetlogp!!(vi::ThreadSafeVarInfo) vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) - accs_by_thread = map(vi.accs_by_thread) do accs - new_accs = map_accumulator!!(accs, Val(:LogPrior), zero) - new_accs = map_accumulator!!(new_accs, Val(:LogLikelihood), zero) - new_accs + for i in eachindex(vi.accs_by_thread) + vi.accs_by_thread[i] = map_accumulator!!(vi.accs_by_thread[i], Val(:LogPrior), zero) + vi.accs_by_thread[i] = map_accumulator!!( + vi.accs_by_thread[i], Val(:LogLikelihood), zero + ) end - return ThreadSafeVarInfo(vi.varinfo, accs_by_thread) + return vi end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) diff --git a/src/varinfo.jl b/src/varinfo.jl index 98882c423..574b64160 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -441,7 +441,17 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - return VarInfo(md, vi.accs) + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + et = float_type_with_fallback(eltype(x)) + accs = map_accumulator!!(vi.accs, convert_eltype, et) + return VarInfo(md, accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in From 221e797c8377a9d73063100aba4b273b3a8f6e00 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 12:07:40 +0100 Subject: [PATCH 24/48] Improve accumulator docs --- src/accumulators.jl | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index c096a25a7..fb63ecf2f 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -22,8 +22,6 @@ See the documentation for each of these functions for more details. """ abstract type AbstractAccumulator end -# TODO(mhauru) Add to the above docstring stuff about resets. - """ accumulator_name(acc::AbstractAccumulator) @@ -88,17 +86,18 @@ See also: [`split`](@ref) """ function combine end +# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in +# src/varinfo.jl. """ - acc!!(acc::AbstractAccumulator, args...) + convert_eltype(acc::AbstractAccumulator, ::Type{T}) -Update `acc` with the values in `args`. Returns the updated `acc`. +Convert `acc` to use element type `T`. -What this means depends greatly on the type of `acc`. For example, for `LogPrior` `args` -would be just `logp`. The utility of this function is that one can call -`acc!!(varinfo::AbstractVarinfo, Val(accname), args...)`, and this call will be propagated -to a call on the particular accumulator. +What "element type" means depends on the type of `acc`. By default this function does +nothing. Accumulator types that need to hold differentiable values, such as dual numbers +used by various AD backends, should implement a method for this function. """ -function acc!! end +convert_eltype(acc::AbstractAccumulator, ::Type) = acc # END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE @@ -314,12 +313,12 @@ function Base.convert(::Type{NumProduce{T}}, acc::NumProduce) where {T} return NumProduce(convert(T, acc.num)) end +# TODO(mhauru) +# We ignore the convert_eltype calls for NumProduce, by letting them fallback on +# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to +# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is +# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. convert_eltype(acc::LogPrior, ::Type{T}) where {T} = LogPrior(convert(T, acc.logp)) function convert_eltype(acc::LogLikelihood, ::Type{T}) where {T} return LogLikelihood(convert(T, acc.logp)) end -# TODO(mhauru) -# We ignore the convert_eltype calls for NumProduce. This is because they are only used to -# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is -# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. -convert_eltype(acc::NumProduce, ::Type) = NumProduce(acc.num) From 1dbcb2ccc051bfa9734e009e6c682c98c0622ae3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 12:08:09 +0100 Subject: [PATCH 25/48] Add test/accumulators.jl --- test/accumulators.jl | 161 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 162 insertions(+) create mode 100644 test/accumulators.jl diff --git a/test/accumulators.jl b/test/accumulators.jl new file mode 100644 index 000000000..40c7619f9 --- /dev/null +++ b/test/accumulators.jl @@ -0,0 +1,161 @@ +module AccumulatoTests + +using Test +using Distributions +using DynamicPPL +using DynamicPPL: + AccumulatorTuple, + LogLikelihood, + LogPrior, + NumProduce, + accumulate_assume!!, + accumulate_observe!!, + combine, + convert_eltype, + getacc, + increment, + map_accumulator!!, + setacc!!, + split + +@testset "accumulators" begin + @testset "individual accumulator types" begin + @testset "constructors" begin + @test LogPrior(0.0) == + LogPrior() == + LogPrior{Float64}() == + LogPrior{Float64}(0.0) == + zero(LogPrior(1.0)) + @test LogLikelihood(0.0) == + LogLikelihood() == + LogLikelihood{Float64}() == + LogLikelihood{Float64}(0.0) == + zero(LogLikelihood(1.0)) + @test NumProduce(0) == + NumProduce() == + NumProduce{Int}() == + NumProduce{Int}(0) == + zero(NumProduce(1)) + end + + @testset "addition and incrementation" begin + @test LogPrior(1.0f0) + LogPrior(1.0f0) == LogPrior(2.0f0) + @test LogPrior(1.0) + LogPrior(1.0f0) == LogPrior(2.0) + @test LogLikelihood(1.0f0) + LogLikelihood(1.0f0) == LogLikelihood(2.0f0) + @test LogLikelihood(1.0) + LogLikelihood(1.0f0) == LogLikelihood(2.0) + @test increment(NumProduce()) == NumProduce(1) + @test increment(NumProduce{UInt8}()) == NumProduce{UInt8}(1) + end + + @testset "split and combine" begin + for acc in [ + LogPrior(1.0), + LogLikelihood(1.0), + NumProduce(1), + LogPrior(1.0f0), + LogLikelihood(1.0f0), + NumProduce(UInt8(1)), + ] + @test combine(acc, split(acc)) == acc + end + end + + @testset "conversions" begin + @test convert(LogPrior{Float32}, LogPrior(1.0)) == LogPrior{Float32}(1.0f0) + @test convert(LogLikelihood{Float32}, LogLikelihood(1.0)) == + LogLikelihood{Float32}(1.0f0) + @test convert(NumProduce{UInt8}, NumProduce(1)) == NumProduce{UInt8}(1) + + @test convert_eltype(LogPrior(1.0), Float32) == LogPrior{Float32}(1.0f0) + @test convert_eltype(LogLikelihood(1.0), Float32) == + LogLikelihood{Float32}(1.0f0) + end + + @testset "accumulate_assume" begin + val = 2.0 + logjac = pi + vn = @varname(x) + dist = Normal() + @test accumulate_assume!!(LogPrior(1.0), val, logjac, vn, dist) == + LogPrior(1.0 + logjac + logpdf(dist, val)) + @test accumulate_assume!!(LogLikelihood(1.0), val, logjac, vn, dist) == + LogLikelihood(1.0) + @test accumulate_assume!!(NumProduce(1), val, logjac, vn, dist) == NumProduce(1) + end + + @testset "accumulate_observe" begin + right = Normal() + left = 2.0 + vn = @varname(x) + @test accumulate_observe!!(LogPrior(1.0), right, left, vn) == LogPrior(1.0) + @test accumulate_observe!!(LogLikelihood(1.0), right, left, vn) == + LogLikelihood(1.0 + logpdf(right, left)) + @test accumulate_observe!!(NumProduce(1), right, left, vn) == NumProduce(2) + end + end + + @testset "accumulator tuples" begin + # Some accumulators we'll use for testing + lp_f64 = LogPrior(1.0) + lp_f32 = LogPrior(1.0f0) + ll_f64 = LogLikelihood(1.0) + ll_f32 = LogLikelihood(1.0f0) + np_i64 = NumProduce(1) + + @testset "constructors" begin + @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) + # Names in NamedTuple arguments are ignored + @test AccumulatorTuple((; a=lp_f64)) == AccumulatorTuple(lp_f64) + + # Can't have two accumulators of the same type. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f64) + # Not even if their element types differ. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f32) + end + + @testset "basic operations" begin + at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64) + + @test at_all64[:LogPrior] == lp_f64 + @test at_all64[:LogLikelihood] == ll_f64 + @test at_all64[:NumProduce] == np_i64 + + @test haskey(AccumulatorTuple(np_i64), Val(:NumProduce)) + @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 + @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce) + @test collect(at_all64) == [lp_f64, ll_f64, np_i64] + + # Replace the existing LogPrior + @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 + # Check that setacc!! didn't modify the original + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) + # Add a new accumulator type. + @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == + AccumulatorTuple(lp_f64, ll_f64) + + @test getacc(at_all64, Val(:LogPrior)) == lp_f64 + end + + @testset "map_accumulator!!" begin + # map over all accumulators + accs = AccumulatorTuple(lp_f32, ll_f32) + @test map_accumulator!!(accs, zero) == + AccumulatorTuple(LogPrior(0.0f0), LogLikelihood(0.0f0)) + # Test that the original wasn't modified. + @test accs == AccumulatorTuple(lp_f32, ll_f32) + + # A map with extra arguments that changes the types of the accumulators. + @test map_accumulator!!(accs, convert_eltype, Float64) == + AccumulatorTuple(LogPrior(1.0), LogLikelihood(1.0)) + + # only apply to a particular accumulator + @test map_accumulator!!(accs, Val(:LogLikelihood), zero) == + AccumulatorTuple(lp_f32, LogLikelihood(0.0f0)) + @test map_accumulator!!(accs, Val(:LogLikelihood), convert_eltype, Float64) == + AccumulatorTuple(lp_f32, LogLikelihood(1.0)) + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 79284f2c4..4a9acf4e1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,6 +49,7 @@ include("test_util.jl") include("Aqua.jl") end include("utils.jl") + include("accumulators.jl") include("compiler.jl") include("varnamedvector.jl") include("varinfo.jl") From e1b70e062c96f2e8772112f92ba4b30c9e5fa98a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 12:14:33 +0100 Subject: [PATCH 26/48] Docs fixes --- docs/src/api.md | 11 +++++------ src/contexts.jl | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index c6e055cf9..b85b71064 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -328,9 +328,9 @@ The following functions were used for sequential Monte Carlo methods. ```@docs get_num_produce -set_num_produce! -increment_num_produce! -reset_num_produce! +set_num_produce!! +increment_num_produce!! +reset_num_produce!! setorder! set_retained_vns_del! ``` @@ -431,8 +431,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs SamplingContext DefaultContext -LikelihoodContext -PriorContext PrefixContext ConditionContext ``` @@ -477,9 +475,10 @@ DynamicPPL.Experimental.is_suitable_varinfo ### [Model-Internal Functions](@id model_internal) ```@docs +tilde_assume!! tilde_assume ``` ```@docs -tilde_observe +tilde_observe!! ``` diff --git a/src/contexts.jl b/src/contexts.jl index 6c82e2e13..3926282e3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -192,8 +192,8 @@ getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute the log joint probability of the data -and parameters when running the model. +The `DefaultContext` is used by default to accumulate values like the log joint probability +when running the model. """ struct DefaultContext <: AbstractContext end NodeTrait(::DefaultContext) = IsLeaf() From 3f195e5f8adc8e5a4fa54e459005c7db861f817c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 16:20:15 +0100 Subject: [PATCH 27/48] Various small fixes --- src/abstract_varinfo.jl | 4 +--- src/varinfo.jl | 14 +++++++------- test/contexts.jl | 7 ++----- test/pointwise_logdensities.jl | 2 +- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 8fcdd70ce..e8a757e5e 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -105,15 +105,13 @@ end """ setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple) - setaccs!!(vi::AbstractVarInfo, accs::AbstractAccumulator...) + setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator} where {N}) Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense. `setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype of `AbstractVarInfo`. """ -function setaccs!! end - function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N} return setaccs!!(vi, AccumulatorTuple(accs)) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 574b64160..b805b2bf4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -292,7 +292,7 @@ function typed_varinfo(vi::UntypedVarInfo) ) end nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, vi.accs) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -353,7 +353,7 @@ single `VarNamedVector` as its metadata field. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - return VarInfo(md, vi.accs) + return VarInfo(md, deepcopy(vi.accs)) end function untyped_vector_varinfo( rng::Random.AbstractRNG, @@ -396,12 +396,12 @@ NamedTuple of `VarNamedVector`s as its metadata field. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - return VarInfo(md, vi.accs) + return VarInfo(md, deepcopy(vi.accs)) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) nt = NamedTuple(new_metas) - return VarInfo(nt, vi.accs) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_vector_varinfo( rng::Random.AbstractRNG, @@ -450,7 +450,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just # plain ugly and hacky. et = float_type_with_fallback(eltype(x)) - accs = map_accumulator!!(vi.accs, convert_eltype, et) + accs = map_accumulator!!(deepcopy(vi.accs), convert_eltype, et) return VarInfo(md, accs) end @@ -533,7 +533,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, varinfo.accs) + return VarInfo(metadata, deepcopy(varinfo.accs)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -622,7 +622,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, varinfo_right.accs) + return VarInfo(metadata, deepcopy(varinfo_right.accs)) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) diff --git a/test/contexts.jl b/test/contexts.jl index e01c34a41..5f22b75eb 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -46,9 +46,8 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "contexts.jl" begin - child_contexts = Dict(:default => DefaultContext()) - - parent_contexts = Dict( + contexts = Dict( + :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), :prefix => PrefixContext(@varname(x)), @@ -63,8 +62,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :condition4 => ConditionContext((x=[1.0, missing],)), ) - contexts = merge(child_contexts, parent_contexts) - @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 3406a58e1..cfb222b66 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,5 +1,5 @@ @testset "logdensities_likelihoods.jl" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS[1:1] + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. From 68b974a5249a76dc67de8743ae216dc56a5de08a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 16:21:01 +0100 Subject: [PATCH 28/48] Make DynamicTransformation not use accumulators other than LogPrior --- src/transforming.jl | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/transforming.jl b/src/transforming.jl index 40b3df61c..300a2a09a 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -27,7 +27,9 @@ function tilde_assume( # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. r_transformed = isinverse ? r : link_transform(right)(r) - vi = acclogprior!!(vi, lp) + if hasacc(vi, Val(:LogPrior)) + vi = acclogprior!!(vi, lp) + end return r, setindex!!(vi, r_transformed, vn) end @@ -36,14 +38,36 @@ function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + return _transform!!(t, DynamicTransformationContext{false}(), vi, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), - ) + return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) +end + +function _transform( + t::AbstractTransformation, + ctx::DynamicTransformationContext, + vi::AbstractVarInfo, + model::Model, +) + # To transform using DynamicTransformationContext, we evaluate the model, but we do not + # need to use any accumulators other than LogPrior (which is affected by the Jacobian of + # the transformation). + accs = getaccs(vi.accs) + has_logprior = hasacc(accs, Val(:LogPrior)) + if has_logprior + old_logprior = getacc(accs, Val(:LogPrior)) + vi = setaccs!!(vi, (old_logprior,)) + end + vi = settrans!!(last(evaluate!!(model, vi, ctx)), t) + # Restore the accumulators. + if has_logprior + new_logprior = getacc(vi, Val(:LogPrior)) + accs = setacc!!(accs, new_logprior) + end + vi = setaccs!!(vi, accs) + return vi end function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) From 2b405d9ebb6c889ab7674c5e416e9fe6d25a86c4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 16:57:49 +0100 Subject: [PATCH 29/48] Fix variable order and name of map_accumulator!! --- src/abstract_varinfo.jl | 40 +++++++++++++++++++--------------------- src/accumulators.jl | 35 ++++++++++++----------------------- src/simple_varinfo.jl | 2 +- src/threadsafe.jl | 20 +++++++++----------- src/transforming.jl | 6 +++--- src/varinfo.jl | 6 +++--- test/accumulators.jl | 22 +++++++++++----------- 7 files changed, 58 insertions(+), 73 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index e8a757e5e..cd59475ed 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -239,7 +239,7 @@ end Update all the accumulators of `vi` by calling `accumulate_assume!!` on them. """ function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) - return map_accumulator!!(vi, accumulate_assume!!, val, logjac, vn, right) + return map_accumulators!!(acc -> accumulate_assume!!(acc, val, logjac, vn, right), vi) end """ @@ -248,37 +248,35 @@ end Update all the accumulators of `vi` by calling `accumulate_observe!!` on them. """ function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) - return map_accumulator!!(vi, accumulate_observe!!, right, left, vn) + return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi) end """ - map_accumulator!!(vi::AbstractVarInfo, func::Function, args...) where {accname} + map_accumulators(vi::AbstractVarInfo, func::Function) -Update all accumulators of `vi` by calling `func(acc, args...)` on them and replacing -them with the return values. +Update all accumulators of `vi` by calling `func` on them and replacing them with the return +values. """ -function map_accumulator!!(vi::AbstractVarInfo, func::Function, args...) - return setaccs!!(vi, map_accumulator!!(getaccs(vi), func, args...)) +function map_accumulators!!(func::Function, vi::AbstractVarInfo) + return setaccs!!(vi, map(func, getaccs(vi))) end """ - map_accumulator!!(vi::AbstractVarInfo, ::Val{accname}, func::Function, args...) where {accname} + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname} -Update the accumulator `accname` of `vi` by calling `func(acc, args...)` on and replacing -it with the return value. +Update the accumulator `accname` of `vi` by calling `func` on it and replacing it with the +return value. """ -function map_accumulator!!(vi::AbstractVarInfo, accname::Val, func::Function, args...) - return setaccs!!(vi, map_accumulator!!(getaccs(vi), accname, func, args...)) +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Val) + return setaccs!!(vi, map_accumulator(func, getaccs(vi), accname)) end -function map_accumulator!!(vi::AbstractVarInfo, accname::Symbol, func::Function, args...) +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) return error( """ - The method - map_accumulator!!(vi::AbstractVarInfo, accname::Symbol, func::Function, args...) + The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) does not exist. For type stability reasons use - map_accumulator!!(vi::AbstractVarInfo, accname::Val, func::Function, args...) - instead. + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead. """ ) end @@ -291,7 +289,7 @@ Add `logp` to the value of the log of the prior probability in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). """ function acclogprior!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(vi, Val(:LogPrior), +, LogPrior(logp)) + return map_accumulator!!(acc -> acc + LogPrior(logp), vi, Val(:LogPrior)) end """ @@ -302,7 +300,7 @@ Add `logp` to the value of the log of the likelihood in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). """ function accloglikelihood!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(vi, Val(:LogLikelihood), +, LogLikelihood(logp)) + return map_accumulator!!(acc -> acc + LogLikelihood(logp), vi, Val(:LogLikelihood)) end """ @@ -326,10 +324,10 @@ Reset the values of the log probabilities (prior and likelihood) in `vi` """ function resetlogp!!(vi::AbstractVarInfo) if hasacc(vi, Val(:LogPrior)) - vi = map_accumulator!!(vi, Val(:LogPrior), zero) + vi = map_accumulator!!(zero, vi, Val(:LogPrior)) end if hasacc(vi, Val(:LogLikelihood)) - vi = map_accumulator!!(vi, Val(:LogLikelihood), zero) + vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) end return vi end diff --git a/src/accumulators.jl b/src/accumulators.jl index fb63ecf2f..ec91e20c2 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -89,7 +89,7 @@ function combine end # TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in # src/varinfo.jl. """ - convert_eltype(acc::AbstractAccumulator, ::Type{T}) + convert_eltype(::Type{T}, acc::AbstractAccumulator) Convert `acc` to use element type `T`. @@ -97,7 +97,7 @@ What "element type" means depends on the type of `acc`. By default this function nothing. Accumulator types that need to hold differentiable values, such as dual numbers used by various AD backends, should implement a method for this function. """ -convert_eltype(acc::AbstractAccumulator, ::Type) = acc +convert_eltype(::Type, acc::AbstractAccumulator) = acc # END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE @@ -167,36 +167,25 @@ function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} return at[accname] end -""" - map_accumulator!!(at::AccumulatorTuple, func::Function, args...) - -Update the accumulators in `at` by calling `func(acc, args...)` on them and replacing them -with the return values. - -Returns a new `AccumulatorTuple`. The `!!` in the name is for consistency with the -corresponding function for `AbstractVarInfo`. -""" -function map_accumulator!!(at::AccumulatorTuple, func::Function, args...) - return AccumulatorTuple(map(acc -> func(acc, args...), at.nt)) +function Base.map(func::Function, at::AccumulatorTuple) + return AccumulatorTuple(map(func, at.nt)) end """ - map_accumulator!!(at::AccumulatorTuple, ::Val{accname}, func::Function, args...) + map_accumulator(func::Function, at::AccumulatorTuple, ::Val{accname}) -Update the accumulator with name `accname` in `at` by calling `func(acc, args...)` on it -and replacing it with the return value. +Update the accumulator with name `accname` in `at` by calling `func` on it. -Returns a new `AccumulatorTuple`. The `!!` in the name is for consistency with the -corresponding function for `AbstractVarInfo`. +Returns a new `AccumulatorTuple`. """ -function map_accumulator!!( - at::AccumulatorTuple, ::Val{accname}, func::Function, args... +function map_accumulator( + func::Function, at::AccumulatorTuple, ::Val{accname} ) where {accname} # Would like to write this as # return Accessors.@set at.nt[accname] = func(at[accname], args...) # for readability, but that one isn't type stable due to # https://github.com/JuliaObjects/Accessors.jl/issues/198 - new_val = func(at[accname], args...) + new_val = func(at[accname]) new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,))) return AccumulatorTuple(new_nt) end @@ -318,7 +307,7 @@ end # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to # deal with dual number types of AD backends, which shouldn't concern NumProduce. This is # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. -convert_eltype(acc::LogPrior, ::Type{T}) where {T} = LogPrior(convert(T, acc.logp)) -function convert_eltype(acc::LogLikelihood, ::Type{T}) where {T} +convert_eltype(::Type{T}, acc::LogPrior) where {T} = LogPrior(convert(T, acc.logp)) +function convert_eltype(::Type{T}, acc::LogLikelihood) where {T} return LogLikelihood(convert(T, acc.logp)) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index aba7e7d39..2a2e9eaf4 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -251,7 +251,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is # required but undesireable. et = float_type_with_fallback(eltype(x)) - accs = map_accumulator!!(svi.accs, convert_eltype, et) + accs = map(acc -> convert_eltype(et, acc), svi.accs) return SimpleVarInfo(vals, accs, svi.transformation) end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9155f5da5..56836b366 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -46,19 +46,17 @@ function getaccs(vi::ThreadSafeVarInfo) return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) end -# Calls to map_accumulator!! are thread-specific by default. For any use of them that should -# _not_ be thread-specific a specific method has to be written. -function map_accumulator!!(vi::ThreadSafeVarInfo, accname::Val, func::Function, args...) +# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that +# should _not_ be thread-specific a specific method has to be written. +function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) tid = Threads.threadid() - vi.accs_by_thread[tid] = map_accumulator!!( - vi.accs_by_thread[tid], accname, func, args... - ) + vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) return vi end -function map_accumulator!!(vi::ThreadSafeVarInfo, func::Function, args...) +function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) tid = Threads.threadid() - vi.accs_by_thread[tid] = map_accumulator!!(vi.accs_by_thread[tid], func, args...) + vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) return vi end @@ -186,9 +184,9 @@ end function resetlogp!!(vi::ThreadSafeVarInfo) vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) for i in eachindex(vi.accs_by_thread) - vi.accs_by_thread[i] = map_accumulator!!(vi.accs_by_thread[i], Val(:LogPrior), zero) - vi.accs_by_thread[i] = map_accumulator!!( - vi.accs_by_thread[i], Val(:LogLikelihood), zero + vi.accs_by_thread[i] = map_accumulator(zero, vi.accs_by_thread[i], Val(:LogPrior)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogLikelihood) ) end return vi diff --git a/src/transforming.jl b/src/transforming.jl index 300a2a09a..773d8fb7e 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -45,7 +45,7 @@ function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) end -function _transform( +function _transform!!( t::AbstractTransformation, ctx::DynamicTransformationContext, vi::AbstractVarInfo, @@ -54,8 +54,8 @@ function _transform( # To transform using DynamicTransformationContext, we evaluate the model, but we do not # need to use any accumulators other than LogPrior (which is affected by the Jacobian of # the transformation). - accs = getaccs(vi.accs) - has_logprior = hasacc(accs, Val(:LogPrior)) + accs = getaccs(vi) + has_logprior = haskey(accs, Val(:LogPrior)) if has_logprior old_logprior = getacc(accs, Val(:LogPrior)) vi = setaccs!!(vi, (old_logprior,)) diff --git a/src/varinfo.jl b/src/varinfo.jl index b805b2bf4..34f622c61 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -450,7 +450,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just # plain ugly and hacky. et = float_type_with_fallback(eltype(x)) - accs = map_accumulator!!(deepcopy(vi.accs), convert_eltype, et) + accs = map(acc -> convert_eltype(et, acc), deepcopy(getaccs(vi))) return VarInfo(md, accs) end @@ -1032,7 +1032,7 @@ set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduce(n)) Add 1 to `num_produce` in `vi`. """ -increment_num_produce!!(vi::VarInfo) = map_accumulator!!(vi, Val(:NumProduce), increment) +increment_num_produce!!(vi::VarInfo) = map_accumulator!!(increment, vi, Val(:NumProduce)) """ reset_num_produce!!(vi::VarInfo) @@ -1040,7 +1040,7 @@ increment_num_produce!!(vi::VarInfo) = map_accumulator!!(vi, Val(:NumProduce), i Reset the value of `num_produce` the log of the joint probability of the observed data and parameters sampled in `vi` to 0. """ -reset_num_produce!!(vi::VarInfo) = map_accumulator!!(vi, Val(:NumProduce), zero) +reset_num_produce!!(vi::VarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce)) # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata) diff --git a/test/accumulators.jl b/test/accumulators.jl index 40c7619f9..a4a31f805 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -14,7 +14,7 @@ using DynamicPPL: convert_eltype, getacc, increment, - map_accumulator!!, + map_accumulator, setacc!!, split @@ -66,8 +66,8 @@ using DynamicPPL: LogLikelihood{Float32}(1.0f0) @test convert(NumProduce{UInt8}, NumProduce(1)) == NumProduce{UInt8}(1) - @test convert_eltype(LogPrior(1.0), Float32) == LogPrior{Float32}(1.0f0) - @test convert_eltype(LogLikelihood(1.0), Float32) == + @test convert_eltype(Float32, LogPrior(1.0)) == LogPrior{Float32}(1.0f0) + @test convert_eltype(Float32, LogLikelihood(1.0)) == LogLikelihood{Float32}(1.0f0) end @@ -137,23 +137,23 @@ using DynamicPPL: @test getacc(at_all64, Val(:LogPrior)) == lp_f64 end - @testset "map_accumulator!!" begin + @testset "map_accumulator(s)!!" begin # map over all accumulators accs = AccumulatorTuple(lp_f32, ll_f32) - @test map_accumulator!!(accs, zero) == - AccumulatorTuple(LogPrior(0.0f0), LogLikelihood(0.0f0)) + @test map(zero, accs) == AccumulatorTuple(LogPrior(0.0f0), LogLikelihood(0.0f0)) # Test that the original wasn't modified. @test accs == AccumulatorTuple(lp_f32, ll_f32) - # A map with extra arguments that changes the types of the accumulators. - @test map_accumulator!!(accs, convert_eltype, Float64) == + # A map with a closure that changes the types of the accumulators. + @test map(acc -> convert_eltype(Float64, acc), accs) == AccumulatorTuple(LogPrior(1.0), LogLikelihood(1.0)) # only apply to a particular accumulator - @test map_accumulator!!(accs, Val(:LogLikelihood), zero) == + @test map_accumulator(zero, accs, Val(:LogLikelihood)) == AccumulatorTuple(lp_f32, LogLikelihood(0.0f0)) - @test map_accumulator!!(accs, Val(:LogLikelihood), convert_eltype, Float64) == - AccumulatorTuple(lp_f32, LogLikelihood(1.0)) + @test map_accumulator( + acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) + ) == AccumulatorTuple(lp_f32, LogLikelihood(1.0)) end end end From 00cd304352804b2279a64fea8c98cab9f41760a3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 17:03:13 +0100 Subject: [PATCH 30/48] Typo fixing --- src/threadsafe.jl | 4 +--- test/threadsafe.jl | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 56836b366..cd8eda97c 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -9,9 +9,7 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarI accs_by_thread::Vector{L} end function ThreadSafeVarInfo(vi::AbstractVarInfo) - accs_by_thread = [ - AccumulatorTuple(map(split, vi.accs.nt)) for _ in 1:Threads.nthreads() - ] + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.nthreads()] return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 2fa84bad8..d88111e46 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -7,7 +7,7 @@ @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in vi.accs)... + (DynamicPPL.split(acc) for acc in getaccs(vi))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @@ -28,14 +28,14 @@ threadsafe_vi = resetlogp!!(threadsafe_vi) @test iszero(getlogjoint(threadsafe_vi)) expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)... + (DynamicPPL.split(acc) for acc in getaccs(threadsafe_vi.varinfo))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) threadsafe_vi = setlogprior!!(threadsafe_vi, 42) @test getlogjoint(threadsafe_vi) == 42 expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in threadsafe_vi.varinfo.accs)... + (DynamicPPL.split(acc) for acc in getaccs(threadsafe_vi.varinfo))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end From 6d1048df42f15e6d6bd20b2d1c8572b0850e6da9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 17:13:28 +0100 Subject: [PATCH 31/48] Small improvement to ThreadSafeVarInfo --- src/threadsafe.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cd8eda97c..7d2d768a6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -182,10 +182,16 @@ end function resetlogp!!(vi::ThreadSafeVarInfo) vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) for i in eachindex(vi.accs_by_thread) - vi.accs_by_thread[i] = map_accumulator(zero, vi.accs_by_thread[i], Val(:LogPrior)) - vi.accs_by_thread[i] = map_accumulator( - zero, vi.accs_by_thread[i], Val(:LogLikelihood) - ) + if hasacc(vi, Val(:LogPrior)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogPrior) + ) + end + if hasacc(vi, Val(:LogLikelihood)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogLikelihood) + ) + end end return vi end From 4fef20f2463d11775a85226c26148883e8b8a89a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 17:24:09 +0100 Subject: [PATCH 32/48] Fix demo_dot_assume_observe_submodel prefixing --- src/test_utils/models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 90b6ac7ac..12f88acad 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -476,7 +476,7 @@ end # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to # capture the result, so we just use a dummy variable - _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) + _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x), false) return (; s=s, m=m, x=x) end From 557954a1142fb4c70a5379ee6886e161b1c8a8b6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 18:14:04 +0100 Subject: [PATCH 33/48] Typo fixing --- src/simple_varinfo.jl | 2 +- test/threadsafe.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 39efbfffc..200c8bece 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -231,7 +231,7 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {names,D} +function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} values = values_as(vi, D) return SimpleVarInfo(values, vi.accs) end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index d88111e46..5b4f6951f 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -7,7 +7,7 @@ @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in getaccs(vi))... + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @@ -28,14 +28,14 @@ threadsafe_vi = resetlogp!!(threadsafe_vi) @test iszero(getlogjoint(threadsafe_vi)) expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in getaccs(threadsafe_vi.varinfo))... + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) threadsafe_vi = setlogprior!!(threadsafe_vi, 42) @test getlogjoint(threadsafe_vi) == 42 expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in getaccs(threadsafe_vi.varinfo))... + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end From 6f702c99f5ff57ce5b75639c7ebea4e91f9408bd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 15:20:12 +0100 Subject: [PATCH 34/48] Miscellaneous small fixes --- src/DynamicPPL.jl | 3 +++ src/model.jl | 11 +++++++++ src/pointwise_logdensities.jl | 43 ++++++++++++++++++++++++----------- src/simple_varinfo.jl | 22 ++++++++++++++---- src/utils.jl | 4 +++- src/varinfo.jl | 4 ++-- 6 files changed, 67 insertions(+), 20 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 801840e8f..e0604f458 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -152,6 +152,9 @@ macro prob_str(str) )) end +# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo +# has to implement. Not sure what the full list is for parameters values, but for +# accumulators we only need `getaccs` and `setaccs!!`. """ AbstractVarInfo diff --git a/src/model.jl b/src/model.jl index b54adbde0..4754b5906 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1057,6 +1057,10 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) + # Remove other accumulators from varinfo, since they are unnecessary. + logprior = + hasacc(varinfo, Val(:LogPrior)) ? getacc(varinfo, Val(:LogPrior)) : LogPrior() + varinfo = setaccs!!(deepcopy(varinfo), (logprior,)) return getlogprior(last(evaluate!!(model, varinfo, DefaultContext()))) end @@ -1104,6 +1108,13 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) + # Remove other accumulators from varinfo, since they are unnecessary. + loglikelihood = if hasacc(varinfo, Val(:LogLikelihood)) + getacc(varinfo, Val(:LogLikelihood)) + else + LogLikelihood() + end + varinfo = setaccs!!(deepcopy(varinfo), (loglikelihood,)) return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext()))) end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index e608e523d..d998858c1 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -66,7 +66,10 @@ function accumulate_assume!!( acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right ) where {whichlogprob} if whichlogprob == :both || whichlogprob == :prior - subacc = accumulate_assume!!(LogPrior{LogProbType}(), val, logjac, vn, right) + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_assume!!(LogPrior{T}(), val, logjac, vn, right) push!(acc, vn, subacc.logp) end return acc @@ -81,21 +84,34 @@ function accumulate_observe!!( return acc end if whichlogprob == :both || whichlogprob == :likelihood - subacc = accumulate_observe!!(LogLikelihood{LogProbType}(), right, left, vn) + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_observe!!(LogLikelihood{T}(), right, left, vn) push!(acc, vn, subacc.logp) end return acc end """ - pointwise_logdensities(model::Model, chain::Chains, keytype = String) + pointwise_logdensities( + model::Model, + chain::Chains, + keytype=String, + context=DefaultContext(), + ::Val{whichlogprob}=Val(:both), + ) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. +Currently, only `String` and `VarName` are supported. `context` is the evaluation context, +and `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, +`:prior`, or `:likelihood`. + +See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). # Notes Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` @@ -204,8 +220,8 @@ function pointwise_logdensities( # Get the data by executing the model once vi = VarInfo(model) - acctype = PointwiseLogProbAccumulator{whichlogprob,KeyType} - vi = setaccs!!(vi, AccumulatorTuple(acctype())) + AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} + vi = setaccs!!(vi, (AccType(),)) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -216,7 +232,7 @@ function pointwise_logdensities( vi = last(evaluate!!(model, vi, context)) end - logps = getacc(vi, Val(accumulator_name(acctype))).logps + logps = getacc(vi, Val(accumulator_name(AccType))).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( @@ -231,11 +247,10 @@ function pointwise_logdensities( context::AbstractContext=DefaultContext(), ::Val{whichlogprob}=Val(:both), ) where {whichlogprob} - acctype = PointwiseLogProbAccumulator{whichlogprob} - # TODO(mhauru) Don't needlessly evaluate the model twice. - varinfo = setaccs!!(varinfo, AccumulatorTuple(acctype())) + AccType = PointwiseLogProbAccumulator{whichlogprob} + varinfo = setaccs!!(varinfo, (AccType(),)) varinfo = last(evaluate!!(model, varinfo, context)) - return getacc(varinfo, Val(accumulator_name(acctype))).logps + return getacc(varinfo, Val(accumulator_name(AccType))).logps end """ @@ -244,7 +259,8 @@ end Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). """ function pointwise_loglikelihoods( model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() @@ -264,7 +280,8 @@ end Compute the pointwise log-prior-densities of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the prior terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 200c8bece..8c377de3f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -209,6 +209,15 @@ end function SimpleVarInfo(values) return SimpleVarInfo{LogProbType}(values) end +function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict}) + return if isempty(values) + # Can't infer from values, so we just use default. + SimpleVarInfo{LogProbType}(values) + else + # Infer from `values`. + SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) + end +end # Using `kwargs` to specify the values. function SimpleVarInfo{T}(; kwargs...) where {T<:Real} @@ -233,7 +242,12 @@ end # Constructor from `VarInfo`. function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} values = values_as(vi, D) - return SimpleVarInfo(values, vi.accs) + return SimpleVarInfo(values, deepcopy(getaccs(vi))) +end +function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} + values = values_as(vi, D) + accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) + return SimpleVarInfo(values, accs) end function untyped_simple_varinfo(model::Model) @@ -250,8 +264,8 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) vals = unflatten(svi.values, x) # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is # required but undesireable. - et = float_type_with_fallback(eltype(x)) - accs = map(acc -> convert_eltype(et, acc), svi.accs) + T = float_type_with_fallback(eltype(x)) + accs = map(acc -> convert_eltype(T, acc), getaccs(svi)) return SimpleVarInfo(vals, accs, svi.transformation) end @@ -423,7 +437,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - accs = getaccs(varinfo_right) + accs = deepcopy(getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) diff --git a/src/utils.jl b/src/utils.jl index 1fb9faab7..2ed0c049e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -58,7 +58,9 @@ true """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) + if $hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = $accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) + end end end diff --git a/src/varinfo.jl b/src/varinfo.jl index 34f622c61..ad624a253 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -449,8 +449,8 @@ function unflatten(vi::VarInfo, x::AbstractVector) # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just # plain ugly and hacky. - et = float_type_with_fallback(eltype(x)) - accs = map(acc -> convert_eltype(et, acc), deepcopy(getaccs(vi))) + T = float_type_with_fallback(eltype(x)) + accs = map(acc -> convert_eltype(T, acc), deepcopy(getaccs(vi))) return VarInfo(md, accs) end From f748775e5c54ea296a8cea6a0c50f97becb804a7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 16:08:50 +0100 Subject: [PATCH 35/48] HISTORY entry and more miscellanea --- HISTORY.md | 15 +++++++++++++++ src/abstract_varinfo.jl | 9 ++++----- src/contexts.jl | 1 - 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 9a70e8d1f..2251b0ae3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,20 @@ # DynamicPPL Changelog +## 0.37.0 + +**Breaking changes** + +### Accumulators + +This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: + + - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPrior(),))`. + - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. + - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. + - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. + - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. + - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. + ## 0.36.0 **Breaking changes** diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index cd59475ed..fdeb3281b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -121,8 +121,7 @@ end Return the `AccumulatorTuple` of `vi`. -This should be implemented by each subtype of `AbstractVarInfo`. `getaccs` is not -user-facing, but used in the implementation of many other functions. +This should be implemented by each subtype of `AbstractVarInfo`. """ function getaccs end @@ -217,7 +216,7 @@ function setlogp!!(vi::AbstractVarInfo, logp) end """ - getacc(vi::AbstractVarInfo, accname) + getacc(vi::AbstractVarInfo, ::Val{accname}) Return the `AbstractAccumulator` of `vi` with name `accname`. """ @@ -252,7 +251,7 @@ function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) end """ - map_accumulators(vi::AbstractVarInfo, func::Function) + map_accumulators!!(func::Function, vi::AbstractVarInfo) Update all accumulators of `vi` by calling `func` on them and replacing them with the return values. @@ -320,7 +319,7 @@ end """ resetlogp!!(vi::AbstractVarInfo) -Reset the values of the log probabilities (prior and likelihood) in `vi` +Reset the values of the log probabilities (prior and likelihood) in `vi` to zero. """ function resetlogp!!(vi::AbstractVarInfo) if hasacc(vi, Val(:LogPrior)) diff --git a/src/contexts.jl b/src/contexts.jl index 3926282e3..addadfa1a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -37,7 +37,6 @@ Return the descendant context of `context`. """ childcontext -# TODO(mhauru) Rework the below docstring to not use PriorContext. """ setchildcontext(parent::AbstractContext, child::AbstractContext) From 5f4a5329c016f46acb167bb7b2df75d171165689 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 16:21:38 +0100 Subject: [PATCH 36/48] Add more tests for accumulators --- test/varinfo.jl | 50 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 260df2931..5d3533b23 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -152,6 +152,56 @@ end test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + @testset "accumulators" begin + @model function demo() + a ~ Normal() + b ~ Normal() + c ~ Normal() + d ~ Normal() + return nothing + end + + values = (; a=1.0, b=2.0, c=3.0, d=4.0) + lp_a = logpdf(Normal(), values.a) + lp_b = logpdf(Normal(), values.b) + lp_c = logpdf(Normal(), values.c) + lp_d = logpdf(Normal(), values.d) + m = demo() | (; c=values.c, d=values.d) + + vi = DynamicPPL.reset_num_produce!!( + DynamicPPL.unflatten(VarInfo(m), collect(values)) + ) + + vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) + @test getlogprior(vi) == lp_a + lp_b + @test getloglikelihood(vi) == lp_c + lp_d + @test get_num_produce(vi) == 2 + + vi = last( + DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPrior(),))) + ) + @test getlogprior(vi) == lp_a + lp_b + @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field LogLikelihood" getlogjoint(vi) + @test_throws "has no field NumProduce" get_num_produce(vi) + + vi = last( + DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduce(),))) + ) + @test_throws "has no field LogPrior" getlogprior(vi) + @test_throws "has no field LogPrior" getlogjoint(vi) + @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test get_num_produce(vi) == 2 + + # Test evaluating without any accumulators. + vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) + @test_throws "has no field LogPrior" getlogprior(vi) + @test_throws "has no field LogPrior" getlogjoint(vi) + @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field NumProduce" get_num_produce(vi) + @test_throws "has no field NumProduce" reset_num_produce!!(vi) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! From 31967fd11c6814007869afa4eecb547bc9cf528e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 16:37:55 +0100 Subject: [PATCH 37/48] Improve accumulators docstrings --- src/accumulators.jl | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index ec91e20c2..010ede08b 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -3,14 +3,14 @@ An abstract type for accumulators. -An accumulator is an object that may change its value at every tilde_assume or tilde_observe -call based on the value of the random variable in question. The obvious examples of +An accumulator is an object that may change its value at every tilde_assume!! or +tilde_observe!! call based on the random variable in question. The obvious examples of accumulators are the log prior and log likelihood. Others examples might be a variable that counts the number of observations in a trace, or a list of the names of random variables seen so far. -An accumulator type `T` must implement the following methods: -- `accumulator_name(acc::T)` +An accumulator type `T <: AbstractAccumulator` must implement the following methods: +- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` - `accumulate_observe!!(acc::T, right, left, vn)` - `accumulate_assume!!(acc::T, val, logjac, vn, right)` @@ -36,7 +36,7 @@ accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) """ accumulate_observe!!(acc::AbstractAccumulator, right, left, vn) -Update `acc` in a `tilde_observe` call. Returns the updated `acc`. +Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`. `vn` is the name of the variable being observed, `left` is the value of the variable, and `right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case @@ -51,7 +51,7 @@ function accumulate_observe!! end """ accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right) -Update `acc` in a `tilde_assume` call. Returns the updated `acc`. +Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. `vn` is the name of the variable being assumed, `val` is the value of the variable, and `right` is the distribution on the RHS of the tilde statement. `logjac` is the log @@ -114,9 +114,6 @@ constraint that the name in the tuple for each accumulator `acc` must be The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The names will be generated automatically. One can also call the constructor with a `NamedTuple` but the names in the argument will be discarded in favour of the generated ones. - -# Fields -$(TYPEDFIELDS) """ struct AccumulatorTuple{N,T<:NamedTuple} nt::T @@ -136,7 +133,10 @@ Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] Base.length(::AccumulatorTuple{N}) where {N} = N Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) -Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} = haskey(at.nt, accname) +function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} + # @inline to ensure constant propagation can resolve this to a compile-time constant. + @inline return haskey(at.nt, accname) +end Base.keys(at::AccumulatorTuple) = keys(at.nt) function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} @@ -201,11 +201,12 @@ An accumulator that tracks the cumulative log prior during model execution. $(TYPEDFIELDS) """ struct LogPrior{T} <: AbstractAccumulator + "the scalar log prior value" logp::T end """ - LogPrior{T}() where {T} + LogPrior{T}() Create a new `LogPrior` accumulator with the log prior initialized to zero. """ @@ -221,11 +222,12 @@ An accumulator that tracks the cumulative log likelihood during model execution. $(TYPEDFIELDS) """ struct LogLikelihood{T} <: AbstractAccumulator + "the scalar log likelihood value" logp::T end """ - LogLikelihood{T}() where {T} + LogLikelihood{T}() Create a new `LogLikelihood` accumulator with the log likelihood initialized to zero. """ @@ -241,11 +243,12 @@ An accumulator that tracks the number of observations during model execution. $(TYPEDFIELDS) """ struct NumProduce{T<:Integer} <: AbstractAccumulator + "the number of observations" num::T end """ - NumProduce{T}() where {T<:Integer} + NumProduce{T<:Integer}() Create a new `NumProduce` accumulator with the number of observations initialized to zero. """ From ad2f564aa1531f6f02b85f1561bb7dda687817e8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 16:38:15 +0100 Subject: [PATCH 38/48] Fix a typo --- test/accumulators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/accumulators.jl b/test/accumulators.jl index a4a31f805..e6a8dfcda 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -1,4 +1,4 @@ -module AccumulatoTests +module AccumulatorTests using Test using Distributions From 10b4f2f5af0ab20dd532f0eea91fae6333b3c342 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 16:43:33 +0100 Subject: [PATCH 39/48] Expand HISTORY entry --- HISTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/HISTORY.md b/HISTORY.md index 2251b0ae3..cadf97bce 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -12,6 +12,7 @@ This release overhauls how VarInfo objects track variables such as the log joint - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. + - For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`. - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. From d2b670d939ebb04ce925f3c8b1f770ddfced1175 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 16:51:06 +0100 Subject: [PATCH 40/48] Add accumulators to API docs --- docs/src/api.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index b85b71064..5095f8279 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -160,7 +160,7 @@ returned(::Model) ## Utilities -It is possible to manually increase (or decrease) the accumulated log density from within a model function. +It is possible to manually increase (or decrease) the accumulated log likelihood from within a model function. ```@docs @addlogprob! @@ -345,6 +345,22 @@ Base.empty! SimpleVarInfo ``` +### Accumulators + +The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. + +```@docs +AbstractAccumulators +``` + +DynamicPPL provides the following default accumulators. + +```@docs +LogPrior +LogLikelihood +NumProduce +``` + ### Common API #### Accumulation of log-probabilities From 8241d126c68fdeb22dcbe9cc70eac674b9bc0282 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 16:53:36 +0100 Subject: [PATCH 41/48] Remove unexported functions from API docs --- docs/src/api.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 5095f8279..13e9a30e1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -491,10 +491,5 @@ DynamicPPL.Experimental.is_suitable_varinfo ### [Model-Internal Functions](@id model_internal) ```@docs -tilde_assume!! tilde_assume ``` - -```@docs -tilde_observe!! -``` From 7b7a3e2aec2e9642c1ef15ec1227491c4142117d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 17:33:12 +0100 Subject: [PATCH 42/48] Add NamedTuple methods for get/set/acclogp --- HISTORY.md | 2 ++ docs/src/api.md | 9 +++-- src/abstract_varinfo.jl | 73 ++++++++++++++++++++++++++++++----------- test/submodels.jl | 10 +++--- test/varinfo.jl | 43 ++++++++++++++++++++++-- 5 files changed, 108 insertions(+), 29 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index cadf97bce..429811b04 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -15,6 +15,8 @@ This release overhauls how VarInfo objects track variables such as the log joint - For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`. - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. + - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. + - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The method with a single scalar value has been deprecated, and falls back on `setloglikelihood!!` or `accloglikelihood!!`. Corresponding setter/accumulator functions exist for the log prior as well. ## 0.36.0 diff --git a/docs/src/api.md b/docs/src/api.md index 13e9a30e1..0b333b604 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -366,12 +366,15 @@ NumProduce #### Accumulation of log-probabilities ```@docs -getlogprior -getloglikelihood +getlogp +setlogp!! +acclogp!! getlogjoint +getlogprior setlogprior!! -setloglikelihood!! acclogprior!! +getloglikelihood +setloglikelihood!! accloglikelihood!! resetlogp!! ``` diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index fdeb3281b..d70d9a70b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -98,9 +98,17 @@ Return the log of the joint probability of the observed data and parameters in ` See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). """ getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) + +""" + getlogp(vi::AbstractVarInfo) + +Return a NamedTuple of the log prior and log likelihood probabilities. + +The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an +error will be thrown. +""" function getlogp(vi::AbstractVarInfo) - Base.depwarn("getlogp is deprecated, use getlogjoint instead", :getlogp) - return getlogjoint(vi) + return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) end """ @@ -198,23 +206,31 @@ See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@re setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp)) """ - setlogp!!(vi::AbstractVarInfo, logp) + setlogp!!(vi::AbstractVarInfo, logp::NamedTuple) -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`, mutating if it makes sense. +Set both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and `loglikelihood` and no other fields. See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). """ -function setlogp!!(vi::AbstractVarInfo, logp) - Base.depwarn( - "setlogp!! is deprecated, use setlogprior!! or setloglikelihood!! instead", - :setlogp!!, - ) - vi = setlogprior!!(vi, zero(logp)) - vi = setloglikelihood!!(vi, logp) +function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} + if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) + error("logp must have the fields logprior and loglikelihood and no other fields.") + end + vi = setlogprior!!(vi, logp.logprior) + vi = setloglikelihood!!(vi, logp.loglikelihood) return vi end +function setlogp!!(vi::AbstractVarInfo, logp::Number) + depwarn( + "`setlogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `setloglikelihood!!(vi, logp)` instead.", + :setlogp, + ) + return setloglikelihood!!(vi, logp) +end + """ getacc(vi::AbstractVarInfo, ::Val{accname}) @@ -303,15 +319,34 @@ function accloglikelihood!!(vi::AbstractVarInfo, logp) end """ - acclogp!!(vi::AbstractVarInfo, logp) + acclogp!!(vi::AbstractVarInfo, logp::NamedTuple) + +Add to both the log prior and the log likelihood probabilities in `vi`. -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`, mutating if it makes sense. +`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields. """ -function acclogp!!(vi::AbstractVarInfo, logp) - Base.depwarn( - "acclogp!! is deprecated, use acclogprior!! or accloglikelihood!! instead", - :acclogp!!, +function acclogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} + if !( + names == (:logprior, :loglikelihood) || + names == (:loglikelihood, :logprior) || + names == (:logprior,) || + names == (:loglikelihood,) + ) + error("logp must have fields logprior and/or loglikelihood and no other fields.") + end + if haskey(logp, :logprior) + vi = acclogprior!!(vi, logp.logprior) + end + if haskey(logp, :loglikelihood) + vi = accloglikelihood!!(vi, logp.loglikelihood) + end + return vi +end + +function acclogp!!(vi::AbstractVarInfo, logp::Number) + depwarn( + "`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.", + :acclogp, ) return accloglikelihood!!(vi, logp) end diff --git a/test/submodels.jl b/test/submodels.jl index e79eed2c3..d3a2f17e7 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -35,7 +35,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) end @@ -67,7 +67,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(y)]) end @@ -99,7 +99,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) end @@ -148,7 +148,7 @@ using Test # No conditioning vi = VarInfo(h()) @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) - @test getlogp(vi) == + @test getlogjoint(vi) == logpdf(Normal(), vi[@varname(a.b.x)]) + logpdf(Normal(), vi[@varname(a.b.y)]) @@ -174,7 +174,7 @@ using Test @testset "$name" for (name, model) in models vi = VarInfo(model) @test Set(keys(vi)) == Set([@varname(a.b.y)]) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 5d3533b23..386b9eeba 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -175,29 +175,68 @@ end vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b @test getloglikelihood(vi) == lp_c + lp_d + @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d) + @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d @test get_num_produce(vi) == 2 + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = accloglikelihood!!(vi, 1.0) + getloglikelihood(vi) == lp_c + lp_d + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + @test begin + vi = setloglikelihood!!(vi, -1.0) + getloglikelihood(vi) == -1.0 + end + @test begin + vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0) + end + @test begin + vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) + getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0) + end + @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) vi = last( DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPrior(),))) ) @test getlogprior(vi) == lp_a + lp_b @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field LogLikelihood" getlogp(vi) @test_throws "has no field LogLikelihood" getlogjoint(vi) @test_throws "has no field NumProduce" get_num_produce(vi) + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + @test_throws "has no field LogLikelihood" setlogp!!(getlogp(vi)) vi = last( DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduce(),))) ) @test_throws "has no field LogPrior" getlogprior(vi) - @test_throws "has no field LogPrior" getlogjoint(vi) @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field LogPrior" getlogp(vi) + @test_throws "has no field LogPrior" getlogjoint(vi) @test get_num_produce(vi) == 2 # Test evaluating without any accumulators. vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) @test_throws "has no field LogPrior" getlogprior(vi) - @test_throws "has no field LogPrior" getlogjoint(vi) @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field LogPrior" getlogp(vi) + @test_throws "has no field LogPrior" getlogjoint(vi) @test_throws "has no field NumProduce" get_num_produce(vi) @test_throws "has no field NumProduce" reset_num_produce!!(vi) end From 0b08237d6152f73728fe8d527d53cd95fbcc8ad2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 17:42:11 +0100 Subject: [PATCH 43/48] Fix setlogp!! with single scalar to error --- HISTORY.md | 2 +- src/abstract_varinfo.jl | 9 ++++----- test/varinfo.jl | 1 - 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 429811b04..38af16970 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -16,7 +16,7 @@ This release overhauls how VarInfo objects track variables such as the log joint - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. - - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The method with a single scalar value has been deprecated, and falls back on `setloglikelihood!!` or `accloglikelihood!!`. Corresponding setter/accumulator functions exist for the log prior as well. + - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well. ## 0.36.0 diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index d70d9a70b..00c9952ce 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -224,11 +224,10 @@ function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} end function setlogp!!(vi::AbstractVarInfo, logp::Number) - depwarn( - "`setlogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `setloglikelihood!!(vi, logp)` instead.", - :setlogp, - ) - return setloglikelihood!!(vi, logp) + return error(""" + `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use + `setloglikelihood!!` and/or `setlogprior!!` instead. + """) end """ diff --git a/test/varinfo.jl b/test/varinfo.jl index 386b9eeba..3ee2305de 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -220,7 +220,6 @@ end vi = setlogprior!!(vi, -1.0) getlogprior(vi) == -1.0 end - @test_throws "has no field LogLikelihood" setlogp!!(getlogp(vi)) vi = last( DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduce(),))) From 2a4b874f868feeb7c252a3427ba56f9dd0f536f6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 25 Apr 2025 17:52:33 +0100 Subject: [PATCH 44/48] Export AbstractAccumulator, fix a docs typo --- docs/src/api.md | 2 +- src/DynamicPPL.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 0b333b604..7a2e5e5e5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -350,7 +350,7 @@ SimpleVarInfo The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. ```@docs -AbstractAccumulators +AbstractAccumulator ``` DynamicPPL provides the following default accumulators. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e0604f458..0553f8d79 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -47,6 +47,7 @@ import Base: export AbstractVarInfo, VarInfo, SimpleVarInfo, + AbstractAccumulator, LogLikelihood, LogPrior, NumProduce, From c1e90f7d5c19b661faa37fab95ad7ce0efc1d9ee Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Apr 2025 15:52:35 +0100 Subject: [PATCH 45/48] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/abstract_varinfo.jl | 2 +- src/accumulators.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 00c9952ce..89cf693df 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -139,7 +139,7 @@ function getaccs end Return a boolean for whether `vi` has an accumulator with name `accname`. """ hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) -function hassacc(vi::AbstractVarInfo, accname::Symbol) +function hasacc(vi::AbstractVarInfo, accname::Symbol) return error( """ The method hasacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type diff --git a/src/accumulators.jl b/src/accumulators.jl index 010ede08b..acd28e69c 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -5,7 +5,7 @@ An abstract type for accumulators. An accumulator is an object that may change its value at every tilde_assume!! or tilde_observe!! call based on the random variable in question. The obvious examples of -accumulators are the log prior and log likelihood. Others examples might be a variable that +accumulators are the log prior and log likelihood. Other examples might be a variable that counts the number of observations in a trace, or a list of the names of random variables seen so far. From cb1c6c641752215554b106e8358a5253349dd7f7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Apr 2025 16:07:21 +0100 Subject: [PATCH 46/48] Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce --- HISTORY.md | 2 +- docs/src/api.md | 6 +- src/DynamicPPL.jl | 6 +- src/abstract_varinfo.jl | 10 +-- src/accumulators.jl | 132 ++++++++++++++++++++-------------- src/logdensityfunction.jl | 2 +- src/model.jl | 9 ++- src/pointwise_logdensities.jl | 4 +- src/simple_varinfo.jl | 8 ++- src/transforming.jl | 2 +- src/varinfo.jl | 6 +- test/accumulators.jl | 129 ++++++++++++++++++--------------- test/varinfo.jl | 36 +++++----- 13 files changed, 201 insertions(+), 151 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 38af16970..68650f9d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,7 +8,7 @@ This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: - - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPrior(),))`. + - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`. - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. diff --git a/docs/src/api.md b/docs/src/api.md index 7a2e5e5e5..83342245a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -356,9 +356,9 @@ AbstractAccumulator DynamicPPL provides the following default accumulators. ```@docs -LogPrior -LogLikelihood -NumProduce +LogPriorAccumulator +LogLikelihoodAccumulator +NumProduceAccumulator ``` ### Common API diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 0553f8d79..9cbdfe229 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -48,9 +48,9 @@ export AbstractVarInfo, VarInfo, SimpleVarInfo, AbstractAccumulator, - LogLikelihood, - LogPrior, - NumProduce, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, push!!, empty!!, subset, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 89cf693df..44097dd2f 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -194,7 +194,7 @@ Set the log of the prior probability of the parameters sampled in `vi` to `logp` See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref). """ -setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPrior(logp)) +setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) """ setloglikelihood!!(vi::AbstractVarInfo, logp) @@ -203,7 +203,7 @@ Set the log of the likelihood probability of the observed data sampled in `vi` t See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref). """ -setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp)) +setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihoodAccumulator(logp)) """ setlogp!!(vi::AbstractVarInfo, logp::NamedTuple) @@ -303,7 +303,7 @@ Add `logp` to the value of the log of the prior probability in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). """ function acclogprior!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(acc -> acc + LogPrior(logp), vi, Val(:LogPrior)) + return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) end """ @@ -314,7 +314,9 @@ Add `logp` to the value of the log of the likelihood in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). """ function accloglikelihood!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(acc -> acc + LogLikelihood(logp), vi, Val(:LogLikelihood)) + return map_accumulator!!( + acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood) + ) end """ diff --git a/src/accumulators.jl b/src/accumulators.jl index acd28e69c..f8f8a6f31 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -193,124 +193,146 @@ end # END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS """ - LogPrior{T} <: AbstractAccumulator + LogPriorAccumulator{T} <: AbstractAccumulator An accumulator that tracks the cumulative log prior during model execution. # Fields $(TYPEDFIELDS) """ -struct LogPrior{T} <: AbstractAccumulator +struct LogPriorAccumulator{T} <: AbstractAccumulator "the scalar log prior value" logp::T end """ - LogPrior{T}() + LogPriorAccumulator{T}() -Create a new `LogPrior` accumulator with the log prior initialized to zero. +Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. """ -LogPrior{T}() where {T} = LogPrior(zero(T)) -LogPrior() = LogPrior{LogProbType}() +LogPriorAccumulator{T}() where {T} = LogPriorAccumulator(zero(T)) +LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() """ - LogLikelihood{T} <: AbstractAccumulator + LogLikelihoodAccumulator{T} <: AbstractAccumulator An accumulator that tracks the cumulative log likelihood during model execution. # Fields $(TYPEDFIELDS) """ -struct LogLikelihood{T} <: AbstractAccumulator +struct LogLikelihoodAccumulator{T} <: AbstractAccumulator "the scalar log likelihood value" logp::T end """ - LogLikelihood{T}() + LogLikelihoodAccumulator{T}() -Create a new `LogLikelihood` accumulator with the log likelihood initialized to zero. +Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. """ -LogLikelihood{T}() where {T} = LogLikelihood(zero(T)) -LogLikelihood() = LogLikelihood{LogProbType}() +LogLikelihoodAccumulator{T}() where {T} = LogLikelihoodAccumulator(zero(T)) +LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() """ - NumProduce{T} <: AbstractAccumulator + NumProduceAccumulator{T} <: AbstractAccumulator An accumulator that tracks the number of observations during model execution. # Fields $(TYPEDFIELDS) """ -struct NumProduce{T<:Integer} <: AbstractAccumulator +struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator "the number of observations" num::T end """ - NumProduce{T<:Integer}() + NumProduceAccumulator{T<:Integer}() -Create a new `NumProduce` accumulator with the number of observations initialized to zero. +Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero. """ -NumProduce{T}() where {T} = NumProduce(zero(T)) -NumProduce() = NumProduce{Int}() +NumProduceAccumulator{T}() where {T} = NumProduceAccumulator(zero(T)) +NumProduceAccumulator() = NumProduceAccumulator{Int}() -Base.show(io::IO, acc::LogPrior) = print(io, "LogPrior($(repr(acc.logp)))") -Base.show(io::IO, acc::LogLikelihood) = print(io, "LogLikelihood($(repr(acc.logp)))") -Base.show(io::IO, acc::NumProduce) = print(io, "NumProduce($(repr(acc.num)))") +function Base.show(io::IO, acc::LogPriorAccumulator) + return print(io, "LogPriorAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::LogLikelihoodAccumulator) + return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::NumProduceAccumulator) + return print(io, "NumProduceAccumulator($(repr(acc.num)))") +end -accumulator_name(::Type{<:LogPrior}) = :LogPrior -accumulator_name(::Type{<:LogLikelihood}) = :LogLikelihood -accumulator_name(::Type{<:NumProduce}) = :NumProduce +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood +accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce -split(::LogPrior{T}) where {T} = LogPrior(zero(T)) -split(::LogLikelihood{T}) where {T} = LogLikelihood(zero(T)) -split(acc::NumProduce) = acc +split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) +split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) +split(acc::NumProduceAccumulator) = acc -combine(acc::LogPrior, acc2::LogPrior) = LogPrior(acc.logp + acc2.logp) -combine(acc::LogLikelihood, acc2::LogLikelihood) = LogLikelihood(acc.logp + acc2.logp) -function combine(acc::NumProduce, acc2::NumProduce) - return NumProduce(max(acc.num, acc2.num)) +function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc.logp + acc2.logp) +end +function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc.logp + acc2.logp) +end +function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator) + return NumProduceAccumulator(max(acc.num, acc2.num)) end -Base.:+(acc1::LogPrior, acc2::LogPrior) = LogPrior(acc1.logp + acc2.logp) -Base.:+(acc1::LogLikelihood, acc2::LogLikelihood) = LogLikelihood(acc1.logp + acc2.logp) -increment(acc::NumProduce) = NumProduce(acc.num + oneunit(acc.num)) +function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc1.logp + acc2.logp) +end +function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc1.logp + acc2.logp) +end +increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num)) -Base.zero(acc::LogPrior) = LogPrior(zero(acc.logp)) -Base.zero(acc::LogLikelihood) = LogLikelihood(zero(acc.logp)) -Base.zero(acc::NumProduce) = NumProduce(zero(acc.num)) +Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) +Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) +Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num)) -function accumulate_assume!!(acc::LogPrior, val, logjac, vn, right) - return acc + LogPrior(logpdf(right, val) + logjac) +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acc + LogPriorAccumulator(logpdf(right, val) + logjac) end -accumulate_observe!!(acc::LogPrior, right, left, vn) = acc +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc -accumulate_assume!!(acc::LogLikelihood, val, logjac, vn, right) = acc -function accumulate_observe!!(acc::LogLikelihood, right, left, vn) +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) # Note that it's important to use the loglikelihood function here, not logpdf, because # they handle vectors differently: # https://github.com/JuliaStats/Distributions.jl/issues/1972 - return acc + LogLikelihood(Distributions.loglikelihood(right, left)) + return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) end -accumulate_assume!!(acc::NumProduce, val, logjac, vn, right) = acc -accumulate_observe!!(acc::NumProduce, right, left, vn) = increment(acc) +accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc +accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc) -Base.convert(::Type{LogPrior{T}}, acc::LogPrior) where {T} = LogPrior(convert(T, acc.logp)) -function Base.convert(::Type{LogLikelihood{T}}, acc::LogLikelihood) where {T} - return LogLikelihood(convert(T, acc.logp)) +function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) end -function Base.convert(::Type{NumProduce{T}}, acc::NumProduce) where {T} - return NumProduce(convert(T, acc.num)) +function Base.convert( + ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator +) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) +end +function Base.convert( + ::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator +) where {T} + return NumProduceAccumulator(convert(T, acc.num)) end # TODO(mhauru) -# We ignore the convert_eltype calls for NumProduce, by letting them fallback on +# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to -# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is +# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. -convert_eltype(::Type{T}, acc::LogPrior) where {T} = LogPrior(convert(T, acc.logp)) -function convert_eltype(::Type{T}, acc::LogLikelihood) where {T} - return LogLikelihood(convert(T, acc.logp)) +function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) +end +function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a3eea9360..1b5e9b8c4 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -79,7 +79,7 @@ julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 julia> # LogDensityFunction respects the accumulators in VarInfo: - f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPrior(),))); + f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),))); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true diff --git a/src/model.jl b/src/model.jl index 4754b5906..e8f2f3528 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1058,8 +1058,11 @@ See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) # Remove other accumulators from varinfo, since they are unnecessary. - logprior = - hasacc(varinfo, Val(:LogPrior)) ? getacc(varinfo, Val(:LogPrior)) : LogPrior() + logprior = if hasacc(varinfo, Val(:LogPrior)) + getacc(varinfo, Val(:LogPrior)) + else + LogPriorAccumulator() + end varinfo = setaccs!!(deepcopy(varinfo), (logprior,)) return getlogprior(last(evaluate!!(model, varinfo, DefaultContext()))) end @@ -1112,7 +1115,7 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) loglikelihood = if hasacc(varinfo, Val(:LogLikelihood)) getacc(varinfo, Val(:LogLikelihood)) else - LogLikelihood() + LogLikelihoodAccumulator() end varinfo = setaccs!!(deepcopy(varinfo), (loglikelihood,)) return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext()))) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index d998858c1..b6b97c8f9 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -69,7 +69,7 @@ function accumulate_assume!!( # T is the element type of the vectors that are the values of `acc.logps`. Usually # it's LogProbType. T = eltype(last(fieldtypes(eltype(acc.logps)))) - subacc = accumulate_assume!!(LogPrior{T}(), val, logjac, vn, right) + subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) push!(acc, vn, subacc.logp) end return acc @@ -87,7 +87,7 @@ function accumulate_observe!!( # T is the element type of the vectors that are the values of `acc.logps`. Usually # it's LogProbType. T = eltype(last(fieldtypes(eltype(acc.logps)))) - subacc = accumulate_observe!!(LogLikelihood{T}(), right, left, vn) + subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) push!(acc, vn, subacc.logp) end return acc diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 8c377de3f..257ccb004 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -125,7 +125,7 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihood(0.0), LogPrior = LogPrior(0.0))) +Transformed SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0))) julia> # (✓) Positive probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) @@ -133,7 +133,7 @@ julia> # (✓) Positive probability mass on negative numbers! julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihood(0.0), LogPrior = LogPrior(0.0))) +SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0))) julia> # (✓) No probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) @@ -204,7 +204,9 @@ function SimpleVarInfo(values, accs) return SimpleVarInfo(values, accs, NoTransformation()) end function SimpleVarInfo{T}(values) where {T<:Real} - return SimpleVarInfo(values, AccumulatorTuple(LogLikelihood{T}(), LogPrior{T}())) + return SimpleVarInfo( + values, AccumulatorTuple(LogLikelihoodAccumulator{T}(), LogPriorAccumulator{T}()) + ) end function SimpleVarInfo(values) return SimpleVarInfo{LogProbType}(values) diff --git a/src/transforming.jl b/src/transforming.jl index 773d8fb7e..ddd1ab59f 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -52,7 +52,7 @@ function _transform!!( model::Model, ) # To transform using DynamicTransformationContext, we evaluate the model, but we do not - # need to use any accumulators other than LogPrior (which is affected by the Jacobian of + # need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of # the transformation). accs = getaccs(vi) has_logprior = haskey(accs, Val(:LogPrior)) diff --git a/src/varinfo.jl b/src/varinfo.jl index ad624a253..ec55f6476 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -105,7 +105,9 @@ function VarInfo(meta=Metadata()) return VarInfo( meta, AccumulatorTuple( - LogPrior{LogProbType}(), LogLikelihood{LogProbType}(), NumProduce{Int}() + LogPriorAccumulator{LogProbType}(), + LogLikelihoodAccumulator{LogProbType}(), + NumProduceAccumulator{Int}(), ), ) end @@ -1025,7 +1027,7 @@ get_num_produce(vi::VarInfo) = getacc(vi, Val(:NumProduce)).num Set the `num_produce` field of `vi` to `n`. """ -set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduce(n)) +set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n)) """ increment_num_produce!!(vi::VarInfo) diff --git a/test/accumulators.jl b/test/accumulators.jl index e6a8dfcda..36bb95e46 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -5,9 +5,9 @@ using Distributions using DynamicPPL using DynamicPPL: AccumulatorTuple, - LogLikelihood, - LogPrior, - NumProduce, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, accumulate_assume!!, accumulate_observe!!, combine, @@ -21,54 +21,63 @@ using DynamicPPL: @testset "accumulators" begin @testset "individual accumulator types" begin @testset "constructors" begin - @test LogPrior(0.0) == - LogPrior() == - LogPrior{Float64}() == - LogPrior{Float64}(0.0) == - zero(LogPrior(1.0)) - @test LogLikelihood(0.0) == - LogLikelihood() == - LogLikelihood{Float64}() == - LogLikelihood{Float64}(0.0) == - zero(LogLikelihood(1.0)) - @test NumProduce(0) == - NumProduce() == - NumProduce{Int}() == - NumProduce{Int}(0) == - zero(NumProduce(1)) + @test LogPriorAccumulator(0.0) == + LogPriorAccumulator() == + LogPriorAccumulator{Float64}() == + LogPriorAccumulator{Float64}(0.0) == + zero(LogPriorAccumulator(1.0)) + @test LogLikelihoodAccumulator(0.0) == + LogLikelihoodAccumulator() == + LogLikelihoodAccumulator{Float64}() == + LogLikelihoodAccumulator{Float64}(0.0) == + zero(LogLikelihoodAccumulator(1.0)) + @test NumProduceAccumulator(0) == + NumProduceAccumulator() == + NumProduceAccumulator{Int}() == + NumProduceAccumulator{Int}(0) == + zero(NumProduceAccumulator(1)) end @testset "addition and incrementation" begin - @test LogPrior(1.0f0) + LogPrior(1.0f0) == LogPrior(2.0f0) - @test LogPrior(1.0) + LogPrior(1.0f0) == LogPrior(2.0) - @test LogLikelihood(1.0f0) + LogLikelihood(1.0f0) == LogLikelihood(2.0f0) - @test LogLikelihood(1.0) + LogLikelihood(1.0f0) == LogLikelihood(2.0) - @test increment(NumProduce()) == NumProduce(1) - @test increment(NumProduce{UInt8}()) == NumProduce{UInt8}(1) + @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0f0) + @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0) + @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0f0) + @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0) + @test increment(NumProduceAccumulator()) == NumProduceAccumulator(1) + @test increment(NumProduceAccumulator{UInt8}()) == + NumProduceAccumulator{UInt8}(1) end @testset "split and combine" begin for acc in [ - LogPrior(1.0), - LogLikelihood(1.0), - NumProduce(1), - LogPrior(1.0f0), - LogLikelihood(1.0f0), - NumProduce(UInt8(1)), + LogPriorAccumulator(1.0), + LogLikelihoodAccumulator(1.0), + NumProduceAccumulator(1), + LogPriorAccumulator(1.0f0), + LogLikelihoodAccumulator(1.0f0), + NumProduceAccumulator(UInt8(1)), ] @test combine(acc, split(acc)) == acc end end @testset "conversions" begin - @test convert(LogPrior{Float32}, LogPrior(1.0)) == LogPrior{Float32}(1.0f0) - @test convert(LogLikelihood{Float32}, LogLikelihood(1.0)) == - LogLikelihood{Float32}(1.0f0) - @test convert(NumProduce{UInt8}, NumProduce(1)) == NumProduce{UInt8}(1) - - @test convert_eltype(Float32, LogPrior(1.0)) == LogPrior{Float32}(1.0f0) - @test convert_eltype(Float32, LogLikelihood(1.0)) == - LogLikelihood{Float32}(1.0f0) + @test convert(LogPriorAccumulator{Float32}, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert( + LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0) + ) == LogLikelihoodAccumulator{Float32}(1.0f0) + @test convert(NumProduceAccumulator{UInt8}, NumProduceAccumulator(1)) == + NumProduceAccumulator{UInt8}(1) + + @test convert_eltype(Float32, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert_eltype(Float32, LogLikelihoodAccumulator(1.0)) == + LogLikelihoodAccumulator{Float32}(1.0f0) end @testset "accumulate_assume" begin @@ -76,31 +85,35 @@ using DynamicPPL: logjac = pi vn = @varname(x) dist = Normal() - @test accumulate_assume!!(LogPrior(1.0), val, logjac, vn, dist) == - LogPrior(1.0 + logjac + logpdf(dist, val)) - @test accumulate_assume!!(LogLikelihood(1.0), val, logjac, vn, dist) == - LogLikelihood(1.0) - @test accumulate_assume!!(NumProduce(1), val, logjac, vn, dist) == NumProduce(1) + @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == + LogPriorAccumulator(1.0 + logjac + logpdf(dist, val)) + @test accumulate_assume!!( + LogLikelihoodAccumulator(1.0), val, logjac, vn, dist + ) == LogLikelihoodAccumulator(1.0) + @test accumulate_assume!!(NumProduceAccumulator(1), val, logjac, vn, dist) == + NumProduceAccumulator(1) end @testset "accumulate_observe" begin right = Normal() left = 2.0 vn = @varname(x) - @test accumulate_observe!!(LogPrior(1.0), right, left, vn) == LogPrior(1.0) - @test accumulate_observe!!(LogLikelihood(1.0), right, left, vn) == - LogLikelihood(1.0 + logpdf(right, left)) - @test accumulate_observe!!(NumProduce(1), right, left, vn) == NumProduce(2) + @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == + LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == + LogLikelihoodAccumulator(1.0 + logpdf(right, left)) + @test accumulate_observe!!(NumProduceAccumulator(1), right, left, vn) == + NumProduceAccumulator(2) end end @testset "accumulator tuples" begin # Some accumulators we'll use for testing - lp_f64 = LogPrior(1.0) - lp_f32 = LogPrior(1.0f0) - ll_f64 = LogLikelihood(1.0) - ll_f32 = LogLikelihood(1.0f0) - np_i64 = NumProduce(1) + lp_f64 = LogPriorAccumulator(1.0) + lp_f32 = LogPriorAccumulator(1.0f0) + ll_f64 = LogLikelihoodAccumulator(1.0) + ll_f32 = LogLikelihoodAccumulator(1.0f0) + np_i64 = NumProduceAccumulator(1) @testset "constructors" begin @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) @@ -126,7 +139,7 @@ using DynamicPPL: @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce) @test collect(at_all64) == [lp_f64, ll_f64, np_i64] - # Replace the existing LogPrior + # Replace the existing LogPriorAccumulator @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 # Check that setacc!! didn't modify the original @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) @@ -140,20 +153,22 @@ using DynamicPPL: @testset "map_accumulator(s)!!" begin # map over all accumulators accs = AccumulatorTuple(lp_f32, ll_f32) - @test map(zero, accs) == AccumulatorTuple(LogPrior(0.0f0), LogLikelihood(0.0f0)) + @test map(zero, accs) == AccumulatorTuple( + LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0) + ) # Test that the original wasn't modified. @test accs == AccumulatorTuple(lp_f32, ll_f32) # A map with a closure that changes the types of the accumulators. @test map(acc -> convert_eltype(Float64, acc), accs) == - AccumulatorTuple(LogPrior(1.0), LogLikelihood(1.0)) + AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0)) # only apply to a particular accumulator @test map_accumulator(zero, accs, Val(:LogLikelihood)) == - AccumulatorTuple(lp_f32, LogLikelihood(0.0f0)) + AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0)) @test map_accumulator( acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) - ) == AccumulatorTuple(lp_f32, LogLikelihood(1.0)) + ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 3ee2305de..efa8c6e4c 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -205,13 +205,15 @@ end @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) vi = last( - DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPrior(),))) + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorAccumulator(),)) + ), ) @test getlogprior(vi) == lp_a + lp_b - @test_throws "has no field LogLikelihood" getloglikelihood(vi) - @test_throws "has no field LogLikelihood" getlogp(vi) - @test_throws "has no field LogLikelihood" getlogjoint(vi) - @test_throws "has no field NumProduce" get_num_produce(vi) + @test_throws "has no field LogLikelihoodAccumulator" getloglikelihood(vi) + @test_throws "has no field LogLikelihoodAccumulator" getlogp(vi) + @test_throws "has no field LogLikelihoodAccumulator" getlogjoint(vi) + @test_throws "has no field NumProduceAccumulator" get_num_produce(vi) @test begin vi = acclogprior!!(vi, 1.0) getlogprior(vi) == lp_a + lp_b + 1.0 @@ -222,22 +224,24 @@ end end vi = last( - DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduce(),))) + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),)) + ), ) - @test_throws "has no field LogPrior" getlogprior(vi) - @test_throws "has no field LogLikelihood" getloglikelihood(vi) - @test_throws "has no field LogPrior" getlogp(vi) - @test_throws "has no field LogPrior" getlogjoint(vi) + @test_throws "has no field LogPriorAccumulator" getlogprior(vi) + @test_throws "has no field LogLikelihoodAccumulator" getloglikelihood(vi) + @test_throws "has no field LogPriorAccumulator" getlogp(vi) + @test_throws "has no field LogPriorAccumulator" getlogjoint(vi) @test get_num_produce(vi) == 2 # Test evaluating without any accumulators. vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) - @test_throws "has no field LogPrior" getlogprior(vi) - @test_throws "has no field LogLikelihood" getloglikelihood(vi) - @test_throws "has no field LogPrior" getlogp(vi) - @test_throws "has no field LogPrior" getlogjoint(vi) - @test_throws "has no field NumProduce" get_num_produce(vi) - @test_throws "has no field NumProduce" reset_num_produce!!(vi) + @test_throws "has no field LogPriorAccumulator" getlogprior(vi) + @test_throws "has no field LogLikelihoodAccumulator" getloglikelihood(vi) + @test_throws "has no field LogPriorAccumulator" getlogp(vi) + @test_throws "has no field LogPriorAccumulator" getlogjoint(vi) + @test_throws "has no field NumProduceAccumulator" get_num_produce(vi) + @test_throws "has no field NumProduceAccumulator" reset_num_produce!!(vi) end @testset "flags" begin From 00ef0cfe43cd706828e3196500cbf4cbc9326727 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Apr 2025 16:11:32 +0100 Subject: [PATCH 47/48] Type bound log prob accumulators with T<:Real --- src/accumulators.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index f8f8a6f31..e241abf1c 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -193,14 +193,14 @@ end # END ACCUMULATOR TUPLE, BEGIN LOG PROB AND NUM PRODUCE ACCUMULATORS """ - LogPriorAccumulator{T} <: AbstractAccumulator + LogPriorAccumulator{T<:Real} <: AbstractAccumulator An accumulator that tracks the cumulative log prior during model execution. # Fields $(TYPEDFIELDS) """ -struct LogPriorAccumulator{T} <: AbstractAccumulator +struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator "the scalar log prior value" logp::T end @@ -210,18 +210,18 @@ end Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. """ -LogPriorAccumulator{T}() where {T} = LogPriorAccumulator(zero(T)) +LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() """ - LogLikelihoodAccumulator{T} <: AbstractAccumulator + LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator An accumulator that tracks the cumulative log likelihood during model execution. # Fields $(TYPEDFIELDS) """ -struct LogLikelihoodAccumulator{T} <: AbstractAccumulator +struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator "the scalar log likelihood value" logp::T end @@ -231,7 +231,7 @@ end Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. """ -LogLikelihoodAccumulator{T}() where {T} = LogLikelihoodAccumulator(zero(T)) +LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)) LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() """ @@ -252,7 +252,7 @@ end Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero. """ -NumProduceAccumulator{T}() where {T} = NumProduceAccumulator(zero(T)) +NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T)) NumProduceAccumulator() = NumProduceAccumulator{Int}() function Base.show(io::IO, acc::LogPriorAccumulator) From 14f478811217f59d2c4e7ca0bc71aa989ca6a1fe Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Apr 2025 16:47:25 +0100 Subject: [PATCH 48/48] Add @addlogprior! and @addloglikelihood! --- docs/src/api.md | 4 +- src/DynamicPPL.jl | 2 + src/abstract_varinfo.jl | 17 ++++-- src/utils.jl | 121 ++++++++++++++++++++++++++++++++++++++-- test/utils.jl | 50 ++++++++++++++++- 5 files changed, 181 insertions(+), 13 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 83342245a..e104193f2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -160,10 +160,12 @@ returned(::Model) ## Utilities -It is possible to manually increase (or decrease) the accumulated log likelihood from within a model function. +It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @addlogprob! +@addloglikelihood! +@addlogprior! ``` Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9cbdfe229..c8bbda020 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -127,6 +127,8 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, + @addlogprior!, + @addloglikelihood!, @submodel, value_iterator_from_chain, check_model, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 44097dd2f..2f5da2c31 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -320,13 +320,18 @@ function accloglikelihood!!(vi::AbstractVarInfo, logp) end """ - acclogp!!(vi::AbstractVarInfo, logp::NamedTuple) + acclogp!!(vi::AbstractVarInfo, logp::NamedTuple; ignore_missing_accumulator::Bool=false) Add to both the log prior and the log likelihood probabilities in `vi`. `logp` should have fields `logprior` and/or `loglikelihood`, and no other fields. + +By default if the necessary accumulators are not in `vi` an error is thrown. If +`ignore_missing_accumulator` is set to `true` then this is silently ignored instead. """ -function acclogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} +function acclogp!!( + vi::AbstractVarInfo, logp::NamedTuple{names}; ignore_missing_accumulator=false +) where {names} if !( names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior) || @@ -335,17 +340,19 @@ function acclogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} ) error("logp must have fields logprior and/or loglikelihood and no other fields.") end - if haskey(logp, :logprior) + if haskey(logp, :logprior) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogPrior))) vi = acclogprior!!(vi, logp.logprior) end - if haskey(logp, :loglikelihood) + if haskey(logp, :loglikelihood) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogLikelihood))) vi = accloglikelihood!!(vi, logp.loglikelihood) end return vi end function acclogp!!(vi::AbstractVarInfo, logp::Number) - depwarn( + Base.depwarn( "`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.", :acclogp, ) diff --git a/src/utils.jl b/src/utils.jl index 2ed0c049e..a141148a0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,8 +18,86 @@ const LogProbType = float(Real) """ @addlogprob!(ex) +Add a term to the log joint. + +If `ex` evaluates to a `NamedTuple` with keys `:loglikelihood` and/or `:logprior`, the +values are added to the log likelihood and log prior respectively. + +If `ex` evaluates to a number it is added to the log likelihood. This use is deprecated +and should be replaced with either the `NamedTuple` version or calls to +[`@addloglikelihood!`](@ref). + +See also [`@addloglikelihood!`](@ref), [`@addlogprior!`](@ref). + +# Examples + +```jldoctest; setup = :(using Distributions) +julia> mylogjoint(x, μ) = (; loglikelihood=loglikelihood(Normal(μ, 1), x), logprior=1.0); + +julia> @model function demo(x) + μ ~ Normal() + @addlogprob! mylogjoint(x, μ) + end; + +julia> x = [1.3, -2.1]; + +julia> loglikelihood(demo(x), (μ=0.2,)) ≈ mylogjoint(x, 0.2).loglikelihood +true + +julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogjoint(x, 0.2).logprior +true +``` + +and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): + +```jldoctest; setup = :(using Distributions, LinearAlgebra) +julia> @model function demo(x) + m ~ MvNormal(zero(x), I) + if dot(m, x) < 0 + @addlogprob! (; loglikelihood=-Inf) + # Exit the model evaluation early + return + end + x ~ MvNormal(m, I) + return + end; + +julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf +true +``` +""" +macro addlogprob!(ex) + return quote + val = $(esc(ex)) + vi = $(esc(:(__varinfo__))) + if val isa Number + Base.depwarn( + """ + @addlogprob! with a single number argument is deprecated. Please use + @addlogprob! (; loglikelihood=x) or @addloglikelihood! instead. + """, + :addlogprob!, + ) + if hasacc(vi, Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), val) + end + elseif !isa(val, NamedTuple) + error("logp must be a NamedTuple.") + else + $(esc(:(__varinfo__))) = acclogp!!( + $(esc(:(__varinfo__))), val; ignore_missing_accumulator=true + ) + end + end +end + +""" + @addloglikelihood!(ex) + Add the result of the evaluation of `ex` to the log likelihood. +See also [`@addlogprob!`](@ref), [`@addlogprior!`](@ref). + # Examples This macro allows you to [include arbitrary terms in the likelihood](https://github.com/TuringLang/Turing.jl/issues/1332) @@ -29,7 +107,7 @@ julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); julia> @model function demo(x) μ ~ Normal() - @addlogprob! myloglikelihood(x, μ) + @addloglikelihood! myloglikelihood(x, μ) end; julia> x = [1.3, -2.1]; @@ -44,7 +122,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): julia> @model function demo(x) m ~ MvNormal(zero(x), I) if dot(m, x) < 0 - @addlogprob! -Inf + @addloglikelihood! -Inf # Exit the model evaluation early return end @@ -56,10 +134,43 @@ julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf true ``` """ -macro addlogprob!(ex) +macro addloglikelihood!(ex) + return quote + if hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) + end + end +end + +""" + @addlogprior!(ex) + +Add the result of the evaluation of `ex` to the log prior. + +See also [`@addloglikelihood!`](@ref), [`@addlogprob!`](@ref). + +# Examples + +This macro allows you to include arbitrary terms in the prior. + +```jldoctest; setup = :(using Distributions) +julia> mylogpriorextraterm(μ) = μ > 0 ? -1.0 : 0.0; + +julia> @model function demo(x) + μ ~ Normal() + @addlogprior! mylogpriorextraterm(μ) + end; + +julia> x = [1.3, -2.1]; + +julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogpriorextraterm(0.2) +true +``` +""" +macro addlogprior!(ex) return quote - if $hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood)) - $(esc(:(__varinfo__))) = $accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex))) + if hasacc($(esc(:(__varinfo__))), Val(:LogPrior)) + $(esc(:(__varinfo__))) = acclogprior!!($(esc(:(__varinfo__))), $(esc(ex))) end end end diff --git a/test/utils.jl b/test/utils.jl index 4aa2d9943..b85d21c41 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -6,10 +6,56 @@ return global lp_after = getlogjoint(__varinfo__) end - model = testmodel() - varinfo = VarInfo(model) + varinfo = VarInfo(testmodel()) @test iszero(lp_before) @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_nt() + global lp_before = getlogjoint(__varinfo__) + @addlogprob! (; logprior=(pi + 1), loglikelihood=42) + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_nt()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + 1 + pi + @test getloglikelihood(varinfo) == 42 + @test getlogprior(varinfo) == pi + 1 + + @model function testmodel_nt2() + global lp_before = getlogjoint(__varinfo__) + llh_nt = (; loglikelihood=42) + @addlogprob! llh_nt + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_nt2()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_likelihood() + global lp_before = getlogjoint(__varinfo__) + @addloglikelihood! 42 + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_likelihood()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_prior() + global lp_before = getlogjoint(__varinfo__) + @addlogprior! 42 + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_prior()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + @test getlogprior(varinfo) == 42 end @testset "getargs_dottilde" begin