Skip to content

Accumulators, stage 1 #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: breaking
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
061acbe
Release 0.36
penelopeysm Mar 5, 2025
1496868
Merge branch 'main' into breaking
penelopeysm Mar 20, 2025
324e623
Merge branch 'main' into breaking
penelopeysm Mar 22, 2025
bb59885
Merge branch 'main' into breaking
penelopeysm Mar 25, 2025
fc32398
AbstractPPL 0.11 + change prefixing behaviour (#830)
penelopeysm Mar 28, 2025
cc5e581
Remove VarInfo(VarInfo, params) (#870)
penelopeysm Mar 28, 2025
b9c368b
Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879)
penelopeysm Apr 9, 2025
1b8f555
Merge remote-tracking branch 'origin/main' into breaking
mhauru Apr 9, 2025
ae9b1cd
Draft of accumulators
mhauru Feb 24, 2025
4fb0bf4
Fix some variable names
mhauru Apr 10, 2025
e410f47
Merge remote-tracking branch 'origin/main' into breaking
penelopeysm Apr 11, 2025
97788bd
Fix pointwise_logdensities, gut tilde_observe, remove resetlogp!!
mhauru Apr 11, 2025
7fe03ec
Map rather than broadcast
mhauru Apr 15, 2025
5ba3530
Merge remote-tracking branch 'origin/main' into breaking
penelopeysm Apr 15, 2025
d49f7be
Start documenting accumulators
mhauru Apr 15, 2025
28bbf1c
Use Val{symbols} instead of AccTypes to index
mhauru Apr 15, 2025
a0ed665
More documentation for accumulators
mhauru Apr 15, 2025
be27636
Link varinfo by default in AD testing utilities; make test suite run …
penelopeysm Apr 16, 2025
e6453fe
Fix resetlogp!! and type stability for accumulators
mhauru Apr 16, 2025
c59400d
Fix type rigidity of LogProbs and NumProduce
mhauru Apr 16, 2025
47033ce
Fix uses of getlogp and other assorted issues
mhauru Apr 17, 2025
8b841c9
setaccs!! nicer interface and logdensity function fixes
mhauru Apr 17, 2025
3ee3989
Revert back to calling the macro @addlogprob!
mhauru Apr 22, 2025
13163f2
Remove a dead test
mhauru Apr 22, 2025
37dd6dd
Clarify a comment
mhauru Apr 22, 2025
d7013b6
Implement split/combine for PointwiseLogdensityAccumulator
mhauru Apr 22, 2025
40d4caa
Switch ThreadSafeVarInfo.accs_by_thread to be a tuple
mhauru Apr 22, 2025
ff5f2cb
Fix `condition` and `fix` in submodels (#892)
penelopeysm Apr 23, 2025
c68f1bb
Merge remote-tracking branch 'origin/main' into breaking
penelopeysm Apr 23, 2025
13da08a
Revert ThreadSafeVarInfo back to Vectors and fix some AD type casting…
mhauru Apr 24, 2025
d52feec
Merge remote-tracking branch 'origin/breaking' into mhauru/custom-acc…
mhauru Apr 24, 2025
221e797
Improve accumulator docs
mhauru Apr 24, 2025
1dbcb2c
Add test/accumulators.jl
mhauru Apr 24, 2025
e1b70e0
Docs fixes
mhauru Apr 24, 2025
3f195e5
Various small fixes
mhauru Apr 24, 2025
68b974a
Make DynamicTransformation not use accumulators other than LogPrior
mhauru Apr 24, 2025
2b405d9
Fix variable order and name of map_accumulator!!
mhauru Apr 24, 2025
00cd304
Typo fixing
mhauru Apr 24, 2025
6d1048d
Small improvement to ThreadSafeVarInfo
mhauru Apr 24, 2025
4fef20f
Fix demo_dot_assume_observe_submodel prefixing
mhauru Apr 24, 2025
905b874
Merge branch 'breaking' into mhauru/custom-accumulators
mhauru Apr 24, 2025
557954a
Typo fixing
mhauru Apr 24, 2025
6f702c9
Miscellaneous small fixes
mhauru Apr 25, 2025
f748775
HISTORY entry and more miscellanea
mhauru Apr 25, 2025
5f4a532
Add more tests for accumulators
mhauru Apr 25, 2025
31967fd
Improve accumulators docstrings
mhauru Apr 25, 2025
ad2f564
Fix a typo
mhauru Apr 25, 2025
10b4f2f
Expand HISTORY entry
mhauru Apr 25, 2025
d2b670d
Add accumulators to API docs
mhauru Apr 25, 2025
8241d12
Remove unexported functions from API docs
mhauru Apr 25, 2025
7b7a3e2
Add NamedTuple methods for get/set/acclogp
mhauru Apr 25, 2025
0b08237
Fix setlogp!! with single scalar to error
mhauru Apr 25, 2025
2a4b874
Export AbstractAccumulator, fix a docs typo
mhauru Apr 25, 2025
c1e90f7
Apply suggestions from code review
mhauru Apr 28, 2025
cb1c6c6
Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce
mhauru Apr 28, 2025
00ef0cf
Type bound log prob accumulators with T<:Real
mhauru Apr 28, 2025
14f4788
Add @addlogprior! and @addloglikelihood!
mhauru Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# DynamicPPL Changelog

## 0.37.0

**Breaking changes**

### Accumulators

This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:

- `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`.
- `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself.
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we can't remove tilde_assume like tilde_observe and observe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's on my list to do, but I didn't want to put in the same PR since I didn't have to. (I had to do tilde_observe because of complications with PointwiseLogdensityAccumulator). Or, more precisely, what's on my list is to revisit the whole call stack for both tilde_assume!! and tilde_observe!! and see what the best way to do things is.

- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding to the log likelihood is a sensible default. For advanced use by programmable inference, let's also support

@addlogprob (logprior=0., loglikelihood=0.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, still need to add that, but happy to do so. I would still prefer to nudge people to use the above syntax or @addloglikelihood! long term, with a depwarning. I would expect to keep @addlogprob! scalar_value around for quite some time still, but long term (v1.0 and beyond) I think the more explicit syntax would be better. It's clearer, the fact that adding to the log prior is even an option means that it's good to be explicit, and as @penelopeysm pointed out the switching cost is minimal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the name tuple version @addlogprob (logprior=0., loglikelihood=0.) and the depwarn I proposed above.

- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.

## 0.36.0

**Breaking changes**
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -68,6 +69,7 @@ MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.95"
OrderedCollections = "1"
Printf = "1.10"
Random = "1.6"
Requires = "1"
Statistics = "1"
Expand Down
40 changes: 29 additions & 11 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ returned(::Model)

## Utilities

It is possible to manually increase (or decrease) the accumulated log density from within a model function.
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.

```@docs
@addlogprob!
@addloglikelihood!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add support for more flexible accumulation via

@addlogprob! 0.
@addlogprob! (logprior=0., loglike=0.)

This would avoid breaking the existing interface which is likely one of the most used features of DynamicPPL.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there situations where user code might break if @addlogprob! only adds to the log likelihood? For instance, I think currently if you have a model that uses @addlogprob! and execute it with PriorContext, the extra log prob term will still be included in the logp of the prior. After this PR it wouldn't. If you did the equivalent of PriorContext but using accumulators, the @addlogprob! would not apply. I worry that this might break things for some users in a surprising way, and hence proposed renaming @addlogp! to @addloglikelihood!and maybe adding a corresponding @addlogprior!, to force the user to explicit about which one they want.

Note that currently in this PR @addlogprob! is produces a deprecation warning, but works as a synonym for @addloglikelihood!. I was thinking of keeping that deprecation in place for a while, hoping people would migrate to using @addloglikelihood! (or maybe in some cases @addlogprior!) due to the warnings, and then eventually we could drop @addlogprob!.

Copy link
Member

@yebai yebai Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there situations where user code might break if @addlogprob! only adds to the log likelihood?

No, it would be a bug if any.

I am leaning towards keeping the @addlogprob! interface as-is, but allow @addlogprob! (logprior=0., loglikelihood=0.) so users can more flexibly accumulate log joint probabilities for expert users -- this should only be relevant if they are using customised inference algorithms (standard algorithms in AdvancedMH, AdvancedHMC, AdvancedVI don't need this behaviour)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might anything weird happen with ESS? I don't know how it works, but it seems to use PriorContext in some particular way.

I guess the same applies then to setlogp!! and acclogp!!, namely that any user-code that uses them and expects to set/accumulate anything other than the log likelihood (even when e.g. using PriorContext) can be considered a bug?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might anything weird happen with ESS?

ESS only deals with Gaussian priors, so users should

  1. use ~ to specify the priors so ESS can check whether RHS is Gaussian
  2. avoid using @addlogprob! to interact with the prior

But, using @addlogprob! to accumulate log likelihood should be okay for ESS, since ESS doesn't make any assumption on the distributional form of likelihood.

I guess the same applies then to setlogp!! and acclogp!!, namely that any user-code that uses them and expects to set/accumulate anything other than the log likelihood (even when e.g. using PriorContext) can be considered a bug?

I suspect so, since these functions were mainly designed for MH and HMC. In these use cases, accumulating log joint and log likelihood are equivalent as long as the model is differentiable.

That said, for prudence, we should provide an explanation on this behavioural change in HISTORY and mark this PR as a breaking change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, in that case I'm happy to keep @addlogprob! in place; The change in its behaviour in this PR can actually be considered a fix to incorrect earlier behaviour. Also happy to add the NamedTuple version of it. Would you be okay with also adding @addloglikelihood!, and putting in an info print or a deprecation warning in @addlogprob! to nudge people towards using either @addlogprob! (likelihood=val,) or @addloglikelihood! val in the long term? I think long term (v1.0 and beyond) the more explicit syntax would be good.

@addlogprior!
```

Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).
Expand Down Expand Up @@ -328,9 +330,9 @@ The following functions were used for sequential Monte Carlo methods.

```@docs
get_num_produce
set_num_produce!
increment_num_produce!
reset_num_produce!
set_num_produce!!
increment_num_produce!!
reset_num_produce!!
setorder!
set_retained_vns_del!
```
Expand All @@ -345,6 +347,22 @@ Base.empty!
SimpleVarInfo
```

### Accumulators

The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.

```@docs
AbstractAccumulator
```

DynamicPPL provides the following default accumulators.

```@docs
LogPriorAccumulator
LogLikelihoodAccumulator
NumProduceAccumulator
```

### Common API

#### Accumulation of log-probabilities
Expand All @@ -353,6 +371,13 @@ SimpleVarInfo
getlogp
setlogp!!
acclogp!!
getlogjoint
getlogprior
setlogprior!!
acclogprior!!
getloglikelihood
setloglikelihood!!
accloglikelihood!!
resetlogp!!
```

Expand Down Expand Up @@ -427,9 +452,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`.
```@docs
SamplingContext
DefaultContext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity

Suggested change
DefaultContext
AccumulatorContext

LikelihoodContext
PriorContext
MiniBatchContext
PrefixContext
ConditionContext
```
Expand Down Expand Up @@ -476,7 +498,3 @@ DynamicPPL.Experimental.is_suitable_varinfo
```@docs
tilde_assume
```

```@docs
tilde_observe
```
12 changes: 6 additions & 6 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.

The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
simulate what new, unobserved data might look like, given your posterior beliefs.

For each parameter configuration in `chain`:
1. All random variables present in `chain` are fixed to their sampled values.
2. Any variables not included in `chain` are sampled from their prior distributions.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.

# Examples
Expand Down Expand Up @@ -124,7 +124,7 @@
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))

Check warning on line 127 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L127

Added line #L127 was not covered by tests
end

chain_result = reduce(
Expand Down
29 changes: 21 additions & 8 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedCollections, OrderedDict
using Printf: Printf

using AbstractMCMC: AbstractMCMC
using ADTypes: ADTypes
Expand Down Expand Up @@ -46,17 +47,28 @@ import Base:
export AbstractVarInfo,
VarInfo,
SimpleVarInfo,
AbstractAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
NumProduceAccumulator,
push!!,
empty!!,
subset,
getlogp,
getlogjoint,
getlogprior,
getloglikelihood,
setlogp!!,
setlogprior!!,
setloglikelihood!!,
acclogp!!,
acclogprior!!,
accloglikelihood!!,
resetlogp!!,
get_num_produce,
set_num_produce!,
reset_num_produce!,
increment_num_produce!,
set_num_produce!!,
reset_num_produce!!,
increment_num_produce!!,
set_retained_vns_del!,
is_flagged,
set_flag!,
Expand Down Expand Up @@ -92,15 +104,10 @@ export AbstractVarInfo,
# Contexts
SamplingContext,
DefaultContext,
LikelihoodContext,
PriorContext,
MiniBatchContext,
PrefixContext,
ConditionContext,
assume,
observe,
tilde_assume,
tilde_observe,
# Pseudo distributions
NamedDist,
NoDist,
Expand All @@ -120,6 +127,8 @@ export AbstractVarInfo,
to_submodel,
# Convenience macros
@addlogprob!,
@addlogprior!,
@addloglikelihood!,
@submodel,
value_iterator_from_chain,
check_model,
Expand All @@ -146,6 +155,9 @@ macro prob_str(str)
))
end

# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo
# has to implement. Not sure what the full list is for parameters values, but for
# accumulators we only need `getaccs` and `setaccs!!`.
Comment on lines +158 to +160
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put a list of AbstractVarInfo interface methods together the other day, feel free to ask me for it on Slack or something. We'd have to add the accumulator bits in, of course.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a good contribution to #899

"""
AbstractVarInfo

Expand All @@ -166,6 +178,7 @@ include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varnamedvector.jl")
include("accumulators.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
Expand Down
Loading
Loading