Skip to content

Conversation

@mhauru
Copy link
Member

@mhauru mhauru commented Oct 20, 2025

A sketch of what using locks to implement a thread-safe VarInfo might look like. The idea is summarised by this:

struct LockingVarInfo{T<:AbstractVarInfo} <: AbstractVarInfo
    inner::Ref{T}
    lock::ReentrantLock
end

function getindex(vi::LockingVarInfo, vn::VarName, dist::Distribution)
    return getindex(vi.inner[], vn, dist)
end

function BangBang.setindex!!(vi::LockingVarInfo, vals, vn::VarName)
    @lock vi.lock begin
        vi.inner[] = BangBang.setindex!!(vi.inner[], vals, vn)
    end
    return vi
end

function link(t::AbstractTransformation, vi::LockingVarInfo, args...)
    return LockingVarInfo(link(t, vi.inner[], args...))
end

In other words, all operations that may modify the VarInfo modify vi.inner while holding vi.lock. Everything else just pipes through to vi.inner[].

I have not thought this through carefully, and this certainly doesn't provide proper thread-safety. However, it's good enough to run the following little benchmark:

module Bench

using BenchmarkTools
using Distributions
using DynamicPPL
using LinearAlgebra
using Random

Random.seed!(23)

@model function f(dim)
    x ~ MvNormal(fill(0.0, dim), I)
    y = Vector{Float64}(undef, dim)
    Threads.@threads for i in eachindex(x)
        y[i] ~ Normal(x[i])
    end
end

m = f(1_000) | (; y=randn(1_000))

vi = VarInfo(m)
tvi = DynamicPPL.ThreadSafeVarInfo(vi)
lvi = DynamicPPL.LockingVarInfo(vi)
@show getlogjoint(last(DynamicPPL.evaluate_threadunsafe!!(m, tvi)))
@show getlogjoint(last(DynamicPPL.evaluate_threadunsafe!!(m, lvi)))
@btime DynamicPPL.evaluate_threadunsafe!!(m, tvi)
@btime DynamicPPL.evaluate_threadunsafe!!(m, lvi)

end

Results on 1 thread:

DynamicPPL.logjoint(m, tvi) = -3377.744304597237
DynamicPPL.logjoint(m, lvi) = -3377.744304597237
  383.042 μs (9509 allocations: 353.56 KiB)
  1.049 ms (20531 allocations: 682.55 KiB)

Results on 8 threads (on a 10 core laptop):

DynamicPPL.logjoint(m, tvi) = -3377.744304597236
DynamicPPL.logjoint(m, lvi) = -3377.744304597236
  95.500 μs (9544 allocations: 357.34 KiB)
  1.923 ms (20566 allocations: 686.33 KiB)

In sum: This is bad. Either I've done something very suboptimally, or locks have large overheads.

cc @yebai, @penelopeysm

penelopeysm and others added 29 commits August 8, 2025 11:21
* Implement InitContext

* Fix loading order of modules; move `prefix(::Model)` to model.jl

* Add tests for InitContext behaviour

* inline `rand(::Distributions.Uniform)`

Note that, apart from being simpler code, Distributions.Uniform also
doesn't allow the lower and upper bounds to be exactly equal (but we
might like to keep that option open in DynamicPPL, e.g. if the user
wants to initialise all values to the same value in linked space).

* Document

* Add a test to check that `init!!` doesn't change linking

* Fix `push!` for VarNamedVector

This should have been changed in #940, but slipped through as the file
wasn't listed as one of the changed files.

* Add some line breaks

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Add the option of no fallback for ParamsInit

* Improve docstrings

* typo

* `p.default` -> `p.fallback`

* Rename `{Prior,Uniform,Params}Init` -> `InitFrom{Prior,Uniform,Params}`

---------

Co-authored-by: Markus Hauru <markus@mhauru.org>
* use `varname_leaves` from AbstractPPL instead

* add changelog entry

* fix import
…!`, `predict`, `returned`, and `initialize_values` (#984)

* Replace `evaluate_and_sample!!` -> `init!!`

* Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends

* Use `init!!` for initialisation

* Paper over the `Sampling->Init` context stack (pending removal of SamplingContext)

* Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway

* Remove `predict` on vector of VarInfo

* Fix some tests

* Remove duplicated test

* Simplify context testing

* Rename FooInit -> InitFromFoo

* Fix JETExt

* Fix JETExt properly

* Fix tests

* Improve comments

* Remove duplicated tests

* Docstring improvements

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Concretise `chain_sample_to_varname_dict` using chain value type

* Clarify testset name

* Re-add comment that shouldn't have vanished

* Fix stale Requires dep

* Fix default_varinfo/initialisation for odd models

* Add comment to src/sampler.jl

Co-authored-by: Markus Hauru <markus@mhauru.org>

---------

Co-authored-by: Markus Hauru <markus@mhauru.org>
…niform}`, `{tilde_,}assume` (#985)

* Remove `SamplingContext` for good

* Remove `tilde_assume` as well

* Split up tilde_observe!! for Distribution / Submodel

* Tidy up tilde-pipeline methods and docstrings

* Fix tests

* fix ambiguity

* Add changelog

* Update HISTORY.md

Co-authored-by: Markus Hauru <markus@mhauru.org>

---------

Co-authored-by: Markus Hauru <markus@mhauru.org>
* Delete del

* Fix a typo

* Add HISTORY entry about del
* setleafcontext(model, ctx) and various other fixes

* fix a bug

* Add warning for `initial_parameters=...`
* Remove resume_from

* Format

* Fix test
* Enable NamedTuple/Dict initialisation

* Add more tests
* Fix `include_all` for predict

* Fix include_all for predict, some perf improvements
* Replace Medata.flags with Metadata.trans

* Fix a bug

* Fix a typo

* Fix two bugs

* Rename trans to is_transformed

* Rename islinked to is_transformed, remove duplication
* Change pointwise_logdensities default key type to VarName

* Fix a doctest
@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1078 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1078/

@penelopeysm
Copy link
Member

penelopeysm commented Oct 20, 2025

This is v intersting. Here are my completely unprompted thoughts:

  1. I've never used ReentrantLock before so I'm not sure whether the drop performance is to do with this specific method of concurrency, or whether it's a concurrency thing in general. The past times I've worked with concurrency, it's always via a Channel https://docs.julialang.org/en/v1/base/parallel/#Base.Channel. You can queue (or buffer) up to N things that you want to do, and then there's a separate thread that is solely dedicated to doing those things one by one. My suspicion is that the lock is completely unbuffered and I wonder if changing the buffer size makes a difference. (My recollection when working on AbstractMCMC parallel progress bars is that it makes a difference, although unfortunately I don't remember whether it makes it better or worse!)

  2. I did not implement this so I'm not saying that I think this is the right way to do things or that you should change it. But in my mind when I was thinking about these things, I was thinking that the lock should be at the higher level of tilde_obssume. What that means is that instead of queuing individual getindex / setindex operations, you queue entire tilde_obssume calls. The assumption of course is that everything within a single tilde_obssume is threadsafe. I don't know if that is better or worse. It could be quite conceivably worse.

@codecov
Copy link

codecov bot commented Oct 20, 2025

Codecov Report

❌ Patch coverage is 0% with 131 lines in your changes missing coverage. Please review.
✅ Project coverage is 78.06%. Comparing base (9bd8f16) to head (8fcfaf4).
⚠️ Report is 2 commits behind head on breaking.

Files with missing lines Patch % Lines
src/lockingvarinfo.jl 0.00% 131 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1078      +/-   ##
============================================
- Coverage     82.40%   78.06%   -4.34%     
============================================
  Files            42       41       -1     
  Lines          3791     3848      +57     
============================================
- Hits           3124     3004     -120     
- Misses          667      844     +177     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm
Copy link
Member

hold on, why is CI failing........ 😱

@mhauru
Copy link
Member Author

mhauru commented Oct 20, 2025

A local experiment shows me that about 80% of the time in the 8-core case is due to the @lock vi.lock thing in map_accumulators!!, which is what accumulate_obssume!! calls. In case you were wondering, like me, whether the problem was with the Refs or in some unexpected place.

We should try locking at a higher level, like at tilde_obbsume level. The nice thing with being lower down is that you can do the logpdf calls in parallel and only drop to serial execution when accumulating the results in accumulate_observe!!. We could also try locking at a lower level, namely at the level of a individual accumulators. I have no good intuition for whether these would be better/worse.

I've never used Channels and it's not obvious to me how to use them here. I guess you would queue something like accumulate_observe!! calls? Would need to think about how to manage the non-mutating nature of many accumulators.

@yebai
Copy link
Member

yebai commented Oct 20, 2025

A quick thought: we could introduce locks at the model level. This would prevent users from using threading, except through an explicit opt-in mechanism such as a hypothetical x ~ threadsafe_product_distribution(...).

@penelopeysm
Copy link
Member

The nice thing with being lower down is that you can do the logpdf calls in parallel and only drop to serial execution when accumulating the results in accumulate_observe!!

Oh, yes, that's completely correct. I didn't think of that. I guess that means even the current implementation here is too high level right? Because the logpdf calculation is actually done by the accumulator itself, way way inside map_accumulator.

And yeah, channels are pretty ugly to work with. The general structure of channel based code looks like this

Distributed.@sync begin
    Distributed.@async begin
        # read from the channel and do things with it one at a time
    end
    Distributed.@async begin
        # here do whatever you like that writes to the channel
    end
end

I think the whole thing would get lumped inside _evaluate!!, the first async block we would have to custom write, the second one would basically be model evaluation.

Because the Channel must contain concrete objects (not function calls) probably we'd need a struct that contains all the arguments to accumulate_obssume plus a boolean to indicate whether it's an assume or observe.

@mhauru
Copy link
Member Author

mhauru commented Oct 20, 2025

I guess that means even the current implementation here is too high level right? Because the logpdf calculation is actually done by the accumulator itself, way way inside map_accumulator.

Oh yes, silly me.

Another option could be to not even try to support threaded tilde_assume!!, but replace ThreadSafeVarInfo with accumulators that use Threads.Atomic. This could be really simple and easy for LogProbAccumulators.

@penelopeysm
Copy link
Member

I like the sound of AtomicAccumulator. In a way that is better because we want things like addlogprob to also be threadsafe, not just tildes. I think threadsafe distributions are also a good idea but atm I don't know how to generalise them. It would work very nicely for the iid case though.

@mhauru
Copy link
Member Author

mhauru commented Oct 20, 2025

I think the problem with atomic accumulators is that if we take that road, all accumulators must be made thread-safe. I can easily wrap the LogPriorAccumulator.logp field in a Threads.Atomic and be very happy, but then I'm also committed to making DebugAccumulator thread-safe and that doesn't sound fun.

A possible way out would be to have an abstract type ThreadSafeAccumulator <: AbstractAccumulator, and anything that is not of that type falls back onto using locks. I don't care if I pay an overhead for using locks on something like PointwiseLogdensityAccumulator. Not sure where to put this fallback locking code though, in such a way that it doesn't complicate and cause overheads in the accummulate_obssume!! call stack.

Base automatically changed from breaking to main October 21, 2025 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants