Skip to content

Run formatter, start using always_use_return = true #886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
style="blue"
format_markdown = true
# The below should actually be part of Blue according to
# https://github.com/JuliaDiff/BlueStyle?tab=readme-ov-file#method-definitions
# but JuliaFormatter v2.10 doesn't enforce it.
always_use_return = true
8 changes: 4 additions & 4 deletions benchmarks/src/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ A short model that tries to cover many DynamicPPL features.
Includes scalar, vector univariate, and multivariate variables; ~, .~, and loops; allocating
a variable vector; observations passed as arguments, and as literals.
"""
@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV}
@model function smorgasbord(x, y, (::Type{TV})=Vector{Float64}) where {TV}
@assert length(x) == length(y)
m ~ truncated(Normal(); lower=0)
means ~ product_distribution(fill(Exponential(m), length(x)))
Expand All @@ -68,7 +68,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio

See `multivariate` for a version that uses `product_distribution` rather than loops.
"""
@model function loop_univariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
@model function loop_univariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV}
a = TV(undef, num_dims)
o = TV(undef, num_dims)
for i in 1:num_dims
Expand All @@ -88,7 +88,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio

See `loop_univariate` for a version that uses loops rather than `product_distribution`.
"""
@model function multivariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
@model function multivariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV}
a = TV(undef, num_dims)
o = TV(undef, num_dims)
a ~ product_distribution(fill(Normal(0, 1), num_dims))
Expand Down Expand Up @@ -118,7 +118,7 @@ end
A model with random variables that have changing support under linking, or otherwise
complicated bijectors.
"""
@model function dynamic(::Type{T}=Vector{Float64}) where {T}
@model function dynamic((::Type{T})=Vector{Float64}) where {T}
eta ~ truncated(Normal(); lower=0.0, upper=0.1)
mat1 ~ LKJCholesky(4, eta)
mat2 ~ InverseWishart(3.2, cholesky([1.0 0.5; 0.5 1.0]))
Expand Down
16 changes: 8 additions & 8 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ end
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.

The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
simulate what new, unobserved data might look like, given your posterior beliefs.

For each parameter configuration in `chain`:
1. All random variables present in `chain` are fixed to their sampled values.
2. Any variables not included in `chain` are sampled from their prior distributions.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.

# Examples
Expand Down Expand Up @@ -161,8 +161,8 @@ function _predictive_samples_to_arrays(predictive_samples)

variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
get(sample_dicts[i], key, missing) for
i in eachindex(sample_dicts), key in variable_names
]

return variable_names, variable_values
Expand Down Expand Up @@ -254,7 +254,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
return model(deepcopy(varinfo))
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ end

function namedtuple_from_splitargs(splitargs)
names = map(splitargs) do (arg_name, arg_type, is_splat, default)
is_splat ? Symbol("#splat#$(arg_name)") : arg_name
return is_splat ? Symbol("#splat#$(arg_name)") : arg_name
end
names_expr = Expr(:tuple, map(QuoteNode, names)...)
vals = Expr(:tuple, map(first, splitargs)...)
Expand Down
4 changes: 2 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ function has_static_constraints(
rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs...
)
results = map(1:num_evals) do _
check_model_and_trace(rng, model; kwargs...)
return check_model_and_trace(rng, model; kwargs...)
end
issuccess = all(first, results)
issuccess || throw(ArgumentError("model check failed"))
Expand All @@ -530,7 +530,7 @@ function has_static_constraints(
traces = map(last, results)
dists_per_trace = map(distributions_in_trace, traces)
transforms = map(dists_per_trace) do dists
map(DynamicPPL.link_transform, dists)
return map(DynamicPPL.link_transform, dists)
end

# Check if the distributions are the same across all runs.
Expand Down
5 changes: 3 additions & 2 deletions src/extract_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ julia> length(extract_priors(rng, model)[@varname(x)])
9
```
"""
extract_priors(args::Union{Model,AbstractVarInfo}...) =
extract_priors(Random.default_rng(), args...)
function extract_priors(args::Union{Model,AbstractVarInfo}...)
return extract_priors(Random.default_rng(), args...)
end
function extract_priors(rng::Random.AbstractRNG, model::Model)
context = PriorExtractorContext(SamplingContext(rng))
evaluate!!(model, VarInfo(), context)
Expand Down
8 changes: 5 additions & 3 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ end
Evaluate the log density of the given `model` at the given parameter values `x`,
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
only for its structure, in the sense that the parameters from the vector `x` are inserted into
it, and its own parameters are discarded.
it, and its own parameters are discarded.
"""
function logdensity_at(
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
Expand Down Expand Up @@ -245,9 +245,11 @@ model.

By default, this just returns the input unchanged.
"""
tweak_adtype(
function tweak_adtype(
adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext
) = adtype
)
return adtype
end

"""
use_closure(adtype::ADTypes.AbstractADType)
Expand Down
16 changes: 9 additions & 7 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ Return a `Model` which now treats variables on the right-hand side as observatio

See [`condition`](@ref) for more information and examples.
"""
Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
condition(model, values)
function Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}})
return condition(model, values)
end

"""
condition(model::Model; values...)
Expand Down Expand Up @@ -1068,7 +1069,7 @@ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
logjoint(model, argvals_dict)
return logjoint(model, argvals_dict)
end
end

Expand Down Expand Up @@ -1115,7 +1116,7 @@ function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
logprior(model, argvals_dict)
return logprior(model, argvals_dict)
end
end

Expand Down Expand Up @@ -1162,7 +1163,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
loglikelihood(model, argvals_dict)
return loglikelihood(model, argvals_dict)
end
end

Expand Down Expand Up @@ -1467,5 +1468,6 @@ ERROR: ArgumentError: `~` with a model on the right-hand side of an observe stat
[...]
```
"""
to_submodel(model::Model, auto_prefix::Bool=true) =
to_sampleable(returned(model), auto_prefix)
function to_submodel(model::Model, auto_prefix::Bool=true)
return to_sampleable(returned(model), auto_prefix)
end
4 changes: 3 additions & 1 deletion src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@
return Iterators.map(
Iterators.product(1:size(chain, 1), 1:size(chain, 3))
) do (iteration_idx, chain_idx)
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
return values_from_chain!(

Check warning on line 207 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L207

Added line #L207 was not covered by tests
vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}()
)
end
end
6 changes: 3 additions & 3 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
end

# Constructor from `VarInfo`.
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
function SimpleVarInfo(vi::TypedVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D}
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
end
function SimpleVarInfo{T}(
Expand Down Expand Up @@ -315,7 +315,7 @@
end
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
vals_linked = mapreduce(vcat, vns) do vn
getindex(vi, vn, dist)
return getindex(vi, vn, dist)

Check warning on line 318 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L318

Added line #L318 was not covered by tests
end
return recombine(dist, vals_linked, length(vns))
end
Expand Down Expand Up @@ -362,7 +362,7 @@
# Attempt to split into `parent` and `child` optic.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(dict, VarName(vn, o))
return haskey(dict, VarName(vn, o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand Down
22 changes: 13 additions & 9 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ x[4:5] ~ Dirichlet([1.0, 2.0])
```
"""
@model function demo_one_variable_multiple_constraints(
::Type{TV}=Vector{Float64}
(::Type{TV})=Vector{Float64}
) where {TV}
x = TV(undef, 5)
x[1] ~ Normal()
Expand Down Expand Up @@ -186,7 +186,9 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp
end

@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV}
@model function demo_dot_assume_observe(
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
) where {TV}
# `dot_assume` and `observe`
s = TV(undef, length(x))
m = TV(undef, length(x))
Expand All @@ -212,7 +214,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe)})
end

@model function demo_assume_index_observe(
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
) where {TV}
# `assume` with indexing and `observe`
s = TV(undef, length(x))
Expand Down Expand Up @@ -268,7 +270,7 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe)})
end

@model function demo_dot_assume_observe_index(
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
) where {TV}
# `dot_assume` and `observe` with indexing
s = TV(undef, length(x))
Expand Down Expand Up @@ -348,7 +350,9 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)}
return [@varname(s), @varname(m)]
end

@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
@model function demo_dot_assume_observe_index_literal(
(::Type{TV})=Vector{Float64}
) where {TV}
# `dot_assume` and literal `observe` with indexing
s = TV(undef, 2)
m = TV(undef, 2)
Expand Down Expand Up @@ -425,7 +429,7 @@ function varnames(model::Model{typeof(demo_assume_dot_observe_literal)})
end

# Only used as a submodel
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
@model function _prior_dot_assume((::Type{TV})=Vector{Float64}) where {TV}
s = TV(undef, 2)
s .~ InverseGamma(2, 3)
m = TV(undef, 2)
Expand Down Expand Up @@ -466,7 +470,7 @@ end
end

@model function demo_dot_assume_observe_submodel(
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
) where {TV}
s = TV(undef, length(x))
s .~ InverseGamma(2, 3)
Expand Down Expand Up @@ -496,7 +500,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)})
end

@model function demo_dot_assume_observe_matrix_index(
x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64}
x=transpose([1.5 2.0;]), (::Type{TV})=Vector{Float64}
) where {TV}
s = TV(undef, length(x))
s .~ InverseGamma(2, 3)
Expand Down Expand Up @@ -525,7 +529,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)})
end

@model function demo_assume_matrix_observe_matrix_index(
x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64}
x=transpose([1.5 2.0;]), (::Type{TV})=Array{Float64}
) where {TV}
n = length(x)
d = n ÷ 2
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function setup_varinfos(
svi_vnv_ref,
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
return DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end

if include_threadsafe
Expand Down
10 changes: 5 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ ERROR: Could not find x.a[2] in x.a[1]
function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
_, child, issuccess = splitoptic(getoptic(vn_child)) do optic
o = optic === nothing ? identity : optic
VarName(vn_child, o) == vn_parent
return VarName(vn_child, o) == vn_parent
end

issuccess || error("Could not find $vn_parent in $vn_child")
Expand Down Expand Up @@ -905,7 +905,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(vals, VarName(vn, o))
return haskey(vals, VarName(vn, o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand Down Expand Up @@ -934,7 +934,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
# Split the optic into the key / `parent` and the extraction optic / `child`.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(values, VarName(vn, o))
return haskey(values, VarName(vn, o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent
Expand Down Expand Up @@ -1078,7 +1078,7 @@ end
function varname_leaves(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = Accessors.PropertyLens{sym}()
varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val))
return varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val))
end
return Iterators.flatten(iter)
end
Expand Down Expand Up @@ -1244,7 +1244,7 @@ end
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = DynamicPPL.Accessors.PropertyLens{sym}()
varname_and_value_leaves_inner(
return varname_and_value_leaves_inner(
VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)
)
end
Expand Down
Loading
Loading