Skip to content

Commit 4795704

Browse files
authored
Merge pull request #113 from tpapp/tp/fix-staticarray-typeinference
Fix static array type inference.
2 parents a1b25a9 + 50806cd commit 4795704

File tree

5 files changed

+23
-13
lines changed

5 files changed

+23
-13
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TransformVariables"
22
uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff"
33
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
4-
version = "0.8.5"
4+
version = "0.8.6"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/TransformVariables.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using LogExpFunctions
77
using LinearAlgebra: UpperTriangular, logabsdet
88
using UnPack: @unpack
99
using Random: AbstractRNG, GLOBAL_RNG
10-
using StaticArrays: MMatrix, SMatrix, SArray
10+
using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst
1111

1212
import ChangesOfVariables
1313
import InverseFunctions

src/aggregation.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,22 @@ end
145145
result_size(::StaticArrayTransformation{D,S}) where {D,S} = fieldtypes(S)
146146

147147
function transform_with(flag::LogJacFlag, transformation::StaticArrayTransformation{D,S},
148-
x::AbstractVector{T}, index) where {D,S,T}
148+
x::AbstractVector{T}, index::Int) where {D,S,T}
149149
@unpack inner_transformation = transformation
150-
= logjac_zero(flag, robust_eltype(x))
151-
SArray{S}(begin
152-
y, ℓΔ, index′ = transform_with(flag, inner_transformation, x, index)
153-
index = index′
154-
+= ℓΔ
155-
y
156-
end
157-
for _ in 1:D), ℓ, index
150+
# NOTE this is a fix for #112, enforcing types taken from the transformation of the
151+
# first element.
152+
y1, ℓ1, index1 = transform_with(flag, inner_transformation, x, index)
153+
L = typeof(ℓ1)
154+
let::L = ℓ1, index::Int = index1
155+
function _f(_)
156+
y, ℓΔ, index′ = transform_with(flag, inner_transformation, x, index)
157+
index = index′
158+
=+ ℓΔ
159+
y
160+
end
161+
yrest = SVector{D-1}(_f(i) for i in 2:D)
162+
SArray{S}(pushfirst(yrest, y1)), ℓ, index
163+
end
158164
end
159165

160166
function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation},

src/generic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ $(SIGNATURES)
4444
4545
Initial value for log Jacobian calculations.
4646
"""
47-
logjac_zero(::LogJac, T::Type{<:Real}) = log(one(T))
47+
logjac_zero(::LogJac, ::Type{T}) where {T<:Real} = log(one(T))
4848

4949
logjac_zero(::NoLogJac, _) = NOLOGJAC
5050

test/runtests.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using LogDensityProblems: logdensity, logdensity_and_gradient
66
using LogDensityProblemsAD
77
using TransformVariables:
88
AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation,
9-
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac
9+
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, NOLOGJAC, transform_with
1010
import ChangesOfVariables, InverseFunctions
1111
using Enzyme: autodiff, Reverse, Active, Const
1212

@@ -640,3 +640,7 @@ end
640640
d = as(SVector{2}, asℝ₊)))
641641
@test [domain_label(t, i) for i in 1:dimension(t)] == [".a", ".b[1,1]", ".c[1]", ".d[1]", ".d[2]"]
642642
end
643+
644+
@testset "static arrays inference" begin
645+
@test @inferred transform_with(NOLOGJAC, as(SVector{3, Float64}), zeros(3), 1) == (SVector(0.0, 0.0, 0.0), NOLOGJAC, 4)
646+
end

0 commit comments

Comments
 (0)