Skip to content

Conversation

@mhauru
Copy link
Member

@mhauru mhauru commented Dec 18, 2025

I put together a quick sketch of what it would look like to use VarNamedTuple as a VarInfo directly. By that I mean having a VarInfo type that is nothing but accumulators plus a VarNamedTuple that maps each VarName to a tuple (or actually a tiny struct, but anyway) of three values: Stored value for this variable, whether it's linked, and what transform should be applied to convert the stored value back to "model space". I'm calling this new VarInfo type VNTVarInfo (name to be changed later).

This isn't finished yet, but the majority of tests pass. There are a lot of failures around edge cases like Cholesky and weird VarNames and such, but for most simple models you can do

vi = VNTVarInfo(model)
vi = link!!(vi, model)
evaluate!!(model, vi)

and it'll give you the correct result. unflatten and vi[:] also work.

I'll keep working on this, but at this point I wanted to pause to do some benchmarks, see how viable this is. Benchmark code, very similar to #1182, running evaluate!! on our benchmarking models:

Details
module VIBench

using DynamicPPL, Distributions, Chairmarks
using StableRNGs: StableRNG
include("benchmarks/src/Models.jl")
using .Models: Models

function run()
    rng = StableRNG(23)

    smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100))

    loop_univariate1k, multivariate1k = begin
        data_1k = randn(rng, 1_000)
        loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k)
        multi = Models.multivariate(length(data_1k)) | (; o=data_1k)
        loop, multi
    end

    loop_univariate10k, multivariate10k = begin
        data_10k = randn(rng, 10_000)
        loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k)
        multi = Models.multivariate(length(data_10k)) | (; o=data_10k)
        loop, multi
    end

    # lda_instance = begin
    #     w = [1, 2, 3, 2, 1, 1]
    #     d = [1, 1, 1, 2, 2, 2]
    #     Models.lda(2, d, w)
    # end

    models = [
        ("simple_assume_observe", Models.simple_assume_observe(randn(rng))),
        ("smorgasbord", smorgasbord_instance),
        ("loop_univariate1k", loop_univariate1k),
        ("multivariate1k", multivariate1k),
        ("loop_univariate10k", loop_univariate10k),
        ("multivariate10k", multivariate10k),
        ("dynamic", Models.dynamic()),
        ("parent", Models.parent(randn(rng))),
        # ("lda", lda_instance),
    ]

    function print_diff(r, ref)
        diff = r.time - ref.time
        units = if diff < 1e-6
            "ns"
        elseif diff < 1e-3
            "µs"
        else
            "ms"
        end
        diff = if units == "ns"
            round(diff / 1e-9; digits=1)
        elseif units == "µs"
            round(diff / 1e-6; digits=1)
        else
            round(diff / 1e-3; digits=1)
        end
        sign = diff < 0 ? "" : "+"
        return println(" ($(sign)$(diff) $units)")
    end

    new = isdefined(DynamicPPL, :(VNTVarInfo))
    prefix = new ? "New" : "Old"

    for (name, m) in models
        println()
        println(name)
        vi = VarInfo(StableRNG(23), m)
        vi_linked = link!!(deepcopy(vi), m)
        # logp = getlogjoint(last(DynamicPPL.evaluate!!(m, vi)))
        # logp_linked = getlogjoint(last(DynamicPPL.evaluate!!(m, vi_linked)))
        # @show logp
        # @show logp_linked
        res = @b DynamicPPL.evaluate!!($m, $vi)
        print("$prefix unlinked: ")
        display(res)
        res = @b DynamicPPL.evaluate!!($m, $vi_linked)
        print("$prefix linked:   ")
        display(res)

        if !isdefined(DynamicPPL, :(VNTVarInfo))
            svi_nt = SimpleVarInfo(vi, NamedTuple)
            try
                res = @b DynamicPPL.evaluate!!($m, $svi_nt)
            catch e
                res = missing
            end
            print("SVI NT:       ")
            display(res)
            svi_od = SimpleVarInfo(vi, OrderedDict)
            res = @b DynamicPPL.evaluate!!($m, $svi_od)
            print("SVI OD:       ")
            display(res)
        end
    end
end

run()

end

Results contrasting the new VarInfo with both the old VarInfo and with SimpleVarInfo{NamedTuple} and SimpleVarInfo{OrderedDict}. Some SVI NT results are missing because it couldn't handle the IndexLenses:

simple_assume_observe
New unlinked: 2.778 ns
New linked:   12.201 ns
Old unlinked: 91.414 ns (4 allocs: 128 bytes)
Old linked:   80.752 ns (4 allocs: 128 bytes)
SVI NT:       2.468 ns
SVI OD:       4.941 ns

smorgasbord
New unlinked: 5.375 μs (12 allocs: 6.156 KiB)
New linked:   6.146 μs (18 allocs: 8.750 KiB)
Old unlinked: 16.375 μs (420 allocs: 33.375 KiB)
Old linked:   13.354 μs (325 allocs: 18.609 KiB)
SVI NT:       missing
SVI OD:       357.333 μs (3514 allocs: 98.891 KiB)

loop_univariate1k
New unlinked: 10.625 μs (6 allocs: 16.125 KiB)
New linked:   12.250 μs (6 allocs: 16.125 KiB)
Old unlinked: 64.542 μs (2009 allocs: 86.688 KiB)
Old linked:   58.625 μs (2009 allocs: 86.688 KiB)
SVI NT:       missing
SVI OD:       7.444 μs (6 allocs: 16.125 KiB)

multivariate1k
New unlinked: 11.125 μs (24 allocs: 80.500 KiB)
New linked:   11.250 μs (24 allocs: 80.500 KiB)
Old unlinked: 11.209 μs (29 allocs: 88.625 KiB)
Old linked:   11.208 μs (29 allocs: 88.625 KiB)
SVI NT:       10.708 μs (24 allocs: 80.500 KiB)
SVI OD:       10.833 μs (24 allocs: 80.500 KiB)

loop_univariate10k
New unlinked: 104.750 μs (6 allocs: 192.125 KiB)
New linked:   142.583 μs (6 allocs: 192.125 KiB)
Old unlinked: 752.542 μs (20009 allocs: 913.188 KiB)
Old linked:   614.750 μs (20009 allocs: 913.188 KiB)
SVI NT:       missing
SVI OD:       155.625 μs (6 allocs: 192.125 KiB)

multivariate10k
New unlinked: 107.500 μs (24 allocs: 896.500 KiB)
New linked:   106.459 μs (24 allocs: 896.500 KiB)
Old unlinked: 112.833 μs (29 allocs: 992.625 KiB)
Old linked:   110.500 μs (29 allocs: 992.625 KiB)
SVI NT:       106.000 μs (24 allocs: 896.500 KiB)
SVI OD:       110.292 μs (24 allocs: 896.500 KiB)

dynamic
New unlinked: 1.109 μs (12 allocs: 672 bytes)
New linked:   2.149 μs (43 allocs: 2.406 KiB)
Old unlinked: 1.854 μs (27 allocs: 1.891 KiB)
Old linked:   3.023 μs (53 allocs: 2.922 KiB)
SVI NT:       1.035 μs (12 allocs: 672 bytes)
SVI OD:       6.927 μs (75 allocs: 2.953 KiB)

parent
New unlinked: 2.777 ns
New linked:   10.967 ns
Old unlinked: 113.683 ns (6 allocs: 192 bytes)
Old linked:   106.579 ns (6 allocs: 192 bytes)
SVI NT:       missing
SVI OD:       4.948 ns

I think a fair TL;DR is that for both small models and models with IndexLenses this is many times faster than the old VarInfo, and not far off from SimpleVarInfo when SimpleVarInfo is at its fastest (NamedTuples for small models, OrderedDicts for IndexLenses). I would still like to close that gap a bit, I don't know why linking causes such a large slowdown in some cases, I suspect it's because the transform system is geared towards assuming we want to vectorise things, and I've hacked this together quickly to just get it to work.

For large models performance is essentially equal, as it should be, because this is about overheads. To fix that, I need to look into using views in some clever way, but that's for later.

I think this is a promising start towards being able to say that all of VarInfo, SimpleVarInfo, and VarNamedVector could be replaced with a direct use of VarNamedTuple (as opposed to e.g. VNT wrapping VarNamedVector), and it would be pretty close to being a best-of-all-worlds solution, in that it's almost as fast as SVI and has full support for all models.

Note that the new VNTVarInfo has no notion of typed and untyped VarInfos. They are all as typed as they can be, which should also help simplify code.

I'll keep working on this tomorrow.

@github-actions
Copy link
Contributor

github-actions bot commented Dec 18, 2025

Benchmark Report

  • this PR's head: 9ae56ab9ccfff435357418ee69b7f4166a514b8d
  • base branch: c917266c3370f7729ad9a756a15e3deabac2cee2

Computer Information

Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬──────────────────────────────┬────────────────────────────┬────────────────────────────────┐
│                       │       │             │                   │        │       t(eval) / t(ref)       │     t(grad) / t(eval)      │        t(grad) / t(ref)        │
│                       │       │             │                   │        │ ─────────┬─────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬──────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │     base │ this PR │ speedup │   base │ this PR │ speedup │      base │  this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼─────────┼─────────┼────────┼─────────┼─────────┼───────────┼──────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │   367.61 │     err │     err │   9.69 │     err │     err │   3563.01 │      err │     err │
│                   LDA │    12 │ reversediff │             typed │   true │  2642.28 │     err │     err │   5.11 │     err │     err │  13492.44 │      err │     err │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 58102.59 │     err │     err │   6.37 │     err │     err │ 369986.61 │      err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼─────────┼─────────┼────────┼─────────┼─────────┼───────────┼──────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │  5815.97 │     err │     err │   5.83 │     err │     err │  33895.00 │      err │     err │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │ 34741.61 │     err │     err │   9.59 │     err │     err │ 333334.06 │      err │     err │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │  3735.01 │     err │     err │   9.00 │     err │     err │  33627.80 │      err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼─────────┼─────────┼────────┼─────────┼─────────┼───────────┼──────────┼─────────┤
│ Simple assume observe │     0 │ forwarddiff │             typed │  false │      err │    2.35 │     err │    err │    3.81 │     err │       err │     8.97 │     err │
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │     2.64 │     err │     err │   4.20 │     err │     err │     11.07 │      err │     err │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │  1104.41 │     err │     err │ 134.56 │     err │     err │ 148604.46 │      err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼─────────┼─────────┼────────┼─────────┼─────────┼───────────┼──────────┼─────────┤
│           Smorgasbord │     0 │ forwarddiff │             typed │  false │      err │  967.78 │     err │    err │   66.31 │     err │       err │ 64169.86 │     err │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │      err │     err │     err │    err │     err │     err │       err │      err │     err │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │      err │     err │     err │    err │     err │     err │       err │      err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼─────────┼─────────┼────────┼─────────┼─────────┼───────────┼──────────┼─────────┤
│           Smorgasbord │   201 │      enzyme │             typed │   true │  1533.19 │     err │     err │   7.02 │     err │     err │  10769.60 │      err │     err │
│           Smorgasbord │   201 │    mooncake │             typed │   true │  1540.35 │     err │     err │   5.69 │     err │     err │   8761.17 │      err │     err │
│           Smorgasbord │   201 │ reversediff │             typed │   true │  1530.88 │     err │     err │  99.66 │     err │     err │ 152562.37 │      err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼─────────┼─────────┼────────┼─────────┼─────────┼───────────┼──────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │  1519.86 │     err │     err │  61.77 │     err │     err │  93884.22 │      err │     err │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │  1523.01 │     err │     err │  65.75 │     err │     err │ 100133.88 │      err │     err │
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │  1527.68 │     err │     err │  61.44 │     err │     err │  93866.20 │      err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼──────────┼─────────┼─────────┼────────┼─────────┼─────────┼───────────┼──────────┼─────────┤
│              Submodel │     1 │    mooncake │             typed │   true │     3.34 │     err │     err │  10.80 │     err │     err │     36.10 │      err │     err │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴──────────┴─────────┴─────────┴────────┴─────────┴─────────┴───────────┴──────────┴─────────┘

@penelopeysm
Copy link
Member

penelopeysm commented Dec 19, 2025

Darn, that is really good.

tuple (or actually a tiny struct, but anyway) of three values: Stored value for this variable, whether it's linked, and what transform should be applied to convert the stored value back to "model space"

Am I right in saying the latter two are only really needed for DefaultContext?

Edit: Actually, that's a silly question, if not for DefaultContext we don't even need the metadata field at all.

@penelopeysm
Copy link
Member

Also, I'm just eyeing this PR and thinking that it's a prime opportunity to clean up the varinfo interface, especially with the functions that return internal values when they probably shouldn't.

@mhauru
Copy link
Member Author

mhauru commented Dec 19, 2025

Also, I'm just eyeing this PR and thinking that it's a prime opportunity to clean up the varinfo interface, especially with the functions that return internal values when they probably shouldn't.

Yes. I'm first trying to make this work without making huge interface changes, just to make sure this can do everything that is needed to do, but I think interface changes should follow close behind, maybe in the same PR or the same release. They'll be much easier to make once there is only two VarInfo types that need to respect them, namely the new one and Threadsafe.

@mhauru mhauru changed the title VarNamedTuple as VarInfo VNT Part 5: VarNamedTuple as VarInfo Dec 19, 2025
@yebai
Copy link
Member

yebai commented Dec 19, 2025

Looks exciting! Two quick quesitons: would this be suitable to

  1. Implement simulation based inference algorithms, eg, particle MCMC, where model dimentionality or parameters support could change
  2. First model run to bootstrap / infer a VarInfo?

@mhauru
Copy link
Member Author

mhauru commented Dec 19, 2025

  1. Yep. The only thing I foresee being a problem is if some variable turns from e.g. being a Vector to being a Matrix, and you do IndexLens indexing into it. So first you have x[1] and then x[1,1]. That would be a problem. Other than that, should be fine.
  2. Yes. You can use the same type, VNTVarInfo, for both the first run when collecting variables, and for later runs when evaluating with known variables. No need for the typed/untyped distinction.

One thing I haven't benchmarked, and maybe should, is type unstable models. There is a possibility that type unstable models will be slower with the new approach, because VNTVarInfo is pretty aggressive in trying to make element types concrete, and if it keeps trying and failing again and again, that could cost a lot of time. Or it might be a negligible contribution to the performance-disaster that is a type unstable model. Need to benchmark.

Base automatically changed from mhauru/vnt-for-vaimacc to mhauru/vnt-for-fastldf January 7, 2026 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants