Skip to content

Accumulators, stage 1 #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 42 commits into
base: breaking
Choose a base branch
from
Draft

Accumulators, stage 1 #885

wants to merge 42 commits into from

Conversation

mhauru
Copy link
Member

@mhauru mhauru commented Apr 10, 2025

This is starting to take shape. It's too early for a review: Everything is undocumented, uncleaned, and some things are still broken. The base design is there though, and most tests pass (pointwiseloglikelihood and doctests being the exceptions), so @penelopeysm, @torfjelde, if you want to have an early look at where this is going, feel free. The most interesting files are accumulators.jl, abstract_varinfo.jl, and context_implementations.jl.

In addition to obvious things that still need doing (documentation, clean-up, new tests, adding deprecations, fixing pointwiseloglikehood), a few things I have on my mind:

  • Need to decide whether to keep the LogLikelihood and LogPrior accumulators immutable like they are now.
  • Whether getacc and similar functions should take the type of the accumulator as the index, or rather the symbol returned by accumulator_name. Leaning towards latter, but the former is what's currently implemented.
  • Maybe rename DefaultContext to AccumulationContext. Or something else? I'm not fixated on the term "accumulator".
  • Since the signature of (tilde_)assume and (tilde_)observe has changed (they no longer return logp), the whole stack of calls within tilde_obssume!! should be revisited. In particular, I'm thinking of splitting anything sampling-related to a call of tilde_obbsume with SamplingContext, that then at the end calls tilde_obssume with DefaultContext. This might be a separate PR though.
  • Benchmark
  • There are a few places where we are now unnecessarily accumulating all of log prior, log likelihood, and num produce. I should clean these up to benefit from being able to do one but not the others.
  • Make metadata.order be an accumulator as well. Probably needs to actually be in the same accumulator with NumProduce, since the two go together. Probably a separate PR though.

penelopeysm and others added 9 commits March 5, 2025 10:34
* AbstractPPL 0.11; change prefixing behaviour

* Use DynamicPPL.prefix rather than overloading
* Unify {Untyped,Typed}{Vector,}VarInfo constructors

* Update invocations

* NTVarInfo

* Fix tests

* More fixes

* Fixes

* Fixes

* Fixes

* Use lowercase functions, don't deprecate VarInfo

* Rewrite VarInfo docstring

* Fix methods

* Fix methods (really)
Copy link
Contributor

github-actions bot commented Apr 10, 2025

Benchmark Report for Commit 557954a

Computer Information

Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                 67.6 |                 1.3 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                918.1 |                30.7 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                486.2 |                44.4 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1318.9 |                28.1 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               7306.8 |                21.3 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1619.2 |                25.9 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1078.1 |                 6.0 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5996.7 |                 4.1 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1087.6 |                 9.5 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              66036.6 |                 3.5 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8869.5 |                10.0 |
|               Dynamic |        10 |    mooncake |             typed |   true |                216.1 |                17.2 |
|              Submodel |         1 |    mooncake |             typed |   true |                 72.6 |                12.8 |
|                   LDA |        12 | reversediff |             typed |   true |               1263.9 |                 1.8 |

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not reviewing actual code, just one high-level thought that struck me.

Comment on lines 121 to 125
function setlogp!!(vi::AbstractVarInfo, logp)
vi = setlogprior!!(vi, zero(logp))
vi = setloglikelihood!!(vi, logp)
return vi
end
Copy link
Member

@penelopeysm penelopeysm Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this the other day and thought I may as well post now. The ...logp() family of functions are no longer well-defined in a world where everything is cleanly split into prior and likelihood. (only getlogp and resetlogp still make sense) I think last time we chatted about it the decision was to maybe forward the others to the likelihood methods, but I was wondering if it's actually safer to remove them (or make them error informatively) and force people to use likelihood or prior as otherwise it risks introducing subtle bugs. Backward compatibility is important but if it comes at the cost of correctness I feel kinda uneasy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hope was that we could deprecate them but provide the same functionality through the new functions, like above. It's a good question as to whether there are edge cases where they do not provide the same functionality. I think this is helped by the fact that PriorContext and LikelihoodContext won't exist, and hence one can't be running code where the expectation would be that ...logp() would be referring to logprior or loglikelihood in particular. And I think as long as one expects to get the logjoint out of ...logp() we can do things like above, shoving things into likelihood, and get the same results. Do you think that solves it and let's us use deprecations rather than straight-up removals, or do you see other edge cases?

Copy link
Member

@penelopeysm penelopeysm Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this is a case where setlogp is ill-defined:

lp = getlogp(vi_typed_metadata)
varinfos = map((
vi_untyped_metadata,
vi_untyped_vnv,
vi_typed_metadata,
vi_typed_vnv,
svi_typed,
svi_untyped,
svi_vnv,
svi_typed_ref,
svi_untyped_ref,
svi_vnv_ref,
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end

The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.

Of course, we can fix this on our end - we would get and set logprior and loglikelihood manually, and we can grep the codebase to make sure that there are no other ill-defined calls to setlogp. We can't guarantee that other people will be similarly careful, though (and us or anyone being careful also doesn't guarantee that everything will be fixed correctly).

Copy link
Member

@penelopeysm penelopeysm Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While looking for other uses of setlogp, I encountered this:

https://github.com/TuringLang/Turing.jl/blob/fc32e10bc17ae3fda4d7e825b6fde45dc7bdb179/src/mcmc/hmc.jl#L201-L234

AdvancedHMC.Transition only contains a single notion of log density, so it's not obvious to me how we're going to extract the prior and likelihood components from it 😓 This might require upstream changes to AdvancedHMC. Since the contexts will be removed, I suspect LogDensityFunction also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).

(For the record, I'd be quite happy with making all of these changes!)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.

It is inconsistent, but as long as the user only uses getlogp, they would never see the difference, right? If some of logprior is accidentally stored in loglikelihood or vice versa, as long as one is using getlogp and DefaultContext that should be undetectable. What would be trouble is if someone mixes using e.g. setlogp!! and getlogprior, which would require adding calls to getlogprior after upgrading to a version that has deprecated setlogp!!, but probably people would end up doing that. Maybe the deprecation warning could say something about this?

Since the contexts will be removed, I suspect LogDensityFunction also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).

Yeah, this sort of stuff will come up (and is coming up) in multiple places. Anything that explicitly uses PriorContext or LikelihoodContext would need to be changed to use LogPrior and LogLikelihood accumulators instead. I'm currently doing this for pointwiselogdensities.

@mhauru
Copy link
Member Author

mhauru commented Apr 11, 2025

pointwise_logdensities now works and uses its own accumulator type rather than a context. This leaves only a handful of tests failing, for quite trivial reasons. Plenty of clean-up to do though: In fixing pointwise_logprobability I had to make substantial changes to the tilde_observe pipeline, because accumulate_observe needed to get the varname as an argument and thus had to be moved higher in the call stack. I'll have to see how to best reorganise tilde_observe in such a way that making ParticleGibbs work with it wouldn't be horrible.

Comment on lines +129 to +133
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, dist)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
return x, vi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we deal with tempering of logpdf and such now that it happens in the leaf of the call stack?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past, we would do this by altering the logpdf coming a the assume higher up in the call tree

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the needs of tempering? What does it need to alter?

mhauru and others added 9 commits April 15, 2025 12:23
docs/src/api.md Outdated
@@ -163,7 +163,7 @@ returned(::Model)
It is possible to manually increase (or decrease) the accumulated log density from within a model function.

```@docs
@addlogprob!
@addloglikelihood!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add support for more flexible accumulation via

@addlogprob! 0.
@addlogprob! (logprior=0., loglike=0.)

This would avoid breaking the existing interface which is likely one of the most used features of DynamicPPL.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there situations where user code might break if @addlogprob! only adds to the log likelihood? For instance, I think currently if you have a model that uses @addlogprob! and execute it with PriorContext, the extra log prob term will still be included in the logp of the prior. After this PR it wouldn't. If you did the equivalent of PriorContext but using accumulators, the @addlogprob! would not apply. I worry that this might break things for some users in a surprising way, and hence proposed renaming @addlogp! to @addloglikelihood!and maybe adding a corresponding @addlogprior!, to force the user to explicit about which one they want.

Note that currently in this PR @addlogprob! is produces a deprecation warning, but works as a synonym for @addloglikelihood!. I was thinking of keeping that deprecation in place for a while, hoping people would migrate to using @addloglikelihood! (or maybe in some cases @addlogprior!) due to the warnings, and then eventually we could drop @addlogprob!.

Copy link
Member

@yebai yebai Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there situations where user code might break if @addlogprob! only adds to the log likelihood?

No, it would be a bug if any.

I am leaning towards keeping the @addlogprob! interface as-is, but allow @addlogprob! (logprior=0., loglikelihood=0.) so users can more flexibly accumulate log joint probabilities for expert users -- this should only be relevant if they are using customised inference algorithms (standard algorithms in AdvancedMH, AdvancedHMC, AdvancedVI don't need this behaviour)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might anything weird happen with ESS? I don't know how it works, but it seems to use PriorContext in some particular way.

I guess the same applies then to setlogp!! and acclogp!!, namely that any user-code that uses them and expects to set/accumulate anything other than the log likelihood (even when e.g. using PriorContext) can be considered a bug?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might anything weird happen with ESS?

ESS only deals with Gaussian priors, so users should

  1. use ~ to specify the priors so ESS can check whether RHS is Gaussian
  2. avoid using @addlogprob! to interact with the prior

But, using @addlogprob! to accumulate log likelihood should be okay for ESS, since ESS doesn't make any assumption on the distributional form of likelihood.

I guess the same applies then to setlogp!! and acclogp!!, namely that any user-code that uses them and expects to set/accumulate anything other than the log likelihood (even when e.g. using PriorContext) can be considered a bug?

I suspect so, since these functions were mainly designed for MH and HMC. In these use cases, accumulating log joint and log likelihood are equivalent as long as the model is differentiable.

That said, for prudence, we should provide an explanation on this behavioural change in HISTORY and mark this PR as a breaking change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, in that case I'm happy to keep @addlogprob! in place; The change in its behaviour in this PR can actually be considered a fix to incorrect earlier behaviour. Also happy to add the NamedTuple version of it. Would you be okay with also adding @addloglikelihood!, and putting in an info print or a deprecation warning in @addlogprob! to nudge people towards using either @addlogprob! (likelihood=val,) or @addloglikelihood! val in the long term? I think long term (v1.0 and beyond) the more explicit syntax would be good.

getlogp
setlogp!!
acclogp!!
getlogprior
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similiar to above:

setlogp!!(x) # equivalent to `setlogp!!(logprior=0., loglikelihood=x)`
setlogp!!((logprior=0., loglikelihood=0.))

acclogp!!(x) # equivalent to `acclogp!!(logprior=0., loglikelihood=x)`
acclogp!!((logprior=0., loglikelihood=0.))

logprior, loglikelihood = getlogp(...) # breaking but okay for clarity
logjoint = sum(getlogp(...)) # convenience function

and avoid changing these API names in a breaking way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the NamedTuple versions are good, and they maintain the nice feature that setlogp!!(vi, getlogp(vi)) is a no-op. I would still like to deprecate setlopgp!!(vi, x::Float), and preferably acclogp!!(vi, x::Float) as well. setlogp!!(vi, x::Float64) seems a bit dangerous to me since it has to set the prior to zero, which is quite different from what it used to do, and not an operation that I think makes much sense. These are also far less user-facing than @addlogprob!, I suspect the only people calling setlogp!!/acclogp!! are inference algorithm developers.

mhauru and others added 9 commits April 22, 2025 16:37
* Fix conditioning in submodels

* Simplify contextual_isassumption

* Add documentation

* Fix some tests

* Add tests; fix a bunch of nested submodel issues

* Fix fix as well

* Fix doctests

* Add unit tests for new functions

* Add changelog entry

* Update changelog

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* Finish docs

* Add a test for conditioning submodel via arguments

* Clean new tests up a bit

* Fix for VarNames with non-identity lenses

* Apply suggestions from code review

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Apply suggestions from code review

* Make PrefixContext contain a varname rather than symbol (#896)

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: Markus Hauru <markus@mhauru.org>
Copy link

codecov bot commented Apr 24, 2025

Codecov Report

Attention: Patch coverage is 89.13525% with 49 lines in your changes missing coverage. Please review.

Project coverage is 84.45%. Comparing base (8135113) to head (557954a).

Files with missing lines Patch % Lines
src/abstract_varinfo.jl 71.69% 15 Missing ⚠️
src/varinfo.jl 86.48% 15 Missing ⚠️
src/threadsafe.jl 85.00% 6 Missing ⚠️
src/debug_utils.jl 77.77% 4 Missing ⚠️
src/simple_varinfo.jl 85.71% 4 Missing ⚠️
src/logdensityfunction.jl 70.00% 3 Missing ⚠️
src/values_as_in_model.jl 50.00% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking     #885      +/-   ##
============================================
- Coverage     85.05%   84.45%   -0.61%     
============================================
  Files            35       36       +1     
  Lines          3922     4002      +80     
============================================
+ Hits           3336     3380      +44     
- Misses          586      622      +36     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Base automatically changed from breaking to main April 24, 2025 16:32
@mhauru mhauru force-pushed the mhauru/custom-accumulators branch from c421740 to 4fef20f Compare April 24, 2025 16:49
@mhauru mhauru changed the base branch from main to breaking April 24, 2025 16:50
@coveralls
Copy link

Pull Request Test Coverage Report for Build 14646630936

Details

  • 306 of 451 (67.85%) changed or added relevant lines in 18 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall first build on mhauru/custom-accumulators at 55.812%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLMCMCChainsExt.jl 0 1 0.0%
src/test_utils/contexts.jl 0 1 0.0%
src/test_utils/models.jl 14 15 93.33%
src/transforming.jl 18 19 94.74%
src/values_as_in_model.jl 2 4 50.0%
src/context_implementations.jl 24 27 88.89%
src/logdensityfunction.jl 7 10 70.0%
src/pointwise_logdensities.jl 39 44 88.64%
src/debug_utils.jl 11 18 61.11%
src/accumulators.jl 64 72 88.89%
Totals Coverage Status
Change from base Build 14646769416: 55.8%
Covered Lines: 2228
Relevant Lines: 3992

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants