Skip to content

Commit 36bc868

Browse files
authored
Dev (#309)
* update Project.toml * s/logpdf/logdensity/ * update deps * scratchpad * version bump * representative => rootmeasure * update dependency versions * reduce dependencies * update deps * scratchpad * representative => rootmeasure * update dependency versions * reduce dependencies * `as` methods for `xform` * cleanup * require latest MeasureTheory * dorp old distributions code * drop old iid code * drop extra space * limit deps to three newest releases * update dynamichmc * add Aqua * bump version * Better `predict` method * withmeasures(::ConditionalModel) * update dependencies * updating symbolics * some updates to symbolics * hmm example * update MeasureBase bound * better dispatch for `predict` * drop redundant method * remove whitespace * bump version
1 parent cd93be7 commit 36bc868

File tree

5 files changed

+61
-12
lines changed

5 files changed

+61
-12
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Soss"
22
uuid = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
33
author = ["Chad Scherrer <[email protected]>"]
4-
version = "0.20.8"
4+
version = "0.20.9"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -49,7 +49,7 @@ JuliaVariables = "0.2"
4949
MLStyle = "0.3,0.4"
5050
MacroTools = "0.5"
5151
MappedArrays = "0.3, 0.4"
52-
MeasureBase = "0.4"
52+
MeasureBase = "0.5"
5353
MeasureTheory = "0.13"
5454
NamedTupleTools = "0.12, 0.13"
5555
NestedTuples = "0.3"
@@ -65,7 +65,7 @@ SpecialFunctions = "0.9, 0.10, 1"
6565
StatsBase = "0.33"
6666
StatsFuns = "0.9"
6767
SymbolicCodegen = "0.2"
68-
SymbolicUtils = "0.14, 0.15, 0.16"
68+
SymbolicUtils = "0.15, 0.16, 0.17"
6969
TransformVariables = "0.4"
7070
TupleVectors = "0.1"
7171
julia = "1.5"

scratchpad/hmm.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using MeasureTheory
2+
using Base.Iterators
3+
using Statistics
4+
5+
using Random
6+
rng = Random.Xoshiro(3)
7+
8+
x = Chain(Normal()) do xj Normal=xj) end
9+
xobs = rand(rng, x)
10+
y = For(xobs) do xj Poisson(logλ=xj) end
11+
yobs = rand(rng, y)
12+
xv = take(xobs, 10) |> collect
13+
yv = take(yobs, 10) |> collect
14+
15+
take(xobs.parent, 10) |> collect
16+
take(yobs.parent, 10) |> collect
17+
18+
19+
exp.(xv)
20+
yv
21+
22+
# using Plots
23+
24+
# plt = scatter(normcdf.(xv, 1, yv), label=false)
25+
26+
# for j in 1:10
27+
# xobs = rand(rng, x)
28+
# yobs = rand(rng, y)
29+
# xv = take(xobs, 100) |> collect;
30+
# yv = take(yobs, 100) |> collect;
31+
# plt = scatter!(plt, normcdf.(xv, 1, yv), label=false)
32+
# end
33+
# plt
34+
35+
using Soss
36+
37+
m = @model begin
38+
x ~ Chain(Normal()) do xj Normal=xj) end
39+
y ~ For(xobs) do xj Poisson(logλ=xj) end
40+
end
41+
42+
truth = rand(rng, m())
43+
44+
xobs = take(truth.x, 10) |> collect
45+
yobs = take(truth.y, 10) |> collect
46+
47+
logdensity(m(), (x=xobs, y=yobs))
48+

src/symbolic/codegen.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function SymbolicCodegen.codegen(cm :: ConditionalModel; kwargs...)
3232
pushfirst!(code.args, :($v = getproperty(_pars, $vname)))
3333
end
3434

35+
code = MacroTools.flatten(code)
3536

3637
return mk_function(getmodule(cm), (:_args, :_data, :_pars), (), code)
3738

src/symbolic/symbolic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ export schema
1919

2020
export symlogdensity
2121

22-
symlogdensity(d, x::Symbolic) = logdensity(d,x)
22+
symlogdensity(d, x::Symbolic) = logpdf(d,x)
2323

24-
function symlogdensity(d::ProductMeasure{<:AbstractArray}, x::Symbolic{A}) where {A <: AbstractArray}
24+
function symlogdensity(d::ProductMeasure{F,S,<:AbstractArray}, x::Symbolic{A}) where {F,S,A <: AbstractArray}
2525
dims = size(d)
2626

2727
iters = Sym{Int}.(gensym.(Symbol.(:i, 1:length(dims))))
2828

29-
marginals = d.data
29+
mar = marginals(d)
3030

3131
# To begin, the result is just the summand
32-
result = getsummand(marginals, x, iters)
32+
result = getsummand(mar, x, iters)
3333

3434
# Then we wrap in a summation index for each dimension
3535
for i in 1:length(dims)

src/transforms/predict.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
export predict
22
using TupleVectors
3+
using SampleChains
34

4-
predict(m::AbstractModel, args...) = predict(Random.GLOBAL_RNG, m, args...)
5-
predict(d::AbstractMeasure, x) = x
5+
predict(args...; kwargs...) = predict(Random.GLOBAL_RNG, args...; kwargs...)
66

7+
# TODO: Fix this hack
8+
predict(d::AbstractMeasure, x) = x
79
predict(d::Dists.Distribution, x) = x
10+
predict(d::AbstractModel, args...; kwargs...) = predict(Random.GLOBAL_RNG, d, args...; kwargs...)
811

912
@inline function predict(rng::AbstractRNG, m::AbstractModel, nt::NamedTuple{N}) where {N}
1013
pred = predictive(Model(m), N...)
@@ -13,7 +16,6 @@ end
1316

1417
predict(rng::AbstractRNG, m::AbstractModel; kwargs...) = predict(rng, m, (;kwargs...))
1518

16-
1719
@inline function predict(rng::AbstractRNG, d::AbstractModel, nt::LazyMerge)
1820
predict(rng, d, convert(NamedTuple, nt))
1921
end
@@ -33,8 +35,6 @@ function predict(rng::AbstractRNG, d::AbstractModel, post::AbstractVector{<:Name
3335
v
3436
end
3537

36-
using SampleChains
37-
3838
function predict(rng::AbstractRNG, d::ConditionalModel, post::MultiChain)
3939
[predict(rng, d, c) for c in getchains(post)]
4040
end

0 commit comments

Comments
 (0)