Skip to content

Commit 64ef0ee

Browse files
authored
Merge pull request #41 from TuringLang/torfjelde/Bijectors-compat
Bump Bijectors.jl compat bounds
2 parents bb7e85c + 7c2f30c commit 64ef0ee

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
- os: macOS-latest
2626
arch: x86
2727
include:
28-
- version: '1.0'
28+
- version: '1.6'
2929
os: ubuntu-latest
3030
arch: x64
3131
- os: ubuntu-latest
@@ -60,4 +60,4 @@ jobs:
6060
if: matrix.coverage
6161
with:
6262
github-token: ${{ secrets.GITHUB_TOKEN }}
63-
path-to-lcov: lcov.info
63+
path-to-lcov: lcov.info

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedVI"
22
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
3-
version = "0.1.6"
3+
version = "0.2.0"
44

55
[deps]
66
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
@@ -17,7 +17,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1717
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1818

1919
[compat]
20-
Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10"
20+
Bijectors = "0.11, 0.12"
2121
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
2222
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
2323
DocStringExtensions = "0.8, 0.9"
@@ -27,7 +27,7 @@ Requires = "0.5, 1.0"
2727
StatsBase = "0.32, 0.33"
2828
StatsFuns = "0.8, 0.9, 1"
2929
Tracker = "0.2.3"
30-
julia = "1"
30+
julia = "1.6"
3131

3232
[extras]
3333
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

src/AdvancedVI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ end
1919
const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
2020

2121
include("ad.jl")
22+
include("utils.jl")
2223

2324
using Requires
2425
function __init__()

src/advi.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ function (elbo::ELBO)(
8181
# = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃))
8282
# = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃))
8383

84-
# But our `forward(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac`
85-
_, z, logjac, _ = forward(rng, q)
84+
# But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac`
85+
z, logjac = rand_and_logjac(rng, q)
8686
res = (logπ(z) + logjac) / num_samples
8787

8888
if q isa TransformedDistribution
@@ -92,7 +92,7 @@ function (elbo::ELBO)(
9292
end
9393

9494
for i = 2:num_samples
95-
_, z, logjac, _ = forward(rng, q)
95+
z, logjac = rand_and_logjac(rng, q)
9696
res += (logπ(z) + logjac) / num_samples
9797
end
9898

src/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Distributions
2+
3+
using Random: Random
4+
using Bijectors: Bijectors
5+
6+
7+
function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution)
8+
x = rand(rng, dist)
9+
return x, zero(eltype(x))
10+
end
11+
12+
function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution)
13+
x = rand(rng, dist.dist)
14+
y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
15+
return y, logjac
16+
end

0 commit comments

Comments
 (0)