From 54b34355533ad700fc46c863f024135a29751ae5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 01:41:39 +0000 Subject: [PATCH 01/60] Begin VarName rework --- src/AbstractPPL.jl | 1 + src/varname/optic.jl | 145 ++++++++++++++ src/varname/varname.jl | 432 ++++++++++------------------------------- 3 files changed, 251 insertions(+), 327 deletions(-) create mode 100644 src/varname/optic.jl diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index d7585784..3fb08f83 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -32,6 +32,7 @@ export AbstractModelTrace include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") +include("varname/optic.jl") include("varname/varname.jl") include("varname/subsumes.jl") include("varname/hasvalue.jl") diff --git a/src/varname/optic.jl b/src/varname/optic.jl new file mode 100644 index 00000000..9227612a --- /dev/null +++ b/src/varname/optic.jl @@ -0,0 +1,145 @@ +using Accessors: Accessors + +""" + AbstractOptic + +An abstract type that represents the non-symbol part of a VarName, i.e., the section of the +variable that is of interest. For example, in `x.a[1][2]`, the `AbstractOptic` represents +the `.a[1][2]` part. + +# Interface + +This is WIP. + +- Base.show +- to_accessors(optic) -> Accessors.Lens (recovering the old representation) + +Not sure if we want to introduce getters and setters and BangBang-style stuff. +""" +abstract type AbstractOptic end +function Base.show(io::IO, optic::AbstractOptic) + print(io, "Optic(") + pretty_print_optic(io, optic) + return print(io, ")") +end + +""" + Iden() + +The identity optic. This is the optic used when we are referring to the entire variable. +It is also the base case for composing optics. +""" +struct Iden <: AbstractOptic end +pretty_print_optic(::IO, ::Iden) = nothing +to_accessors(::Iden) = identity +concretize(i::Iden, ::Any) = i + +""" + DynamicIndex + +An abstract type representing dynamic indices such as `begin`, `end`, and `:`. These indices +are things which cannot be resolved until we provide the value that is being indexed into. +When parsing VarNames, we convert such indices into subtypes of `DynamicIndex`, and we later +mark them as requiring concretisation. +""" +abstract type DynamicIndex end +# Fallback for all other indices +concretize(@nospecialize(ix::Any), ::Any, ::Any) = ix + +struct DynamicBegin <: DynamicIndex end +concretize(::DynamicBegin, val, dim::Nothing) = Base.firstindex(val) +concretize(::DynamicBegin, val, dim) = Base.firstindex(val, dim) + +struct DynamicEnd <: DynamicIndex end +concretize(::DynamicEnd, val, dim::Nothing) = Base.lastindex(val) +concretize(::DynamicEnd, val, dim) = Base.lastindex(val, dim) + +struct DynamicColon <: DynamicIndex end +concretize(::DynamicColon, val, dim::Nothing) = Base.firstindex(val):Base.lastindex(val) +concretize(::DynamicColon, val, dim) = Base.firstindex(val, dim):Base.lastindex(val, dim) + +struct DynamicRange{T1,T2} <: DynamicIndex + start::T1 + stop::T2 +end +function concretize(dr::DynamicRange, axis) + start = dr.start isa DynamicIndex ? concretize(dr.start, axis) : dr.start + stop = dr.stop isa DynamicIndex ? concretize(dr.stop, axis) : dr.stop + return start:stop +end + +""" + Index(ix, child=Iden()) + +An indexing optic representing access to indices `ix`. A VarName{:x} with this optic +represents access to `x[ix...]`. The child optic represents any further indexing or +property access after this indexing operation. +""" +struct Index{I<:Tuple,C<:AbstractOptic} <: AbstractOptic + ix::I + child::C +end +Index(ix::Tuple, child::C=Iden()) where {C<:AbstractOptic} = Index{typeof(ix),C}(ix, child) + +Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.child == b.child +Base.isequal(a::Index, b::Index) = a == b +function pretty_print_optic(io::IO, idx::Index) + ixs = join(idx.ix, ", ") + print(io, "[$(ixs)]") + return pretty_print_optic(io, idx.child) +end +function to_accessors(idx::Index) + ilens = Accessors.IndexLens(idx.ix) + return if idx.child isa Iden + ilens + else + Base.ComposedFunction(to_accessors(idx.child), ilens) + end +end +function concretize(idx::Index, val) + concretized_indices = if length(idx.ix) == 0 + [] + elseif length(idx.ix) == 1 + # If there's only one index, it's linear indexing. This code is mostly lifted from + # Accessors.jl. + [concretize(only(idx.ix), val, nothing)] + else + # If there are multiple indices, then each index corresponds to a different + # dimension. + [concretize(ix, val, dim) for (dim, ix) in enumerate(idx.ix)] + end + inner_concretized = concretize(idx.child, val[concretized_indices...]) + return Index((concretized_indices...,), inner_concretized) +end + +""" + Property{sym}(child=Iden()) + +A property access optic representing access to property `sym`. A VarName{:x} with this +optic represents access to `x.sym`. The child optic represents any further indexing +or property access after this property access operation. +""" +struct Property{sym,C<:AbstractOptic} <: AbstractOptic + child::C +end +Property{sym}(child::C=Iden()) where {sym,C<:AbstractOptic} = Property{sym,C}(child) + +Base.:(==)(a::Property{sym}, b::Property{sym}) where {sym} = a.child == b.child +Base.:(==)(a::Property, b::Property) = false +Base.isequal(a::Property, b::Property) = a == b +function pretty_print_optic(io::IO, prop::Property{sym}) where {sym} + print(io, ".$(sym)") + return pretty_print_optic(io, prop.child) +end +function to_accessors(prop::Property{sym}) where {sym} + plens = Accessors.PropertyLens{sym}() + return if prop.child isa Iden + plens + else + Base.ComposedFunction(to_accessors(prop.child), plens) + end +end +function concretize(prop::Property{sym}, val) where {sym} + inner_concretized = concretize(prop.child, getproperty(val, sym)) + return Property{sym}(inner_concretized) +end diff --git a/src/varname/varname.jl b/src/varname/varname.jl index c2916de6..057b1d41 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -1,122 +1,20 @@ -using Accessors -using Accessors: PropertyLens, IndexLens, DynamicIndexLens - -# nb. ComposedFunction is the same as Accessors.ComposedOptic -const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedFunction} - """ VarName{sym}(optic=identity) -A variable identifier for a symbol `sym` and optic `optic`. - -The Julia variable in the model corresponding to `sym` can refer to a single value or to a -hierarchical array structure of univariate, multivariate or matrix variables. The field `lens` -stores the indices requires to access the random variable from the Julia variable indicated by `sym` -as a tuple of tuples. Each element of the tuple thereby contains the indices of one optic -operation. - -`VarName`s can be manually constructed using the `VarName{sym}(optic)` constructor, or from an -optic expression through the [`@varname`](@ref) convenience macro. - -# Examples - -```jldoctest; setup=:(using Accessors) -julia> vn = VarName{:x}(Accessors.IndexLens((Colon(), 1)) ⨟ Accessors.IndexLens((2, ))) -x[:, 1][2] - -julia> getoptic(vn) -(@o _[Colon(), 1][2]) +A variable identifier for a symbol `sym` and optic `optic`. `sym` refers to the name of the +top-level Julia variable, while `optic` allows one to specify a particular property or index +inside that variable. -julia> @varname x[:, 1][1+1] -x[:, 1][2] -``` +`VarName`s can be manually constructed using the `VarName{sym}(optic)` constructor, or from +an optic expression through the [`@varname`](@ref) convenience macro. """ -struct VarName{sym,T<:ALLOWED_OPTICS} +struct VarName{sym,T<:AbstractOptic} optic::T - - function VarName{sym}(optic=identity) where {sym} - optic = normalise(optic) - if !is_static_optic(typeof(optic)) - throw( - ArgumentError( - "attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))", - ), - ) - end + function VarName{sym}(optic=Iden()) where {sym} return new{sym,typeof(optic)}(optic) end end -""" - is_static_optic(l) - -Return `true` if `l` is one or a composition of `identity`, `PropertyLens`, and `IndexLens`; `false` if `l` is -one or a composition of `DynamicIndexLens`; and undefined otherwise. -""" -is_static_optic(::Type{<:Union{typeof(identity),PropertyLens,IndexLens}}) = true -function is_static_optic(::Type{ComposedFunction{LO,LI}}) where {LO,LI} - return is_static_optic(LO) && is_static_optic(LI) -end -is_static_optic(::Type{<:DynamicIndexLens}) = false - -""" - normalise(optic) - -Enforce that compositions of optics are always nested in the same way, in that -a ComposedFunction never has a ComposedFunction as its inner lens. Thus, for -example, - -```jldoctest; setup=:(using Accessors) -julia> op1 = ((@o _.c) ∘ (@o _.b)) ∘ (@o _.a) -(@o _.a.b.c) - -julia> op2 = (@o _.c) ∘ ((@o _.b) ∘ (@o _.a)) -(@o _.c) ∘ ((@o _.a.b)) - -julia> op1 == op2 -false - -julia> AbstractPPL.normalise(op1) == AbstractPPL.normalise(op2) == @o _.a.b.c -true -``` - -This function also removes redundant `identity` optics from ComposedFunctions: - -```jldoctest; setup=:(using Accessors) -julia> op3 = ((@o _.b) ∘ identity) ∘ (@o _.a) -(@o identity(_.a).b) - -julia> op4 = (@o _.b) ∘ (identity ∘ (@o _.a)) -(@o _.b) ∘ ((@o identity(_.a))) - -julia> AbstractPPL.normalise(op3) == AbstractPPL.normalise(op4) == @o _.a.b -true -``` -""" -function normalise(o::ComposedFunction{Outer,<:ComposedFunction}) where {Outer} - # `o` is currently (outer ∘ (inner_outer ∘ inner_inner)). - # We want to change this to: - # o = (outer ∘ inner_outer) ∘ inner_inner - inner_inner = o.inner.inner - inner_outer = o.inner.outer - # Recursively call normalise because inner_inner could itself be a - # ComposedFunction - return normalise((o.outer ∘ inner_outer) ∘ inner_inner) -end -function normalise(o::ComposedFunction{Outer,typeof(identity)} where {Outer}) - # strip outer identity - return normalise(o.outer) -end -function normalise(o::ComposedFunction{typeof(identity),Inner} where {Inner}) - # strip inner identity - return normalise(o.inner) -end -normalise(o::ComposedFunction) = normalise(o.outer) ∘ o.inner -normalise(o::ALLOWED_OPTICS) = o -# These two methods are needed to avoid method ambiguity. -normalise(o::ComposedFunction{typeof(identity),<:ComposedFunction}) = normalise(o.inner) -normalise(::ComposedFunction{typeof(identity),typeof(identity)}) = identity - """ getsym(vn::VarName) @@ -143,78 +41,26 @@ Return the optic of the Julia variable used to generate `vn`. ```jldoctest julia> getoptic(@varname(x[1][2:3])) -(@o _[1][2:3]) +[1][2:3] julia> getoptic(@varname(y)) -identity (generic function with 1 method) +Iden() ``` """ getoptic(vn::VarName) = vn.optic -""" - get(obj, vn::VarName{sym}) - -Alias for `(PropertyLens{sym}() ⨟ getoptic(vn))(obj)`. -``` -""" -function Base.get(obj, vn::VarName{sym}) where {sym} - return (PropertyLens{sym}() ⨟ getoptic(vn))(obj) -end - -""" - set(obj, vn::VarName{sym}, value) - -Alias for `set(obj, PropertyLens{sym}() ⨟ getoptic(vn), value)`. - -# Example - -```jldoctest; setup = :(using AbstractPPL: Accessors; nt = (a = 1, b = (c = [1, 2, 3],)); name = :nt) -julia> Accessors.set(nt, @varname(a), 10) -(a = 10, b = (c = [1, 2, 3],)) - -julia> Accessors.set(nt, @varname(b.c[1]), 10) -(a = 1, b = (c = [10, 2, 3],)) -``` -""" -function Accessors.set(obj, vn::VarName{sym}, value) where {sym} - return Accessors.set(obj, PropertyLens{sym}() ⨟ getoptic(vn), value) -end - -# Allow compositions with optic. -function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym}) where {sym} - return VarName{sym}(optic ∘ getoptic(vn)) -end - -Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) function Base.:(==)(x::VarName, y::VarName) return getsym(x) == getsym(y) && getoptic(x) == getoptic(y) end +Base.isequal(x::VarName, y::VarName) = x == y + +Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T} print(io, getsym(vn)) - return _show_optic(io, getoptic(vn)) -end - -# modified from https://github.com/JuliaObjects/Accessors.jl/blob/01528a81fdf17c07436e1f3d99119d3f635e4c26/src/sugar.jl#L502 -function _show_optic(io::IO, optic) - opts = Accessors.deopcompose(optic) - inner = Iterators.takewhile(x -> applicable(_shortstring, "", x), opts) - outer = Iterators.dropwhile(x -> applicable(_shortstring, "", x), opts) - if !isempty(outer) - show(io, opcompose(outer...)) - print(io, " ∘ ") - end - shortstr = reduce(_shortstring, inner; init="") - return print(io, shortstr) + return pretty_print_optic(io, getoptic(vn)) end -_shortstring(prev, o::IndexLens) = "$prev[$(join(map(prettify_index, o.indices), ", "))]" -_shortstring(prev, ::typeof(identity)) = "$prev" -_shortstring(prev, o) = Accessors._shortstring(prev, o) - -prettify_index(x) = repr(x) -prettify_index(::Colon) = ":" - """ Symbol(vn::VarName) @@ -229,93 +75,18 @@ julia> Symbol(@varname(x[1][:])) Symbol("x[1][:]") ``` """ -Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol - -""" - ConcretizedSlice(::Base.Slice) - -An indexing object wrapping the range of a `Base.Slice` object representing the concrete indices a -`:` indicates. Behaves the same, but prints differently, namely, still as `:`. -""" -struct ConcretizedSlice{T,R} <: AbstractVector{T} - range::R -end - -function ConcretizedSlice(s::Base.Slice{R}) where {R} - return ConcretizedSlice{eltype(s.indices),R}(s.indices) -end -Base.show(io::IO, s::ConcretizedSlice) = print(io, ":") -function Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) - return print(io, "ConcretizedSlice(", s.range, ")") -end -Base.size(s::ConcretizedSlice) = size(s.range) -Base.iterate(s::ConcretizedSlice, state...) = Base.iterate(s.range, state...) -Base.collect(s::ConcretizedSlice) = collect(s.range) -Base.getindex(s::ConcretizedSlice, i) = s.range[i] -Base.hasfastin(::Type{<:ConcretizedSlice}) = true -Base.in(i, s::ConcretizedSlice) = i in s.range - -# and this is the reason why we are doing this: -Base.to_index(A, s::ConcretizedSlice) = Base.Slice(s.range) - -""" - reconcretize_index(original_index, lowered_index) - -Create the index to be emitted in `concretize`. `original_index` is the original, unconcretized -index, and `lowered_index` the respective position of the result of `to_indices`. - -The only purpose of this are special cases like `:`, which we want to avoid becoming a -`Base.Slice(OneTo(...))` -- it would confuse people when printed. Instead, we concretize to a -`ConcretizedSlice` based on the `lowered_index`, just what you'd get with an explicit `begin:end` -""" -reconcretize_index(original_index, lowered_index) = lowered_index -function reconcretize_index(original_index::Colon, lowered_index::Base.Slice) - return ConcretizedSlice(lowered_index) -end - -""" - concretize(l, x) - -Return `l` instantiated on `x`, i.e. any information related to the runtime shape of `x` is -evaluated. This concerns `begin`, `end`, and `:` slices. - -Basically, every index is converted to a concrete value using `Base.to_index` on `x`. However, `:` -slices are only converted to `ConcretizedSlice` (as opposed to `Base.Slice{Base.OneTo}`), to keep -the result close to the original indexing. -""" -concretize(I::ALLOWED_OPTICS, x) = I -concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x) -function concretize(I::IndexLens, x) - return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices))) -end -function concretize(I::ComposedFunction, x) - x_inner = I.inner(x) # TODO: get view here - return ComposedFunction(concretize(I.outer, x_inner), concretize(I.inner, x)) -end +Base.Symbol(vn::VarName) = Symbol(string(vn)) """ concretize(vn::VarName, x) Return `vn` concretized on `x`, i.e. any information related to the runtime shape of `x` is -evaluated. This concerns `begin`, `end`, and `:` slices. - -# Examples -```jldoctest; setup=:(using Accessors) -julia> x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); - -julia> getoptic(@varname(x.a[1:end, end][:], true)) # concrete=true required for @varname -(@o _.a[1:3, 2][:]) - -julia> y = zeros(10, 10); - -julia> @varname(y[:], true) -y[:] - -julia> # The underlying value is concretized, though: - AbstractPPL.getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] -ConcretizedSlice(Base.OneTo(100)) -``` +evaluated. This will convert any Colon indices to `Base.Slice`, which contains information +about the length of the dimension being sliced. """ +# TODO(penelopeysm): Does this affect begin/end? The old docstring said it would, but I +# could not see where in the implementation this was actually done. I remember that this is +# not the first time I've been confused about this. concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x)) """ @@ -355,104 +126,109 @@ julia> # Potentially surprising behaviour, but this is equivalent to what Base d (x[2:2:4], 2:2:4) ``` -### General indexing +# Interpolation -Under the hood `optic`s are used for the indexing: +Property names can also be constructed from interpolated symbols: ```jldoctest -julia> getoptic(@varname(x)) -identity (generic function with 1 method) - -julia> getoptic(@varname(x[1])) -(@o _[1]) - -julia> getoptic(@varname(x[:, 1])) -(@o _[Colon(), 1]) - -julia> getoptic(@varname(x[:, 1][2])) -(@o _[Colon(), 1][2]) - -julia> getoptic(@varname(x[1,2][1+5][45][3])) -(@o _[1, 2][6][45][3]) +julia> name = :hello; @varname(x.\$name) +x.hello ``` -This also means that we support property access: +For indices, you don't need to use `\$` to interpolate, just use the variable directly: ```jldoctest -julia> getoptic(@varname(x.a)) -(@o _.a) - -julia> getoptic(@varname(x.a[1])) -(@o _.a[1]) - -julia> x = (a = [(b = rand(2), )], ); getoptic(@varname(x.a[1].b[end], true)) -(@o _.a[1].b[2]) +julia> ix = 2; @varname(x[ix]) +x[2] ``` +""" -Interpolation can be used for variable names, or array name, but not the lhs of a `.` expression. -Variables within indices are always evaluated in the calling scope. - -```jldoctest -julia> name, i = :a, 10; - -julia> @varname(x.\$name[i, i+1]) -x.a[10, 11] - -julia> @varname(\$name) -a - -julia> @varname(\$name[1]) -a[1] - -julia> @varname(\$name.x[1]) -a.x[1] +struct VarNameParseException <: Exception + expr::Expr +end +function Base.showerror(io::IO, e::VarNameParseException) + return print(io, "malformed variable name `$(e.expr)`") +end -julia> @varname(b.\$name.x[1]) -b.a.x[1] -``` -""" -macro varname(expr::Union{Expr,Symbol}, concretize::Bool=Accessors.need_dynamic_optic(expr)) - return varname(expr, concretize) +macro varname(expr) + return _varname(expr, :(Iden())) +end +function _varname(sym::Symbol, inner_expr) + return :($VarName{$(QuoteNode(sym))}($inner_expr)) +end +function _varname(expr::Expr, inner_expr) + next_inner = if expr.head == :(.) + sym = _handle_property(expr.args[2], expr) + :(Property{$(sym)}($inner_expr)) + elseif expr.head == :ref + ixs = map(first ∘ _handle_index, expr.args[2:end]) + # TODO(penelopeysm): Technically, here we could track whether any of the indices are + # dynamic, and store this for later use. + # isdyn = any(last, ixs_and_isdyn) + # What we do now (generate the dynamic VarName first, and then later check whether + # it needs concretization) is slightly inefficient. + :(Index(tuple($(ixs...)), $inner_expr)) + else + # some other expression we can't parse + throw(VarNameParseException(expr)) + end + return _varname(expr.args[1], next_inner) end -varname(sym::Symbol) = :($(AbstractPPL.VarName){$(QuoteNode(sym))}()) -varname(sym::Symbol, _) = varname(sym) -function varname(expr::Expr, concretize=Accessors.need_dynamic_optic(expr)) - if Meta.isexpr(expr, :ref) || Meta.isexpr(expr, :.) - # Split into object/base symbol and lens. - sym_escaped, optics = _parse_obj_optic(expr) - # Setfield.jl escapes the return symbol, so we need to unescape - # to call `QuoteNode` on it. - sym = drop_escape(sym_escaped) - - # This is to handle interpolated heads -- Setfield treats them differently: - # julia> AbstractPPL._parse_obj_optics(Meta.parse("\$name.a")) - # (:($(Expr(:escape, :_))), (:($(Expr(:escape, :name))), :((PropertyLens){:a}()))) - # julia> AbstractPPL._parse_obj_optic(:(x.a)) - # (:($(Expr(:escape, :x))), :(Accessors.opticcompose((PropertyLens){:a}()))) - if sym != :_ - sym = QuoteNode(sym) - else - sym = optics.args[2] - optics = Expr(:call, optics.args[1], optics.args[3:end]...) - end +function _handle_property(qn::QuoteNode, original_expr) + if qn.value isa Symbol # no interpolation e.g. @varname(x.a) + return qn + elseif Meta.isexpr(qn.value, :$, 1) && qn.value.args[1] isa Symbol + # interpolated property e.g. @varname(x.$name). + # TODO(penelopeysm): Note that $name must evaluate to a Symbol, or else you will get + # a slightly inscrutable error: "ERROR: TypeError: in Type, in parameter, expected + # Type, got a value of type String". This should probably be fixed, but I don't + # actually *know* how to do it. Again, this is not a new issue, the old VarName + # also had the same problem. + return esc(qn.value.args[1]) + else + throw(VarNameParseException(original_expr)) + end +end +function _handle_property(::Any, original_expr) + throw(VarNameParseException(original_expr)) +end - if concretize - return :($(AbstractPPL.VarName){$sym}( - $(AbstractPPL.concretize)($optics, $sym_escaped) - )) - elseif Accessors.need_dynamic_optic(expr) - error("Variable name `$(expr)` is dynamic and requires concretization!") +_handle_index(ix::Int) = ix, false +function _handle_index(ix::Symbol) + # NOTE(penelopeysm): We could just use `:end` instead of Symbol(:end), but the former + # messes up syntax highlighting with Treesitter + # https://github.com/tree-sitter/tree-sitter-julia/issues/104 + if ix == Symbol(:end) + return :(DynamicEnd()), true + elseif ix == Symbol(:begin) + return :(DynamicBegin()), true + elseif ix == :(:) + return :(DynamicColon()), true + else + # an interpolated symbol + return ix, false + end +end +function _handle_index(ix::Expr) + if Meta.isexpr(ix, :call, 3) && ix.args[1] == :(:) + # This is a range + start, isdyn = _handle_index(ix.args[2]) + stop, isdyn2 = _handle_index(ix.args[3]) + if isdyn || isdyn2 + return :(DynamicRange($start, $stop)), true else - return :($(AbstractPPL.VarName){$sym}($optics)) + return :(($start):($stop)), false end - elseif Meta.isexpr(expr, :$, 1) - return :($(AbstractPPL.VarName){$(esc(expr.args[1]))}()) else - error("Malformed variable name `$(expr)`!") + # Some other expression. We don't want to parse this any further, but we also don't + # want to error, because it may well be an expression that evaluates to a valid + # index. + return ix, false end end +#= drop_escape(x) = x function drop_escape(expr::Expr) Meta.isexpr(expr, :escape) && return drop_escape(expr.args[1]) @@ -691,3 +467,5 @@ end _init(::Accessors.PropertyLens) = identity _init(::Accessors.IndexLens) = identity _init(::typeof(identity)) = identity + +=# From 1957e5d895600fad83af84d48a75ff089e4a779a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 02:42:27 +0000 Subject: [PATCH 02/60] finish optic.jl and varname.jl --- src/varname/optic.jl | 112 ++++++++++- src/varname/varname.jl | 441 +++++++++++++---------------------------- 2 files changed, 244 insertions(+), 309 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 9227612a..9712221a 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -7,19 +7,21 @@ An abstract type that represents the non-symbol part of a VarName, i.e., the sec variable that is of interest. For example, in `x.a[1][2]`, the `AbstractOptic` represents the `.a[1][2]` part. -# Interface +# Public interface This is WIP. - Base.show - to_accessors(optic) -> Accessors.Lens (recovering the old representation) +- is_dynamic(optic) -> Bool (whether the optic contains any dynamic indices) +- concretize(optic, val) -> AbstractOptic (resolving any dynamic indices given the value) Not sure if we want to introduce getters and setters and BangBang-style stuff. """ abstract type AbstractOptic end function Base.show(io::IO, optic::AbstractOptic) print(io, "Optic(") - pretty_print_optic(io, optic) + _pretty_print_optic(io, optic) return print(io, ")") end @@ -30,8 +32,9 @@ The identity optic. This is the optic used when we are referring to the entire v It is also the base case for composing optics. """ struct Iden <: AbstractOptic end -pretty_print_optic(::IO, ::Iden) = nothing +_pretty_print_optic(::IO, ::Iden) = nothing to_accessors(::Iden) = identity +is_dynamic(::Iden) = false concretize(i::Iden, ::Any) = i """ @@ -43,30 +46,39 @@ When parsing VarNames, we convert such indices into subtypes of `DynamicIndex`, mark them as requiring concretisation. """ abstract type DynamicIndex end +_is_dynamic_idx(::DynamicIndex) = true +_is_dynamic_idx(::Any) = false # Fallback for all other indices concretize(@nospecialize(ix::Any), ::Any, ::Any) = ix +_pretty_print_index(x::Any) = string(x) struct DynamicBegin <: DynamicIndex end concretize(::DynamicBegin, val, dim::Nothing) = Base.firstindex(val) concretize(::DynamicBegin, val, dim) = Base.firstindex(val, dim) +_pretty_print_index(::DynamicBegin) = "begin" struct DynamicEnd <: DynamicIndex end concretize(::DynamicEnd, val, dim::Nothing) = Base.lastindex(val) concretize(::DynamicEnd, val, dim) = Base.lastindex(val, dim) +_pretty_print_index(::DynamicEnd) = "end" struct DynamicColon <: DynamicIndex end concretize(::DynamicColon, val, dim::Nothing) = Base.firstindex(val):Base.lastindex(val) concretize(::DynamicColon, val, dim) = Base.firstindex(val, dim):Base.lastindex(val, dim) +_pretty_print_index(::DynamicColon) = ":" struct DynamicRange{T1,T2} <: DynamicIndex start::T1 stop::T2 end -function concretize(dr::DynamicRange, axis) - start = dr.start isa DynamicIndex ? concretize(dr.start, axis) : dr.start - stop = dr.stop isa DynamicIndex ? concretize(dr.stop, axis) : dr.stop +function concretize(dr::DynamicRange, axis, dim) + start = dr.start isa DynamicIndex ? concretize(dr.start, axis, dim) : dr.start + stop = dr.stop isa DynamicIndex ? concretize(dr.stop, axis, dim) : dr.stop return start:stop end +function _pretty_print_index(dr::DynamicRange) + return "$(_pretty_print_index(dr.start)):$(_pretty_print_index(dr.stop))" +end """ Index(ix, child=Iden()) @@ -83,10 +95,10 @@ Index(ix::Tuple, child::C=Iden()) where {C<:AbstractOptic} = Index{typeof(ix),C} Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.child == b.child Base.isequal(a::Index, b::Index) = a == b -function pretty_print_optic(io::IO, idx::Index) - ixs = join(idx.ix, ", ") +function _pretty_print_optic(io::IO, idx::Index) + ixs = join(map(_pretty_print_index, idx.ix), ", ") print(io, "[$(ixs)]") - return pretty_print_optic(io, idx.child) + return _pretty_print_optic(io, idx.child) end function to_accessors(idx::Index) ilens = Accessors.IndexLens(idx.ix) @@ -96,6 +108,7 @@ function to_accessors(idx::Index) Base.ComposedFunction(to_accessors(idx.child), ilens) end end +is_dynamic(idx::Index) = any(_is_dynamic_idx, idx.ix) || is_dynamic(idx.child) function concretize(idx::Index, val) concretized_indices = if length(idx.ix) == 0 [] @@ -127,9 +140,9 @@ Property{sym}(child::C=Iden()) where {sym,C<:AbstractOptic} = Property{sym,C}(ch Base.:(==)(a::Property{sym}, b::Property{sym}) where {sym} = a.child == b.child Base.:(==)(a::Property, b::Property) = false Base.isequal(a::Property, b::Property) = a == b -function pretty_print_optic(io::IO, prop::Property{sym}) where {sym} +function _pretty_print_optic(io::IO, prop::Property{sym}) where {sym} print(io, ".$(sym)") - return pretty_print_optic(io, prop.child) + return _pretty_print_optic(io, prop.child) end function to_accessors(prop::Property{sym}) where {sym} plens = Accessors.PropertyLens{sym}() @@ -139,7 +152,84 @@ function to_accessors(prop::Property{sym}) where {sym} Base.ComposedFunction(to_accessors(prop.child), plens) end end +is_dynamic(prop::Property) = is_dynamic(prop.child) function concretize(prop::Property{sym}, val) where {sym} inner_concretized = concretize(prop.child, getproperty(val, sym)) return Property{sym}(inner_concretized) end + +function Base.:(∘)(outer::AbstractOptic, inner::AbstractOptic) + if outer isa Iden + return inner + elseif inner isa Iden + return outer + else + # TODO... + error("not implemented") + end +end + +""" + _head(optic) + +Get the innermost layer of an AbstractOptic. For all optics, we have that `_tail(optic) ∘ +_head(optic) == optic`. +""" +_head(::Property{s}) where {s} = Property{s}(Iden()) +_head(idx::Index) = Index((idx.ix...,), Iden()) +_head(i::Iden) = i + +""" + _tail(optic) + +Get everything but the innermost layer of an optic. For all optics, we have that +`_tail(optic) ∘ _head(optic) == optic`. +``` +""" +_tail(p::Property) = p.child +_tail(idx::Index) = idx.child +_tail(i::Iden) = i + +""" + _last(optic) + +Get the outermost layer of an optic. For all optics, we have that `_last(optic) ∘ +_init(optic) == optic`. +""" +function _last(p::Property{s}) where {s} + if p.child isa Iden + return p + else + return _last(p.child) + end +end +function _last(idx::Index) + if idx.child isa Iden + return idx + else + return _last(idx.child) + end +end +_last(i::Iden) = i + +""" + _init(optic) + +Get everything but the outermost layer of an optic. For all optics, we have that +`_last(optic) ∘ _init(optic) == optic`. +""" +function _init(p::Property{s}) where {s} + return if p.child isa Iden + Iden() + else + Property{s}(_init(p.child)) + end +end +function _init(idx::Index) + return if idx.child isa Iden + Iden() + else + Index(idx.ix, _init(idx.child)) + end +end +_init(i::Iden) = i diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 057b1d41..40124385 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -41,10 +41,10 @@ Return the optic of the Julia variable used to generate `vn`. ```jldoctest julia> getoptic(@varname(x[1][2:3])) -[1][2:3] +Optic([1][2:3]) julia> getoptic(@varname(y)) -Iden() +Optic() ``` """ getoptic(vn::VarName) = vn.optic @@ -58,7 +58,7 @@ Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T} print(io, getsym(vn)) - return pretty_print_optic(io, getoptic(vn)) + return _pretty_print_optic(io, getoptic(vn)) end """ @@ -81,58 +81,125 @@ Base.Symbol(vn::VarName) = Symbol(string(vn)) concretize(vn::VarName, x) Return `vn` concretized on `x`, i.e. any information related to the runtime shape of `x` is -evaluated. This will convert any Colon indices to `Base.Slice`, which contains information -about the length of the dimension being sliced. +evaluated. This will convert any `begin`, `end`, or `:` indices in `vn` to concrete indices +with information about the length of the dimension being sliced. """ -# TODO(penelopeysm): Does this affect begin/end? The old docstring said it would, but I -# could not see where in the implementation this was actually done. I remember that this is -# not the first time I've been confused about this. concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x)) +""" + is_dynamic(vn::VarName) + +Return `true` if `vn` contains any dynamic indices (i.e., `begin`, `end`, or `:`). If a +`VarName` has been concretized, this will always return `false`. +""" +is_dynamic(vn::VarName) = is_dynamic(getoptic(vn)) + +""" + VarNameParseException(expr) + +An exception thrown when a variable name expression cannot be parsed by the +[`@varname`](@ref) macro. +""" +struct VarNameParseException <: Exception + expr::Expr +end +function Base.showerror(io::IO, e::VarNameParseException) + return print(io, "malformed variable name `$(e.expr)`") +end + """ @varname(expr, concretize=false) -A macro that returns an instance of [`VarName`](@ref) given a symbol or indexing expression `expr`. +Create a [`VarName`](@ref) given an expression `expr` representing a variable or part of it +(any expression that can be assigned to). -If `concretize` is `true`, the resulting expression will be wrapped in a `concretize()` call. +# Basic examples -Note that expressions involving dynamic indexing, i.e. `begin` and/or `end`, will always need to be -concretized as `VarName` only supports non-dynamic indexing as determined by -`is_static_optic`. See examples below. +In general, `VarName`s must have a top-level symbol representing the identifier itself, and +can then have any number of property accesses or indexing operations chained to it. -## Examples +```jldoctest +julia> @varname(x) +x + +julia> @varname(x.a.b.c) +x.a.b.c + +julia> @varname(x[1][2][3]) +x[1][2][3] + +julia> @varname(x.a[1:3].b[2]) +x.a[1:3].b[2] +``` + +# Dynamic indices + +Some expressions may involve dynamic indices, e.g., `begin`, `end`, and `:`. These indices +cannot be resolved, or 'concretized', until the value being indexed into is known. By +default, `@varname(...)` will not automatically concretize these expressions, and thus +the resulting `VarName` will contain markers for these. + +```jldoctest +julia> # VarNames are pretty-printed, so at first glance, it's not special... + vn = @varname(x[end]) +x[end] + +julia> # But if you look under the hood, you can see that the index is dynamic. + vn = @varname(x[end]); getoptic(vn).ix +(DynamicEnd(),) + +julia> vn = @varname(x[1:end, end]); getoptic(vn).ix +(DynamicRange{Int64, DynamicEnd}(1, DynamicEnd()), DynamicEnd()) +``` + +You can detect whether a `VarName` contains any dynamic indices using `is_dynamic(vn)`: -### Dynamic indexing ```jldoctest -julia> x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); +julia> vn = @varname(x[1:end, end]); is_dynamic(vn) +true +``` -julia> @varname(x.a[1:end, end][:], true) -x.a[1:3, 2][:] +To concretize such expressions, you can call `concretize(vn, val)` on the resulting +`VarName`. After concretization, the resulting `VarName` will no longer be dynamic. -julia> @varname(x.a[end], false) # disable concretization -ERROR: LoadError: Variable name `x.a[end]` is dynamic and requires concretization! -[...] +```jldoctest +julia> x = randn(2, 3); -julia> @varname(x.a[end]) # concretization occurs by default if deemed necessary -x.a[6] +julia> vn = @varname(x[1:end, end]); vn2 = concretize(vn, x) +x[1:2, 3][1:2] -julia> # Note that "dynamic" here refers to usage of `begin` and/or `end`, - # _not_ "information only available at runtime", i.e. the following works. - [@varname(x.a[i]) for i = 1:length(x.a)][end] -x.a[6] +julia> getoptic(vn2).ix +((1:2), 3) -julia> # Potentially surprising behaviour, but this is equivalent to what Base does: - @varname(x[2:2:5]), 2:2:5 -(x[2:2:4], 2:2:4) +julia> is_dynamic(vn2) +false +``` + +Alternatively, you can pass `true` as the second positional argument the `@varname` macro +(note that it is not a keyword argument!). This will call `concretize` for you, using the +top-level symbol to look up the value used for concretization. + +```jldoctest +julia> x = randn(2, 3); + +julia> @varname(x[1:end, end][:], true) +x[1:2, 3][1:2] ``` # Interpolation -Property names can also be constructed from interpolated symbols: +Property names, as well as top-level symbols, can also be constructed from interpolated +symbols: ```jldoctest julia> name = :hello; @varname(x.\$name) x.hello + +julia> @varname(\$name) +hello + +julia> @varname(\$name.a.\$name[1]) +hello.a.hello[1] ``` For indices, you don't need to use `\$` to interpolate, just use the variable directly: @@ -142,37 +209,57 @@ julia> ix = 2; @varname(x[ix]) x[2] ``` """ - -struct VarNameParseException <: Exception - expr::Expr -end -function Base.showerror(io::IO, e::VarNameParseException) - return print(io, "malformed variable name `$(e.expr)`") -end - -macro varname(expr) - return _varname(expr, :(Iden())) +macro varname(expr, concretize::Bool=false) + unconcretized_vn, sym = _varname(expr, :(Iden())) + return if concretize + if sym === nothing + throw( + ArgumentError( + "cannot automatically concretize VarName with interpolated top-level symbol; call `concretize(vn, val)` manually instead", + ), + ) + end + :(concretize($unconcretized_vn, $(esc(sym)))) + else + unconcretized_vn + end end function _varname(sym::Symbol, inner_expr) - return :($VarName{$(QuoteNode(sym))}($inner_expr)) + return :($VarName{$(QuoteNode(sym))}($inner_expr)), sym end function _varname(expr::Expr, inner_expr) - next_inner = if expr.head == :(.) - sym = _handle_property(expr.args[2], expr) - :(Property{$(sym)}($inner_expr)) - elseif expr.head == :ref - ixs = map(first ∘ _handle_index, expr.args[2:end]) - # TODO(penelopeysm): Technically, here we could track whether any of the indices are - # dynamic, and store this for later use. - # isdyn = any(last, ixs_and_isdyn) - # What we do now (generate the dynamic VarName first, and then later check whether - # it needs concretization) is slightly inefficient. - :(Index(tuple($(ixs...)), $inner_expr)) + if Meta.isexpr(expr, :$, 1) + # Interpolation of the top-level symbol e.g. @varname($name). If we hit this branch, + # it means that there are no further property/indexing accesses (because otherwise + # expr.head would be :ref or :.) Thus we don't need to recurse further, and we can + # just return `inner_expr` as-is. + # TODO(penelopeysm): Is there a way to make auto-concretisation work here? To + # be clear, what we want is something like the following to work: + # name = :hello; hello = rand(3); @varname($name[:], true) + # I've tried every combination of `esc`, `QuoteNode`, and `$` I can think of, but + # with no success yet. It didn't work with old AbstractPPL either ("syntax: + # all-underscore identifiers are write-only and their values cannot be used in + # expressions"); at least now we give a more sensible error message. + sym_expr = expr.args[1] + return :($VarName{$(sym_expr)}($inner_expr)), nothing else - # some other expression we can't parse - throw(VarNameParseException(expr)) + next_inner = if expr.head == :(.) + sym = _handle_property(expr.args[2], expr) + :(Property{$(sym)}($inner_expr)) + elseif expr.head == :ref + ixs = map(first ∘ _handle_index, expr.args[2:end]) + # TODO(penelopeysm): Technically, here we could track whether any of the indices are + # dynamic, and store this for later use. + # isdyn = any(last, ixs_and_isdyn) + # What we do now (generate the dynamic VarName first, and then later check whether + # it needs concretization) is slightly inefficient. + :(Index(tuple($(ixs...)), $inner_expr)) + else + # some other expression we can't parse + throw(VarNameParseException(expr)) + end + return _varname(expr.args[1], next_inner) end - return _varname(expr.args[1], next_inner) end function _handle_property(qn::QuoteNode, original_expr) @@ -227,245 +314,3 @@ function _handle_index(ix::Expr) return ix, false end end - -#= -drop_escape(x) = x -function drop_escape(expr::Expr) - Meta.isexpr(expr, :escape) && return drop_escape(expr.args[1]) - return Expr(expr.head, map(x -> drop_escape(x), expr.args)...) -end - -function _parse_obj_optic(ex) - obj, optics = _parse_obj_optics(ex) - optic = Expr(:call, Accessors.opticcompose, optics...) - return obj, optic -end - -# Accessors doesn't have the same support for interpolation -# so this function is copied and altered from `Setfield._parse_obj_lens` -function _parse_obj_optics(ex) - if Meta.isexpr(ex, :$, 1) - return esc(:_), (esc(ex.args[1]),) - elseif Meta.isexpr(ex, :ref) && !isempty(ex.args) - front, indices... = ex.args - obj, frontoptics = _parse_obj_optics(front) - if any(Accessors.need_dynamic_optic, indices) - @gensym collection - indices = Accessors.replace_underscore.(indices, collection) - dims = length(indices) == 1 ? nothing : 1:length(indices) - lindices = esc.(Accessors.lower_index.(collection, indices, dims)) - optics = - :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),))) - else - index = esc(Expr(:tuple, indices...)) - optics = :($(Accessors.IndexLens)($index)) - end - elseif Meta.isexpr(ex, :., 2) - front = ex.args[1] - property = ex.args[2].value # ex.args[2] is a QuoteNode - obj, frontoptics = _parse_obj_optics(front) - if property isa Union{Symbol,String} - optics = :($(Accessors.PropertyLens){$(QuoteNode(property))}()) - elseif Meta.isexpr(property, :$, 1) - optics = :($(Accessors.PropertyLens){$(esc(property.args[1]))}()) - else - throw( - ArgumentError( - string( - "Error while parsing :($ex). Second argument to `getproperty` can only be", - "a `Symbol` or `String` literal, received `$property` instead.", - ), - ), - ) - end - else - obj = esc(ex) - return obj, () - end - return obj, tuple(frontoptics..., optics) -end - -""" - @vsym(expr) - -A macro that returns the variable symbol given the input variable expression `expr`. -For example, `@vsym x[1]` returns `:x`. - -## Examples - -```jldoctest -julia> @vsym x -:x - -julia> @vsym x[1,1][2,3] -:x - -julia> @vsym x[end] -:x -``` -""" -macro vsym(expr::Union{Expr,Symbol}) - return QuoteNode(vsym(expr)) -end - -""" - vsym(expr) - -Return name part of the [`@varname`](@ref)-compatible expression `expr` as a symbol for input of the -[`VarName`](@ref) constructor. -""" -function vsym end - -vsym(expr::Symbol) = expr -function vsym(expr::Expr) - if Meta.isexpr(expr, :ref) || Meta.isexpr(expr, :.) - return vsym(expr.args[1]) - else - error("Malformed variable name `$(expr)`!") - end -end - -""" - _head(optic) - -Get the innermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_tail(optic) ∘ -_head(optic) == optic)`. - -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._head(Accessors.@o _.a.b.c) -(@o _.a) - -julia> AbstractPPL._head(Accessors.@o _[1][2][3]) -(@o _[1]) - -julia> AbstractPPL._head(Accessors.@o _.a) -(@o _.a) - -julia> AbstractPPL._head(Accessors.@o _[1]) -(@o _[1]) - -julia> AbstractPPL._head(Accessors.@o _) -identity (generic function with 1 method) -``` -""" -_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner -_head(o::Accessors.PropertyLens) = o -_head(o::Accessors.IndexLens) = o -_head(::typeof(identity)) = identity - -""" - _tail(optic) - -Get everything but the innermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_tail(optic) ∘ -_head(optic) == optic)`. - -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._tail(Accessors.@o _.a.b.c) -(@o _.b.c) - -julia> AbstractPPL._tail(Accessors.@o _[1][2][3]) -(@o _[2][3]) - -julia> AbstractPPL._tail(Accessors.@o _.a) -identity (generic function with 1 method) - -julia> AbstractPPL._tail(Accessors.@o _[1]) -identity (generic function with 1 method) - -julia> AbstractPPL._tail(Accessors.@o _) -identity (generic function with 1 method) -``` -""" -_tail(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer -_tail(::Accessors.PropertyLens) = identity -_tail(::Accessors.IndexLens) = identity -_tail(::typeof(identity)) = identity - -""" - _last(optic) - -Get the outermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_last(optic) ∘ -_init(optic)) == optic`. - -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._last(Accessors.@o _.a.b.c) -(@o _.c) - -julia> AbstractPPL._last(Accessors.@o _[1][2][3]) -(@o _[3]) - -julia> AbstractPPL._last(Accessors.@o _.a) -(@o _.a) - -julia> AbstractPPL._last(Accessors.@o _[1]) -(@o _[1]) - -julia> AbstractPPL._last(Accessors.@o _) -identity (generic function with 1 method) -``` -""" -_last(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _last(o.outer) -_last(o::Accessors.PropertyLens) = o -_last(o::Accessors.IndexLens) = o -_last(::typeof(identity)) = identity - -""" - _init(optic) - -Get everything but the outermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_last(optic) ∘ -_init(optic)) == optic`. - -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._init(Accessors.@o _.a.b.c) -(@o _.a.b) - -julia> AbstractPPL._init(Accessors.@o _[1][2][3]) -(@o _[1][2]) - -julia> AbstractPPL._init(Accessors.@o _.a) -identity (generic function with 1 method) - -julia> AbstractPPL._init(Accessors.@o _[1]) -identity (generic function with 1 method) - -julia> AbstractPPL._init(Accessors.@o _) -identity (generic function with 1 method) -""" -# This one needs normalise because it's going 'against' the direction of the -# linked list (otherwise you will end up with identities scattered throughout) -function _init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} - return normalise(_init(o.outer) ∘ o.inner) -end -_init(::Accessors.PropertyLens) = identity -_init(::Accessors.IndexLens) = identity -_init(::typeof(identity)) = identity - -=# From b4aee1c9e6a3bf3e47fb0fc7e27d92224f9f873c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 02:53:06 +0000 Subject: [PATCH 03/60] fix optic composition --- src/varname/optic.jl | 71 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 62 insertions(+), 9 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 9712221a..fad69563 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -140,6 +140,7 @@ Property{sym}(child::C=Iden()) where {sym,C<:AbstractOptic} = Property{sym,C}(ch Base.:(==)(a::Property{sym}, b::Property{sym}) where {sym} = a.child == b.child Base.:(==)(a::Property, b::Property) = false Base.isequal(a::Property, b::Property) = a == b +getsym(::Property{s}) where {s} = s function _pretty_print_optic(io::IO, prop::Property{sym}) where {sym} print(io, ".$(sym)") return _pretty_print_optic(io, prop.child) @@ -158,32 +159,68 @@ function concretize(prop::Property{sym}, val) where {sym} return Property{sym}(inner_concretized) end +""" + ∘(outer::AbstractOptic, inner::AbstractOptic) + +Compose two `AbstractOptic`s together. + +```jldoctest +julia> p1 = Property{:a}(Index((1,))) +Optic(.a[1]) + +julia> p2 = Property{:b}(Index((2,3))) +Optic(.b[2, 3]) + +julia> p1 ∘ p2 +Optic(.b[2, 3].a[1]) +``` +""" function Base.:(∘)(outer::AbstractOptic, inner::AbstractOptic) if outer isa Iden return inner elseif inner isa Iden return outer else - # TODO... - error("not implemented") + if inner isa Property + return Property{getsym(inner)}(outer ∘ inner.child) + elseif inner isa Index + return Index(inner.ix, outer ∘ inner.child) + else + error("unreachable; unknown AbstractOptic subtype $(typeof(inner))") + end end end """ - _head(optic) + _head(optic::AbstractOptic) -Get the innermost layer of an AbstractOptic. For all optics, we have that `_tail(optic) ∘ +Get the innermost layer of an optic. For all optics, we have that `_tail(optic) ∘ _head(optic) == optic`. + +```jldoctest +julia> _head(getoptic(@varname(x.a[1][2]))) +Optic(.a) + +julia> _head(getoptic(@varname(x))) +Optic() +``` """ _head(::Property{s}) where {s} = Property{s}(Iden()) _head(idx::Index) = Index((idx.ix...,), Iden()) _head(i::Iden) = i """ - _tail(optic) + _tail(optic::AbstractOptic) -Get everything but the innermost layer of an optic. For all optics, we have that +Get everything but the innermost layer of an optic. For all optics, we have that `_tail(optic) ∘ _head(optic) == optic`. + +```jldoctest +julia> _tail(getoptic(@varname(x.a[1][2]))) +Optic([1][2]) + +julia> _tail(getoptic(@varname(x))) +Optic() ``` """ _tail(p::Property) = p.child @@ -191,10 +228,18 @@ _tail(idx::Index) = idx.child _tail(i::Iden) = i """ - _last(optic) + _last(optic::AbstractOptic) -Get the outermost layer of an optic. For all optics, we have that `_last(optic) ∘ +Get the outermost layer of an optic. For all optics, we have that `_last(optic) ∘ _init(optic) == optic`. + +```jldoctest +julia> _last(getoptic(@varname(x.a[1][2]))) +Optic([2]) + +julia> _last(getoptic(@varname(x))) +Optic() +``` """ function _last(p::Property{s}) where {s} if p.child isa Iden @@ -213,10 +258,18 @@ end _last(i::Iden) = i """ - _init(optic) + _init(optic::AbstractOptic) Get everything but the outermost layer of an optic. For all optics, we have that `_last(optic) ∘ _init(optic) == optic`. + +```jldoctest +julia> _init(getoptic(@varname(x.a[1][2]))) +Optic(.a[1]) + +julia> _init(getoptic(@varname(x))) +Optic() +``` """ function _init(p::Property{s}) where {s} return if p.child isa Iden From 9a8a236c59e6656d0e52809d73aa120f57ed7a73 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 02:54:46 +0000 Subject: [PATCH 04/60] use view --- src/varname/optic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index fad69563..a67089c5 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -121,7 +121,7 @@ function concretize(idx::Index, val) # dimension. [concretize(ix, val, dim) for (dim, ix) in enumerate(idx.ix)] end - inner_concretized = concretize(idx.child, val[concretized_indices...]) + inner_concretized = concretize(idx.child, view(val, concretized_indices...)) return Index((concretized_indices...,), inner_concretized) end From 41430d3c5d6a0c04e9c54b4519511a3dd16638f5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 03:03:49 +0000 Subject: [PATCH 05/60] Add some tsts --- src/AbstractPPL.jl | 10 +++--- src/varname/optic.jl | 71 ++++++++++++++++++++++--------------------- test/runtests.jl | 5 +-- test/varname.jl | 58 ----------------------------------- test/varname/optic.jl | 29 ++++++++++++++++++ 5 files changed, 74 insertions(+), 99 deletions(-) create mode 100644 test/varname/optic.jl diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 3fb08f83..3bea9757 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -34,10 +34,10 @@ include("abstractprobprog.jl") include("evaluate.jl") include("varname/optic.jl") include("varname/varname.jl") -include("varname/subsumes.jl") -include("varname/hasvalue.jl") -include("varname/leaves.jl") -include("varname/prefix.jl") -include("varname/serialize.jl") +# include("varname/subsumes.jl") +# include("varname/hasvalue.jl") +# include("varname/leaves.jl") +# include("varname/prefix.jl") +# include("varname/serialize.jl") end # module diff --git a/src/varname/optic.jl b/src/varname/optic.jl index a67089c5..eae8adc7 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -12,6 +12,9 @@ the `.a[1][2]` part. This is WIP. - Base.show +- Base.:(==), Base.isequal +- Base.:(∘) (composition) +- ohead, otail, olast, oinit (decomposition) - to_accessors(optic) -> Accessors.Lens (recovering the old representation) - is_dynamic(optic) -> Bool (whether the optic contains any dynamic indices) - concretize(optic, val) -> AbstractOptic (resolving any dynamic indices given the value) @@ -192,97 +195,97 @@ function Base.:(∘)(outer::AbstractOptic, inner::AbstractOptic) end """ - _head(optic::AbstractOptic) + ohead(optic::AbstractOptic) -Get the innermost layer of an optic. For all optics, we have that `_tail(optic) ∘ -_head(optic) == optic`. +Get the innermost layer of an optic. For all optics, we have that `otail(optic) ∘ +ohead(optic) == optic`. ```jldoctest -julia> _head(getoptic(@varname(x.a[1][2]))) +julia> ohead(getoptic(@varname(x.a[1][2]))) Optic(.a) -julia> _head(getoptic(@varname(x))) +julia> ohead(getoptic(@varname(x))) Optic() ``` """ -_head(::Property{s}) where {s} = Property{s}(Iden()) -_head(idx::Index) = Index((idx.ix...,), Iden()) -_head(i::Iden) = i +ohead(::Property{s}) where {s} = Property{s}(Iden()) +ohead(idx::Index) = Index((idx.ix...,), Iden()) +ohead(i::Iden) = i """ - _tail(optic::AbstractOptic) + otail(optic::AbstractOptic) Get everything but the innermost layer of an optic. For all optics, we have that -`_tail(optic) ∘ _head(optic) == optic`. +`otail(optic) ∘ ohead(optic) == optic`. ```jldoctest -julia> _tail(getoptic(@varname(x.a[1][2]))) +julia> otail(getoptic(@varname(x.a[1][2]))) Optic([1][2]) -julia> _tail(getoptic(@varname(x))) +julia> otail(getoptic(@varname(x))) Optic() ``` """ -_tail(p::Property) = p.child -_tail(idx::Index) = idx.child -_tail(i::Iden) = i +otail(p::Property) = p.child +otail(idx::Index) = idx.child +otail(i::Iden) = i """ - _last(optic::AbstractOptic) + olast(optic::AbstractOptic) -Get the outermost layer of an optic. For all optics, we have that `_last(optic) ∘ -_init(optic) == optic`. +Get the outermost layer of an optic. For all optics, we have that `olast(optic) ∘ +oinit(optic) == optic`. ```jldoctest -julia> _last(getoptic(@varname(x.a[1][2]))) +julia> olast(getoptic(@varname(x.a[1][2]))) Optic([2]) -julia> _last(getoptic(@varname(x))) +julia> olast(getoptic(@varname(x))) Optic() ``` """ -function _last(p::Property{s}) where {s} +function olast(p::Property{s}) where {s} if p.child isa Iden return p else - return _last(p.child) + return olast(p.child) end end -function _last(idx::Index) +function olast(idx::Index) if idx.child isa Iden return idx else - return _last(idx.child) + return olast(idx.child) end end -_last(i::Iden) = i +olast(i::Iden) = i """ - _init(optic::AbstractOptic) + oinit(optic::AbstractOptic) Get everything but the outermost layer of an optic. For all optics, we have that -`_last(optic) ∘ _init(optic) == optic`. +`olast(optic) ∘ oinit(optic) == optic`. ```jldoctest -julia> _init(getoptic(@varname(x.a[1][2]))) +julia> oinit(getoptic(@varname(x.a[1][2]))) Optic(.a[1]) -julia> _init(getoptic(@varname(x))) +julia> oinit(getoptic(@varname(x))) Optic() ``` """ -function _init(p::Property{s}) where {s} +function oinit(p::Property{s}) where {s} return if p.child isa Iden Iden() else - Property{s}(_init(p.child)) + Property{s}(oinit(p.child)) end end -function _init(idx::Index) +function oinit(idx::Index) return if idx.child isa Iden Iden() else - Index(idx.ix, _init(idx.child)) + Index(idx.ix, oinit(idx.child)) end end -_init(i::Iden) = i +oinit(i::Iden) = i diff --git a/test/runtests.jl b/test/runtests.jl index cb07ee02..a4047777 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,9 +7,10 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractPPL.jl" begin if GROUP == "All" || GROUP == "Tests" include("Aqua.jl") - include("varname.jl") include("abstractprobprog.jl") - include("hasvalue.jl") + include("varname/optic.jl") + # include("varname.jl") + # include("hasvalue.jl") end if GROUP == "All" || GROUP == "Doctests" diff --git a/test/varname.jl b/test/varname.jl index c86c4b4e..8de88d8c 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -3,31 +3,6 @@ using InvertedIndices using OffsetArrays using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky -using AbstractPPL: ⊑, ⊒, ⋢, ⋣, ≍ - -using AbstractPPL: Accessors -using AbstractPPL.Accessors: IndexLens, PropertyLens, ⨟ - -macro test_strict_subsumption(x, y) - quote - @test $((varname(x))) ⊑ $((varname(y))) - @test $((varname(x))) ⋣ $((varname(y))) - end -end - -function test_equal(o1::VarName{sym1}, o2::VarName{sym2}) where {sym1,sym2} - return sym1 === sym2 && test_equal(o1.optic, o2.optic) -end -function test_equal(o1::ComposedFunction, o2::ComposedFunction) - return test_equal(o1.inner, o2.inner) && test_equal(o1.outer, o2.outer) -end -function test_equal(o1::Accessors.IndexLens, o2::Accessors.IndexLens) - return test_equal(o1.indices, o2.indices) -end -function test_equal(o1, o2) - return o1 == o2 -end - @testset "varnames" begin @testset "string and symbol conversion" begin vn1 = @varname x[1][2] @@ -42,39 +17,6 @@ end @test hash(vn2) == hash(vn1) end - @testset "inspace" begin - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) - end - - @testset "optic normalisation" begin - # Push the limits a bit with four optics, one of which is identity, and - # we'll parenthesise them in every possible way. (Some of these are - # going to be equal even before normalisation, but we should test that - # `normalise` works regardless of how Base or Accessors.jl define - # associativity.) - op1 = ((@o _.c) ∘ (@o _.b)) ∘ identity ∘ (@o _.a) - op2 = (@o _.c) ∘ ((@o _.b) ∘ identity) ∘ (@o _.a) - op3 = (@o _.c) ∘ (@o _.b) ∘ (identity ∘ (@o _.a)) - op4 = ((@o _.c) ∘ (@o _.b) ∘ identity) ∘ (@o _.a) - op5 = (@o _.c) ∘ ((@o _.b) ∘ identity ∘ (@o _.a)) - op6 = (@o _.c) ∘ (@o _.b) ∘ identity ∘ (@o _.a) - for op in (op1, op2, op3, op4, op5, op6) - @test AbstractPPL.normalise(op) == (@o _.c) ∘ (@o _.b) ∘ (@o _.a) - end - # Prefix and unprefix also provide further testing for normalisation. - end - @testset "construction & concretization" begin i = 1:10 j = 2:2:5 diff --git a/test/varname/optic.jl b/test/varname/optic.jl new file mode 100644 index 00000000..162cb4ff --- /dev/null +++ b/test/varname/optic.jl @@ -0,0 +1,29 @@ +module OpticTests + +using Test +using AbstractPPL + +@testset verbose = true "varname/optic.jl" begin + # Note that much of the functionality in optic.jl is tested by varname.jl (for example, + # pretty-printing VarNames essentially boils down to pretty-printing optics). So, this + # file focuses on tests that are specific to optics. + + @testset "composition" begin + @testset "with identity" begin + i = AbstractPPL.Iden() + o = getoptic(@varname(x.a.b)) + @test i ∘ i == i + @test i ∘ o == o + @test o ∘ i == o + end + + o1 = getoptic(@varname(x.a.b)) + o2 = getoptic(@varname(x[1][2])) + @test o1 ∘ o2 == getoptic(@varname(x[1][2].a.b)) + @test o2 ∘ o1 == getoptic(@varname(x.a.b[1][2])) + end + + @testset "decomposition" begin end +end + +end # module From f0ac406acfd1bc15341c98029df237e7e160eec9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 03:08:49 +0000 Subject: [PATCH 06/60] fix some stuff --- src/varname/optic.jl | 6 ++++-- test/varname/optic.jl | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index eae8adc7..10f14f49 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -9,17 +9,19 @@ the `.a[1][2]` part. # Public interface -This is WIP. +TODO - Base.show - Base.:(==), Base.isequal - Base.:(∘) (composition) - ohead, otail, olast, oinit (decomposition) + - to_accessors(optic) -> Accessors.Lens (recovering the old representation) - is_dynamic(optic) -> Bool (whether the optic contains any dynamic indices) - concretize(optic, val) -> AbstractOptic (resolving any dynamic indices given the value) -Not sure if we want to introduce getters and setters and BangBang-style stuff. +We probably want to introduce getters and setters. See e.g. +https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/ """ abstract type AbstractOptic end function Base.show(io::IO, optic::AbstractOptic) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index 162cb4ff..3b472d21 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -7,7 +7,6 @@ using AbstractPPL # Note that much of the functionality in optic.jl is tested by varname.jl (for example, # pretty-printing VarNames essentially boils down to pretty-printing optics). So, this # file focuses on tests that are specific to optics. - @testset "composition" begin @testset "with identity" begin i = AbstractPPL.Iden() From e7dd774797b3403b510fe3c8f0d6cf870a8771c5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 03:26:11 +0000 Subject: [PATCH 07/60] fix x[1: :] --- src/varname/optic.jl | 13 +++++++++---- src/varname/varname.jl | 6 +++++- test/runtests.jl | 1 + test/varname/optic.jl | 2 +- test/varname/varname.jl | 22 ++++++++++++++++++++++ 5 files changed, 38 insertions(+), 6 deletions(-) create mode 100644 test/varname/varname.jl diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 10f14f49..8f10c8d3 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -51,18 +51,19 @@ When parsing VarNames, we convert such indices into subtypes of `DynamicIndex`, mark them as requiring concretisation. """ abstract type DynamicIndex end +abstract type DynamicSingleIndex <: DynamicIndex end _is_dynamic_idx(::DynamicIndex) = true _is_dynamic_idx(::Any) = false # Fallback for all other indices concretize(@nospecialize(ix::Any), ::Any, ::Any) = ix _pretty_print_index(x::Any) = string(x) -struct DynamicBegin <: DynamicIndex end +struct DynamicBegin <: DynamicSingleIndex end concretize(::DynamicBegin, val, dim::Nothing) = Base.firstindex(val) concretize(::DynamicBegin, val, dim) = Base.firstindex(val, dim) _pretty_print_index(::DynamicBegin) = "begin" -struct DynamicEnd <: DynamicIndex end +struct DynamicEnd <: DynamicSingleIndex end concretize(::DynamicEnd, val, dim::Nothing) = Base.lastindex(val) concretize(::DynamicEnd, val, dim) = Base.lastindex(val, dim) _pretty_print_index(::DynamicEnd) = "end" @@ -72,7 +73,9 @@ concretize(::DynamicColon, val, dim::Nothing) = Base.firstindex(val):Base.lastin concretize(::DynamicColon, val, dim) = Base.firstindex(val, dim):Base.lastindex(val, dim) _pretty_print_index(::DynamicColon) = ":" -struct DynamicRange{T1,T2} <: DynamicIndex +struct DynamicRange{ + T1<:Union{Real,DynamicSingleIndex},T2<:Union{Real,DynamicSingleIndex} +} <: DynamicIndex start::T1 stop::T2 end @@ -95,8 +98,10 @@ property access after this indexing operation. struct Index{I<:Tuple,C<:AbstractOptic} <: AbstractOptic ix::I child::C + function Index(ix::Tuple, child::C=Iden()) where {C<:AbstractOptic} + return new{typeof(ix),C}(ix, child) + end end -Index(ix::Tuple, child::C=Iden()) where {C<:AbstractOptic} = Index{typeof(ix),C}(ix, child) Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.child == b.child Base.isequal(a::Index, b::Index) = a == b diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 40124385..43b5f837 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -101,7 +101,7 @@ An exception thrown when a variable name expression cannot be parsed by the [`@varname`](@ref) macro. """ struct VarNameParseException <: Exception - expr::Expr + expr end function Base.showerror(io::IO, e::VarNameParseException) return print(io, "malformed variable name `$(e.expr)`") @@ -224,6 +224,10 @@ macro varname(expr, concretize::Bool=false) unconcretized_vn end end +function _varname(@nospecialize(expr::Any), ::Any) + # fallback: it's not a variable! + throw(VarNameParseException(expr)) +end function _varname(sym::Symbol, inner_expr) return :($VarName{$(QuoteNode(sym))}($inner_expr)), sym end diff --git a/test/runtests.jl b/test/runtests.jl index a4047777..06e91615 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ const GROUP = get(ENV, "GROUP", "All") include("Aqua.jl") include("abstractprobprog.jl") include("varname/optic.jl") + include("varname/varname.jl") # include("varname.jl") # include("hasvalue.jl") end diff --git a/test/varname/optic.jl b/test/varname/optic.jl index 3b472d21..933c4976 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -3,7 +3,7 @@ module OpticTests using Test using AbstractPPL -@testset verbose = true "varname/optic.jl" begin +@testset "varname/optic.jl" verbose = true begin # Note that much of the functionality in optic.jl is tested by varname.jl (for example, # pretty-printing VarNames essentially boils down to pretty-printing optics). So, this # file focuses on tests that are specific to optics. diff --git a/test/varname/varname.jl b/test/varname/varname.jl new file mode 100644 index 00000000..e7e80457 --- /dev/null +++ b/test/varname/varname.jl @@ -0,0 +1,22 @@ +module VarNameTests + +using AbstractPPL +using Test + +@testset "varname/varname.jl" verbose = true begin + # TODO + + @testset "errors on nonsensical inputs" begin + # Note: have to wrap in eval to avoid throwing an error before the actual test + errmsg = "malformed variable name" + @test_throws errmsg eval(:(@varname(1))) + @test_throws errmsg eval(:(@varname(x + y))) + # This doesn't fail to parse, but it will throw a MethodError because you can't + # construct a DynamicColon with a DynamicColon as an argument + # TODO(penelopeysm): I would like to test this, but JuliaFormatter reformats + # this into x[1::] which then fails to parse. Grr. + # @test_throws MethodError eval(:(@varname(x[1: :]))) + end +end + +end # module VarNameTests From bb93f3dee31265c0998a4ba26be564ceb86489d1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 17:14:28 +0000 Subject: [PATCH 08/60] Fix dynamic indices fully --- Project.toml | 2 + src/AbstractPPL.jl | 37 ++++----- src/varname/optic.jl | 111 +++++++++++++++------------ src/varname/varname.jl | 165 ++++++++++++++++++++++++----------------- 4 files changed, 176 insertions(+), 139 deletions(-) diff --git a/Project.toml b/Project.toml index d6db8338..99c04fb8 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -27,6 +28,7 @@ DensityInterface = "0.4" Distributions = "0.25" JSON = "0.19 - 0.21, 1" LinearAlgebra = "<0.0.1, 1.10" +MacroTools = "0.5.16" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" julia = "1.10" diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 3bea9757..f31da95c 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -1,26 +1,21 @@ module AbstractPPL -# VarName -export VarName, - getsym, - getoptic, - inspace, - subsumes, - subsumedby, - varname, - vsym, - @varname, - @vsym, - index_to_dict, - dict_to_index, - varname_to_string, - string_to_varname, - prefix, - unprefix, - getvalue, - hasvalue, - varname_leaves, - varname_and_value_leaves +# VarName and Optic functions +export VarName, getsym, getoptic, concretize, is_dynamic, @varname, @opticof +export AbstractOptic, Iden, Index, Property, DynamicIndex, ohead, otail, olast, oinit + +# subsumes, +# subsumedby, +# index_to_dict, +# dict_to_index, +# varname_to_string, +# string_to_varname, +# prefix, +# unprefix, +# getvalue, +# hasvalue, +# varname_leaves, +# varname_and_value_leaves # Abstract model functions export AbstractProbabilisticProgram, diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 8f10c8d3..b5a86c58 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -1,4 +1,5 @@ using Accessors: Accessors +using MacroTools: MacroTools """ AbstractOptic @@ -47,47 +48,69 @@ concretize(i::Iden, ::Any) = i An abstract type representing dynamic indices such as `begin`, `end`, and `:`. These indices are things which cannot be resolved until we provide the value that is being indexed into. -When parsing VarNames, we convert such indices into subtypes of `DynamicIndex`, and we later -mark them as requiring concretisation. +When parsing VarNames, we convert such indices into subtypes of `DynamicIndex`. + +Because a `DynamicIndex` cannot be resolved until we have the value being indexed into, it +is actually a wrapper around a function that, when called on the value, returns the concrete +index. + +For example: + +- the index `begin` is turned into `DynamicIndex(:begin, (val) -> Base.firstindex(val))`. +- the index `1:end` is turned into `DynamicIndex(:(1:end), (val) -> 1:Base.lastindex(val))`. + +The `expr` field stores the original expression solely for pretty-printing purposes. """ -abstract type DynamicIndex end -abstract type DynamicSingleIndex <: DynamicIndex end -_is_dynamic_idx(::DynamicIndex) = true -_is_dynamic_idx(::Any) = false -# Fallback for all other indices -concretize(@nospecialize(ix::Any), ::Any, ::Any) = ix -_pretty_print_index(x::Any) = string(x) - -struct DynamicBegin <: DynamicSingleIndex end -concretize(::DynamicBegin, val, dim::Nothing) = Base.firstindex(val) -concretize(::DynamicBegin, val, dim) = Base.firstindex(val, dim) -_pretty_print_index(::DynamicBegin) = "begin" - -struct DynamicEnd <: DynamicSingleIndex end -concretize(::DynamicEnd, val, dim::Nothing) = Base.lastindex(val) -concretize(::DynamicEnd, val, dim) = Base.lastindex(val, dim) -_pretty_print_index(::DynamicEnd) = "end" - -struct DynamicColon <: DynamicIndex end -concretize(::DynamicColon, val, dim::Nothing) = Base.firstindex(val):Base.lastindex(val) -concretize(::DynamicColon, val, dim) = Base.firstindex(val, dim):Base.lastindex(val, dim) -_pretty_print_index(::DynamicColon) = ":" - -struct DynamicRange{ - T1<:Union{Real,DynamicSingleIndex},T2<:Union{Real,DynamicSingleIndex} -} <: DynamicIndex - start::T1 - stop::T2 +struct DynamicIndex{E<:Union{Expr,Symbol},F} + expr::E + f::F end -function concretize(dr::DynamicRange, axis, dim) - start = dr.start isa DynamicIndex ? concretize(dr.start, axis, dim) : dr.start - stop = dr.stop isa DynamicIndex ? concretize(dr.stop, axis, dim) : dr.stop - return start:stop +function _make_dynamicindex_expr(symbol::Symbol, dim::Union{Nothing,Int}) + # NOTE(penelopeysm): We could just use `:end` instead of Symbol(:end), but the former + # messes up syntax highlighting with Treesitter + # https://github.com/tree-sitter/tree-sitter-julia/issues/104 + if symbol === Symbol(:begin) + func = dim === nothing ? :(Base.firstindex) : :(Base.Fix2(firstindex, $dim)) + return :(DynamicIndex($(QuoteNode(symbol)), $func)) + elseif symbol === Symbol(:end) + func = dim === nothing ? :(Base.lastindex) : :(Base.Fix2(lastindex, $dim)) + return :(DynamicIndex($(QuoteNode(symbol)), $func)) + else + # Just a variable. + return symbol + end end -function _pretty_print_index(dr::DynamicRange) - return "$(_pretty_print_index(dr.start)):$(_pretty_print_index(dr.stop))" +function _make_dynamicindex_expr(expr::Expr, dim::Union{Nothing,Int}) + @gensym val + replaced_expr = MacroTools.postwalk(x -> replace_begin_and_end(x, val, dim), expr) + return if replaced_expr == expr + # Nothing to replace, just use the original expr. + expr + else + :(DynamicIndex($(QuoteNode(expr)), $val -> $replaced_expr)) + end end +# Replace all instances of `begin` in `expr` with `_firstindex_dim(val, dim)` and +# all instances of `end` with `_lastindex_dim(val, dim)`. +replace_begin_and_end(x, ::Any, ::Any) = x +function replace_begin_and_end(x::Symbol, val_sym, dim) + return if (x === :begin) + dim === nothing ? :(Base.firstindex($val_sym)) : :(Base.firstindex($val_sym, $dim)) + elseif (x === :end) + dim === nothing ? :(Base.lastindex($val_sym)) : :(Base.lastindex($val_sym, $dim)) + else + # It's some other symbol; we need to escape it to allow interpolation. + esc(x) + end +end +_pretty_string_index(ix) = string(ix) +_pretty_string_index(::Colon) = ":" +_pretty_string_index(di::DynamicIndex) = "DynamicIndex($(di.expr))" + +_concretize_index(idx::Any, ::Any) = idx +_concretize_index(idx::DynamicIndex, val) = idx.f(val) + """ Index(ix, child=Iden()) @@ -106,7 +129,7 @@ end Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.child == b.child Base.isequal(a::Index, b::Index) = a == b function _pretty_print_optic(io::IO, idx::Index) - ixs = join(map(_pretty_print_index, idx.ix), ", ") + ixs = join(map(_pretty_string_index, idx.ix), ", ") print(io, "[$(ixs)]") return _pretty_print_optic(io, idx.child) end @@ -118,19 +141,9 @@ function to_accessors(idx::Index) Base.ComposedFunction(to_accessors(idx.child), ilens) end end -is_dynamic(idx::Index) = any(_is_dynamic_idx, idx.ix) || is_dynamic(idx.child) +is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) function concretize(idx::Index, val) - concretized_indices = if length(idx.ix) == 0 - [] - elseif length(idx.ix) == 1 - # If there's only one index, it's linear indexing. This code is mostly lifted from - # Accessors.jl. - [concretize(only(idx.ix), val, nothing)] - else - # If there are multiple indices, then each index corresponds to a different - # dimension. - [concretize(ix, val, dim) for (dim, ix) in enumerate(idx.ix)] - end + concretized_indices = map(Base.Fix2(_concretize_index, val), idx.ix) inner_concretized = concretize(idx.child, view(val, concretized_indices...)) return Index((concretized_indices...,), inner_concretized) end diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 43b5f837..0416673e 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -107,6 +107,31 @@ function Base.showerror(io::IO, e::VarNameParseException) return print(io, "malformed variable name `$(e.expr)`") end +""" + VarNameConcretizationException() + +When constructing a `VarName` using [`@varname`](@ref) (or [`@opticof`](@ref)), we allow +for interpolation of the top-level symbol, e.g. using `name = :x; @varname(\$name)`. However, +if this is done, it is not possible to automatically concretize the resulting `VarName` by +passing `true` as the second argument to `@varname`. + +Because macros are confusing, this is probably worth more explanation. For example, consider +the user input `name = :x; @varname(\$name, true)`. + +Without concretization, we can easily handle this as `VarName{name}(Iden())`. `name` is then +resolved outside the macro to produce `VarName{:x}(Iden())`. However, to correctly +concretize this, we would need to generate the output `concretize(VarName{name}(), x)`; +i.e., we need to know at macro-expansion time that `name` evaluates to `:x`. This is not +possible given the expression `\$name` alone, which is why this error is thrown. +""" +struct VarNameConcretizationException <: Exception end +function Base.showerror(io::IO, ::VarNameConcretizationException) + return print( + io, + "cannot automatically concretize VarName with interpolated top-level symbol; call `concretize(vn, val)` manually instead", + ) +end + """ @varname(expr, concretize=false) @@ -134,28 +159,25 @@ x.a[1:3].b[2] # Dynamic indices -Some expressions may involve dynamic indices, e.g., `begin`, `end`, and `:`. These indices +Some expressions may involve dynamic indices, specifically, `begin`, `end`. These indices cannot be resolved, or 'concretized', until the value being indexed into is known. By -default, `@varname(...)` will not automatically concretize these expressions, and thus -the resulting `VarName` will contain markers for these. +default, `@varname(...)` will not automatically concretize these expressions, and thus the +resulting `VarName` will contain markers for these. -```jldoctest -julia> # VarNames are pretty-printed, so at first glance, it's not special... - vn = @varname(x[end]) -x[end] +Note that colons are not considered dynamic. -julia> # But if you look under the hood, you can see that the index is dynamic. - vn = @varname(x[end]); getoptic(vn).ix -(DynamicEnd(),) +```jldoctest +julia> vn = @varname(x[end]) +x[DynamicIndex(end)] -julia> vn = @varname(x[1:end, end]); getoptic(vn).ix -(DynamicRange{Int64, DynamicEnd}(1, DynamicEnd()), DynamicEnd()) +julia> vn = @varname(x[1, end-1]) +x[1, DynamicIndex(end - 1)] ``` You can detect whether a `VarName` contains any dynamic indices using `is_dynamic(vn)`: ```jldoctest -julia> vn = @varname(x[1:end, end]); is_dynamic(vn) +julia> vn = @varname(x[1, end-1]); AbstractPPL.is_dynamic(vn) true ``` @@ -165,13 +187,13 @@ To concretize such expressions, you can call `concretize(vn, val)` on the result ```jldoctest julia> x = randn(2, 3); -julia> vn = @varname(x[1:end, end]); vn2 = concretize(vn, x) -x[1:2, 3][1:2] +julia> vn = @varname(x[1, end-1]); vn2 = AbstractPPL.concretize(vn, x) +x[1, 2] -julia> getoptic(vn2).ix -((1:2), 3) +julia> getoptic(vn2).ix # Just an ordinary tuple. +(1, 2) -julia> is_dynamic(vn2) +julia> AbstractPPL.is_dynamic(vn2) false ``` @@ -183,7 +205,7 @@ top-level symbol to look up the value used for concretization. julia> x = randn(2, 3); julia> @varname(x[1:end, end][:], true) -x[1:2, 3][1:2] +x[1:2, 3][:] ``` # Interpolation @@ -208,17 +230,20 @@ For indices, you don't need to use `\$` to interpolate, just use the variable di julia> ix = 2; @varname(x[ix]) x[2] ``` + +However, if the top-level symbol is interpolated, automatic concretization is not +possible: + +```jldoctest +julia> name = :x; @varname(\$name[1:end], true) +ERROR: LoadError: cannot automatically concretize VarName with interpolated top-level symbol; call `concretize(vn, val)` manually instead +[...] +``` """ macro varname(expr, concretize::Bool=false) unconcretized_vn, sym = _varname(expr, :(Iden())) return if concretize - if sym === nothing - throw( - ArgumentError( - "cannot automatically concretize VarName with interpolated top-level symbol; call `concretize(vn, val)` manually instead", - ), - ) - end + sym === nothing && throw(VarNameConcretizationException()) :(concretize($unconcretized_vn, $(esc(sym)))) else unconcretized_vn @@ -237,13 +262,6 @@ function _varname(expr::Expr, inner_expr) # it means that there are no further property/indexing accesses (because otherwise # expr.head would be :ref or :.) Thus we don't need to recurse further, and we can # just return `inner_expr` as-is. - # TODO(penelopeysm): Is there a way to make auto-concretisation work here? To - # be clear, what we want is something like the following to work: - # name = :hello; hello = rand(3); @varname($name[:], true) - # I've tried every combination of `esc`, `QuoteNode`, and `$` I can think of, but - # with no success yet. It didn't work with old AbstractPPL either ("syntax: - # all-underscore identifiers are write-only and their values cannot be used in - # expressions"); at least now we give a more sensible error message. sym_expr = expr.args[1] return :($VarName{$(sym_expr)}($inner_expr)), nothing else @@ -251,12 +269,11 @@ function _varname(expr::Expr, inner_expr) sym = _handle_property(expr.args[2], expr) :(Property{$(sym)}($inner_expr)) elseif expr.head == :ref - ixs = map(first ∘ _handle_index, expr.args[2:end]) - # TODO(penelopeysm): Technically, here we could track whether any of the indices are - # dynamic, and store this for later use. - # isdyn = any(last, ixs_and_isdyn) - # What we do now (generate the dynamic VarName first, and then later check whether - # it needs concretization) is slightly inefficient. + original_ixs = expr.args[2:end] + is_single_index = length(original_ixs) == 1 + ixs = map(enumerate(original_ixs)) do (dim, ix) + _handle_index(ix, is_single_index ? nothing : dim) + end :(Index(tuple($(ixs...)), $inner_expr)) else # some other expression we can't parse @@ -285,36 +302,46 @@ function _handle_property(::Any, original_expr) throw(VarNameParseException(original_expr)) end -_handle_index(ix::Int) = ix, false -function _handle_index(ix::Symbol) - # NOTE(penelopeysm): We could just use `:end` instead of Symbol(:end), but the former - # messes up syntax highlighting with Treesitter - # https://github.com/tree-sitter/tree-sitter-julia/issues/104 - if ix == Symbol(:end) - return :(DynamicEnd()), true - elseif ix == Symbol(:begin) - return :(DynamicBegin()), true - elseif ix == :(:) - return :(DynamicColon()), true - else - # an interpolated symbol - return ix, false - end -end -function _handle_index(ix::Expr) - if Meta.isexpr(ix, :call, 3) && ix.args[1] == :(:) - # This is a range - start, isdyn = _handle_index(ix.args[2]) - stop, isdyn2 = _handle_index(ix.args[3]) - if isdyn || isdyn2 - return :(DynamicRange($start, $stop)), true - else - return :(($start):($stop)), false - end +_handle_index(ix::Int, ::Any) = ix +_handle_index(ix::Symbol, dim) = _make_dynamicindex_expr(ix, dim) +_handle_index(ix::Expr, dim) = _make_dynamicindex_expr(ix, dim) + +""" + @opticof(expr, concretize=false) + +Extract the optic from `@varname(expr, concretize)`. This is a thin wrapper around +`getoptic(@varname(...))`. + +If you don't need to concretize, you should use `_` as the top-level symbol to +indicate that it is not relevant: + +```jldoctest +julia> AbstractPPL.@opticof(_.a.b) +Optic(.a.b) +``` + +Only if you need to concretize should you provide a real variable name (in which case +it is then used to look up the value for concretization): + +```jldoctest +julia> x = randn(3, 4); AbstractPPL.@opticof(x[1:end, end], true) +Optic([1:3, 4]) +``` + +Note that concretization with `@opticof` has the same limitations as with `@varname`, +specifically, if the top-level symbol is interpolated, automatic concretization is not +possible. +""" +macro opticof(expr, concretize::Bool=false) + # This implementation is a bit ugly, as it copies the logic from `@varname`. However, + # getting the output of `@varname` and then processing it is a bit tricky, specifically + # when concretization is involved (because the top-level value must be escaped, but not + # anything else!). So it's easier to just duplicate the logic here. + unconcretized_vn, sym = _varname(expr, :(Iden())) + return if concretize + sym === nothing && throw(VarNameConcretizationException()) + :(getoptic(concretize($unconcretized_vn, $(esc(sym))))) else - # Some other expression. We don't want to parse this any further, but we also don't - # want to error, because it may well be an expression that evaluates to a valid - # index. - return ix, false + :(getoptic($unconcretized_vn)) end end From 5e9401ada4ffec92a4dbaa16458ad32fc1a111d2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 17:26:44 +0000 Subject: [PATCH 09/60] changelog --- HISTORY.md | 36 ++++++++++++++++++++++++++++++++++++ Project.toml | 2 +- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 35f1dc8a..66f95cce 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,39 @@ +## 0.14.0 + +This release overhauls the `VarName` type. +Much of the external API for traversing and manipulating `VarName`s has been preserved, but there are significant changes: + +**Internal representation** + +The `optic` field of VarName now uses our hand-rolled optic types, which are subtypes of `AbstractPPL.AbstractOptic`. +Previously these were optics from Accessors.jl. + +This change was made for two reasons: firstly, it is easier to provide custom behaviour for VarNames as we avoid running into possible type piracy issues, and secondly, the linked-list data structure used in `AbstractOptic` is easier to work with than Accessors.jl, which used `Base.ComposedFunction` to represent optic compositions and required a lot of care to avoid issues with associativity and identity optics. + +To construct an optic, the easiest way is to use the `@opticof` macro, which superficially behaves similarly to `Accessors.@optic` (for example, you can write `@opticof _[1].y.z`), but also supports automatic concretization by passing a second parameter (just like `@varname`). + +**Concretization** + +VarNames using 'dynamic' indices, i.e., `begin` and `end`, are now instantiated in a 'dynamic' form, meaning that these indices are unresolved. +These indices need to be resolved, or concretized, against the actual container. +For example, `@varname(x[end])` is dynamic, but when concretized against `x = randn(3)`, this becomes `@varname(x[3])`. +This can be done using `concretize(varname, x)`. + +The idea of concretization is not new to AbstractPPL. +However, there are some differences: + + - Colons are no longer concretized: they *always* remain as Colons, even after calling `concretize`. + - Previously, AbstractPPL would refuse to allow you to construct unconcretized versions of `begin` and `end`. This is no longer the case; you can now create such VarNames in their unconcretized forms. + This is useful, for example, when indexing into a chain that contains `x` as a variable-length vector. This change allows you to write `chain[@varname(x[end])]` without having AbstractPPL throw an error. + +**Interface** + +The `vsym` function (and `@vsym`) has been removed; you should use `getsym(vn)` instead. + +The `Base.get` and `Base.set!` methods for VarNames have been removed (these were responsible for method ambiguities). + +The `inspace` function has been removed (it used to be relevant for Turing's old Gibbs sampler; but now it no longer serves any use). + ## 0.13.6 Fix a missing qualifier in AbstractPPLDistributionsExt. diff --git a/Project.toml b/Project.toml index 99c04fb8..78e90428 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.13.6" +version = "0.14.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 6d73a5a635ee79b0505b4b2f4a23ee5b21d4f6c9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 18:15:41 +0000 Subject: [PATCH 10/60] Docs --- docs/Project.toml | 3 + docs/make.jl | 2 +- docs/src/api.md | 64 ----------- docs/src/pplapi.md | 20 ++++ docs/src/varname.md | 170 +++++++++++++++++++++++++++++ ext/AbstractPPLDistributionsExt.jl | 3 + src/AbstractPPL.jl | 22 +++- src/varname/optic.jl | 35 ++++-- src/varname/varname.jl | 37 +++++-- 9 files changed, 266 insertions(+), 90 deletions(-) delete mode 100644 docs/src/api.md create mode 100644 docs/src/pplapi.md create mode 100644 docs/src/varname.md diff --git a/docs/Project.toml b/docs/Project.toml index 15b2ec43..9aed942a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,3 +3,6 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[sources] +AbstractPPL = {path = "../"} diff --git a/docs/make.jl b/docs/make.jl index a94901da..abab3f69 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,7 +9,7 @@ DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive= makedocs(; sitename="AbstractPPL", modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)], - pages=["index.md", "api.md", "interface.md"], + pages=["index.md", "varname.md", "pplapi.md", "interface.md"], checkdocs=:exports, doctest=false, ) diff --git a/docs/src/api.md b/docs/src/api.md deleted file mode 100644 index 9f289ab5..00000000 --- a/docs/src/api.md +++ /dev/null @@ -1,64 +0,0 @@ -# API - -## VarNames - -```@docs -VarName -getsym -getoptic -inspace -subsumes -subsumedby -vsym -@varname -@vsym -``` - -## VarName prefixing and unprefixing - -```@docs -prefix -unprefix -``` - -## Extracting values corresponding to a VarName - -```@docs -hasvalue -getvalue -``` - -## Splitting VarNames up into components - -```@docs -varname_leaves -varname_and_value_leaves -``` - -## VarName serialisation - -```@docs -index_to_dict -dict_to_index -varname_to_string -string_to_varname -``` - -## Abstract model functions - -```@docs -AbstractProbabilisticProgram -condition -decondition -fix -unfix -logdensityof -AbstractContext -evaluate!! -``` - -## Abstract traces - -```@docs -AbstractModelTrace -``` diff --git a/docs/src/pplapi.md b/docs/src/pplapi.md new file mode 100644 index 00000000..492ab294 --- /dev/null +++ b/docs/src/pplapi.md @@ -0,0 +1,20 @@ +# Probabilistic programming API + +## Abstract model functions + +```@docs +AbstractProbabilisticProgram +condition +decondition +fix +unfix +logdensityof +AbstractContext +evaluate!! +``` + +## Abstract traces + +```@docs +AbstractModelTrace +``` diff --git a/docs/src/varname.md b/docs/src/varname.md new file mode 100644 index 00000000..66a8b72f --- /dev/null +++ b/docs/src/varname.md @@ -0,0 +1,170 @@ +# VarNames and optics + +## VarNames: an overview + +One of the most important parts of AbstractPPL.jl is the `VarName` type, which is used throughout the TuringLang ecosystem to represent names of random variables. + +Fundamentally, a `VarName` comprises a symbol (which represents the name of the variable itself) and an optic (which tells us which part of the variable we might be interested in). +For example, `x.a[1]` means the first element of the field `a` of the variable `x`. +Here, `x` is the symbol, and `.a[1]` is the optic. + +VarNames can be created using the `@varname` macro: + +```@example vn +using AbstractPPL + +vn = @varname(x.a[1]) +``` + +```@docs +VarName +@varname +``` + +You can obtain the components of a `VarName` using the `getsym` and `getoptic` functions: + +```@example vn +getsym(vn), getoptic(vn) +``` + +```@docs +getsym +getoptic +``` + +## Dynamic indices + +VarNames may contain 'dynamic' indices, that is, indices whose meaning is not known until they are resolved against a specific value. +For example, `x[end]` refers to the last element of `x`; but we don't know what that means until we know what `x` is. + +Specifically, `begin` and `end` symbols in indices are treated as dynamic indices. +This is also true for any expression that contains `begin` or `end`, such as `end-1` or `1:3:end`. + +Dynamic indices are represented using an internal type, `AbstractPPL.DynamicIndex`. + +```@example vn +vn_dyn = @varname(x[1:2:end]) +``` + +You can detect whether a VarName contains dynamic indices using the `is_dynamic` function: + +```@example vn +is_dynamic(vn_dyn) +``` + +```@docs +is_dynamic +``` + +These dynamic indices can be resolved, or _concretized_, by passing a specific value to the `concretize` function: + +```@example vn +x = randn(5) +vn_conc = concretize(vn_dyn, x) +``` + +```@docs +concretize +``` + +## Optics + +The optics used in AbstractPPL.jl are represented as a linked list. +For example, the optic `.a[1]` is a `Property` optic that contains an `Index` optic as its child. +That means that the 'elements' of the linked list can be read from left-to-right: + +``` +Property{:a} -> Index{1} -> Iden +``` + +All optic linked lists are terminated with an `Iden` optic, which represents the identity function. + +```@example vn +optic = getoptic(@varname x.a[1]) +dump(optic) +``` + +```@docs +AbstractOptic +Property +Index +Iden +``` + +Instead of calling `getoptic(@varname(...))`, you can directly use the [`@opticof`](@ref) macro to create optics: + +```@example vn +optic = @opticof(_.a[1]) +``` + +```@docs +@opticof +``` + +## Composing and decomposing optics + +If you have two optics, you can compose them using the `∘` operator: + +```@example vn +optic1 = @opticof(_.a) +optic2 = @opticof(_[1]) +composed = optic2 ∘ optic1 +``` + +Notice the order of composition here, which can be counterintuitive: `optic2 ∘ optic1` means "first apply `optic1`, then apply `optic2`", and thus this represents the optic `.a[1]` (not `.[1].a`). + +```@docs +Base.:∘(::AbstractOptic, ::AbstractOptic) +``` + +`Base.cat(optics...)` is also provided, which composes optics in a more intuitive sense (indeed, if you think of an optic as a linked list, this can be thought of as concatenating the lists). +The following is equivalent to the previous example: + +```@example vn +composed2 = Base.cat(optic1, optic2) +``` + +```@docs +Base.cat(::AbstractOptic...) +``` + +Several functions are provided to decompose optics, which all stem from their linked-list structure. +Their names directly mirror Haskell's functions for decomposing lists, but are prefixed with `o`: + +```@docs +ohead +otail +oinit +olast +``` + +For example, `ohead` returns the first element of the optic linked list, and `otail` returns the rest of the list after removing the head: + +```@example vn +optic = @opticof(_.a[1].b[2]) +ohead(optic), otail(optic) +``` + +Convesely, `oinit` returns the optic linked list without its last element, and `olast` returns the last element: + +```@example vn +oinit(optic), olast(optic) +``` + +If the optic only has a single element, then `oinit` and `otail` return `Iden`, while `ohead` and `olast` return the optic itself: + +```@example vn +optic_single = @opticof(_.a) +oinit(optic_single), olast(optic_single), ohead(optic_single), otail(optic_single) +``` + +## Converting VarNames to optics and back + +Sometimes it is useful to treat a VarName's top level symbol as if it were part of the optic. +For example, when indexing into a NamedTuple `nt`, we might want to treat the entire VarName `x.a[1]` as an optic that can be applied to a NamedTuple: i.e., we want to access the `nt.x` field rather than the variable `x` itself. +This can be achieved with: + +```@docs +varname_to_optic +optic_to_varname +``` diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index d10b748c..2824e75e 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -49,6 +49,8 @@ This decision may be revisited in the future. module AbstractPPLDistributionsExt +#= + using AbstractPPL: AbstractPPL, VarName, Accessors, LinearAlgebra using Distributions: Distributions using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular @@ -322,4 +324,5 @@ function AbstractPPL.getvalue( end end +=# end diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index f31da95c..f657131d 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -1,8 +1,24 @@ module AbstractPPL -# VarName and Optic functions -export VarName, getsym, getoptic, concretize, is_dynamic, @varname, @opticof -export AbstractOptic, Iden, Index, Property, DynamicIndex, ohead, otail, olast, oinit +# Optics +export AbstractOptic, + Iden, + Index, + Property, + ohead, + otail, + olast, + oinit, + # VarName + VarName, + getsym, + getoptic, + concretize, + is_dynamic, + @varname, + @opticof, + varname_to_optic, + optic_to_varname # subsumes, # subsumedby, diff --git a/src/varname/optic.jl b/src/varname/optic.jl index b5a86c58..df591de8 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -76,8 +76,8 @@ function _make_dynamicindex_expr(symbol::Symbol, dim::Union{Nothing,Int}) func = dim === nothing ? :(Base.lastindex) : :(Base.Fix2(lastindex, $dim)) return :(DynamicIndex($(QuoteNode(symbol)), $func)) else - # Just a variable. - return symbol + # Just a variable; but we need to escape it to allow interpolation. + return esc(symbol) end end function _make_dynamicindex_expr(expr::Expr, dim::Union{Nothing,Int}) @@ -188,10 +188,10 @@ end Compose two `AbstractOptic`s together. ```jldoctest -julia> p1 = Property{:a}(Index((1,))) +julia> p1 = @opticof(_.a[1]) Optic(.a[1]) -julia> p2 = Property{:b}(Index((2,3))) +julia> p2 = @opticof(_.b[2, 3]) Optic(.b[2, 3]) julia> p1 ∘ p2 @@ -214,6 +214,17 @@ function Base.:(∘)(outer::AbstractOptic, inner::AbstractOptic) end end +""" + cat(optics::AbstractOptic...) + +Compose multiple `AbstractOptic`s together. The optics should be provided from +innermost to outermost, i.e., `cat(o1, o2, o3)` corresponds to `o3 ∘ o2 ∘ o1`. + +""" +function Base.cat(optics::AbstractOptic...) + return foldl((a, b) -> b ∘ a, optics; init=Iden()) +end + """ ohead(optic::AbstractOptic) @@ -221,10 +232,10 @@ Get the innermost layer of an optic. For all optics, we have that `otail(optic) ohead(optic) == optic`. ```jldoctest -julia> ohead(getoptic(@varname(x.a[1][2]))) +julia> ohead(@opticof _.a[1][2]) Optic(.a) -julia> ohead(getoptic(@varname(x))) +julia> ohead(@opticof _) Optic() ``` """ @@ -239,10 +250,10 @@ Get everything but the innermost layer of an optic. For all optics, we have that `otail(optic) ∘ ohead(optic) == optic`. ```jldoctest -julia> otail(getoptic(@varname(x.a[1][2]))) +julia> otail(@opticof _.a[1][2]) Optic([1][2]) -julia> otail(getoptic(@varname(x))) +julia> otail(@opticof _) Optic() ``` """ @@ -257,10 +268,10 @@ Get the outermost layer of an optic. For all optics, we have that `olast(optic) oinit(optic) == optic`. ```jldoctest -julia> olast(getoptic(@varname(x.a[1][2]))) +julia> olast(@opticof _.a[1][2]) Optic([2]) -julia> olast(getoptic(@varname(x))) +julia> olast(@opticof _) Optic() ``` """ @@ -287,10 +298,10 @@ Get everything but the outermost layer of an optic. For all optics, we have that `olast(optic) ∘ oinit(optic) == optic`. ```jldoctest -julia> oinit(getoptic(@varname(x.a[1][2]))) +julia> oinit(@opticof _.a[1][2]) Optic(.a[1]) -julia> oinit(getoptic(@varname(x))) +julia> oinit(@opticof _) Optic() ``` """ diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 0416673e..09b8203a 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -81,8 +81,8 @@ Base.Symbol(vn::VarName) = Symbol(string(vn)) concretize(vn::VarName, x) Return `vn` concretized on `x`, i.e. any information related to the runtime shape of `x` is -evaluated. This will convert any `begin`, `end`, or `:` indices in `vn` to concrete indices -with information about the length of the dimension being sliced. +evaluated. This will convert any `begin` and `end` indices in `vn` to concrete indices with +information about the length of the dimension being indexed into. """ concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x)) @@ -135,8 +135,7 @@ end """ @varname(expr, concretize=false) -Create a [`VarName`](@ref) given an expression `expr` representing a variable or part of it -(any expression that can be assigned to). +Create a [`VarName`](@ref) given an expression `expr` representing a variable or part of it. # Basic examples @@ -174,14 +173,14 @@ julia> vn = @varname(x[1, end-1]) x[1, DynamicIndex(end - 1)] ``` -You can detect whether a `VarName` contains any dynamic indices using `is_dynamic(vn)`: +You can detect whether a `VarName` contains any dynamic indices using [`is_dynamic`](@ref): ```jldoctest julia> vn = @varname(x[1, end-1]); AbstractPPL.is_dynamic(vn) true ``` -To concretize such expressions, you can call `concretize(vn, val)` on the resulting +To concretize such expressions, you can call [`concretize`](@ref) on the resulting `VarName`. After concretization, the resulting `VarName` will no longer be dynamic. ```jldoctest @@ -197,9 +196,9 @@ julia> AbstractPPL.is_dynamic(vn2) false ``` -Alternatively, you can pass `true` as the second positional argument the `@varname` macro -(note that it is not a keyword argument!). This will call `concretize` for you, using the -top-level symbol to look up the value used for concretization. +Alternatively, you can pass `true` as the second positional argument to the `@varname` macro +(note that it is not a keyword argument!). This will automatically call [`concretize`](@ref) +for you, using the top-level symbol to look up the value used for concretization. ```jldoctest julia> x = randn(2, 3); @@ -224,7 +223,7 @@ julia> @varname(\$name.a.\$name[1]) hello.a.hello[1] ``` -For indices, you don't need to use `\$` to interpolate, just use the variable directly: +For indices, you do nott need to use `\$` to interpolate, just use the variable directly: ```jldoctest julia> ix = 2; @varname(x[ix]) @@ -345,3 +344,21 @@ macro opticof(expr, concretize::Bool=false) :(getoptic($unconcretized_vn)) end end + +""" + varname_to_optic(vn::VarName) + +Convert a `VarName` to an optic, by converting the top-level symbol to a `Property` optic. +""" +varname_to_optic(vn::VarName{sym}) where {sym} = Property{sym}(getoptic(vn)) + +""" + optic_to_varname(optic::Property{sym}) where {sym} + +Convert a `Property` optic to a `VarName`, by converting the top-level property to a symbol. +This fails for all other optics. +""" +optic_to_varname(optic::Property{sym}) where {sym} = VarName{sym}(otail(optic)) +function optic_to_varname(::AbstractOptic) + throw(ArgumentError("to_varname: can only convert Property optics to VarName")) +end From ed3ba78259fdc7fa6d740099c2024618d177bc14 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 18:40:33 +0000 Subject: [PATCH 11/60] Add a bunch of tests --- src/varname/optic.jl | 49 ++++++++++++++++-------- test/varname/varname.jl | 83 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 113 insertions(+), 19 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index df591de8..83622b7a 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -82,28 +82,45 @@ function _make_dynamicindex_expr(symbol::Symbol, dim::Union{Nothing,Int}) end function _make_dynamicindex_expr(expr::Expr, dim::Union{Nothing,Int}) @gensym val - replaced_expr = MacroTools.postwalk(x -> replace_begin_and_end(x, val, dim), expr) - return if replaced_expr == expr - # Nothing to replace, just use the original expr. - expr + if has_begin_or_end(expr) + replaced_expr = MacroTools.postwalk(x -> _make_dynamicindex_expr(x, val, dim), expr) + return :(DynamicIndex($(QuoteNode(expr)), $val -> $replaced_expr)) else - :(DynamicIndex($(QuoteNode(expr)), $val -> $replaced_expr)) + return esc(expr) end end - -# Replace all instances of `begin` in `expr` with `_firstindex_dim(val, dim)` and -# all instances of `end` with `_lastindex_dim(val, dim)`. -replace_begin_and_end(x, ::Any, ::Any) = x -function replace_begin_and_end(x::Symbol, val_sym, dim) - return if (x === :begin) - dim === nothing ? :(Base.firstindex($val_sym)) : :(Base.firstindex($val_sym, $dim)) - elseif (x === :end) - dim === nothing ? :(Base.lastindex($val_sym)) : :(Base.lastindex($val_sym, $dim)) +function _make_dynamicindex_expr(symbol::Symbol, val_sym::Symbol, dim::Union{Nothing,Int}) + # NOTE(penelopeysm): We could just use `:end` instead of Symbol(:end), but the former + # messes up syntax highlighting with Treesitter + # https://github.com/tree-sitter/tree-sitter-julia/issues/104 + if symbol === Symbol(:begin) + return if dim === nothing + :(Base.firstindex($val_sym)) + else + :(Base.Fix2(firstindex, $dim)($val_sym)) + end + elseif symbol === Symbol(:end) + return if dim === nothing + :(Base.lastindex($val_sym)) + else + :(Base.Fix2(lastindex, $dim)($val_sym)) + end else - # It's some other symbol; we need to escape it to allow interpolation. - esc(x) + # Just a variable; but we need to escape it to allow interpolation. + return esc(symbol) end end +function _make_dynamicindex_expr(i::Any, ::Symbol, ::Union{Nothing,Int}) + return i +end + +has_begin_or_end(expr::Expr) = has_begin_or_end_inner(expr, false) +function has_begin_or_end_inner(x, found::Bool) + return found || + x ∈ (:end, :begin, Expr(:end), Expr(:begin)) || + (x isa Expr && any(arg -> has_begin_or_end_inner(arg, found), x.args)) +end + _pretty_string_index(ix) = string(ix) _pretty_string_index(::Colon) = ":" _pretty_string_index(di::DynamicIndex) = "DynamicIndex($(di.expr))" diff --git a/test/varname/varname.jl b/test/varname/varname.jl index e7e80457..49ec02e3 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -4,19 +4,96 @@ using AbstractPPL using Test @testset "varname/varname.jl" verbose = true begin - # TODO + @testset "basic construction" begin + @test @varname(x) == VarName{:x}(Iden()) + @test @varname(x[1]) == VarName{:x}(Index((1,), Iden())) + @test @varname(x.a) == VarName{:x}(Property{:a}(Iden())) + @test @varname(x.a[1]) == VarName{:x}(Property{:a}(Index((1,), Iden()))) + end @testset "errors on nonsensical inputs" begin # Note: have to wrap in eval to avoid throwing an error before the actual test errmsg = "malformed variable name" @test_throws errmsg eval(:(@varname(1))) @test_throws errmsg eval(:(@varname(x + y))) - # This doesn't fail to parse, but it will throw a MethodError because you can't - # construct a DynamicColon with a DynamicColon as an argument # TODO(penelopeysm): I would like to test this, but JuliaFormatter reformats # this into x[1::] which then fails to parse. Grr. # @test_throws MethodError eval(:(@varname(x[1: :]))) end + + @testset "dynamic indices and manual concretization" begin + @testset "begin" begin + vn = @varname(x[begin]) + @test vn isa VarName + @test is_dynamic(vn) + @test concretize(vn, [1.0]) == @varname(x[1]) + end + + @testset "end" begin + vn = @varname(x[end]) + @test vn isa VarName + @test is_dynamic(vn) + @test concretize(vn, randn(5)) == @varname(x[5]) + end + + @testset "expressions thereof" begin + vn = @varname(x[(begin + 2):(end - 1)]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(6) + @test concretize(vn, arr) == @varname(x[3:5]) + end + + @testset "different dimensions" begin + vn = @varname(x[end, begin:(end - 1)]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(4, 4) + @test concretize(vn, arr) == @varname(x[4, 1:3]) + end + + @testset "linear indexing for matrices" begin + vn = @varname(x[begin:end]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(4, 4) + @test concretize(vn, arr) == @varname(x[:]) + end + end + + @testset "things that shouldn't be dynamic aren't dynamic" begin + @test !is_dynamic(@varname(x)) + @test !is_dynamic(@varname(x[3])) + @test !is_dynamic(@varname(x[:])) + @test !is_dynamic(@varname(x[1:3])) + @test !is_dynamic(@varname(x[1:3, 3, 2 + 9])) + i = 10 + @test !is_dynamic(@varname(x[1:3, 3, 2 + 9, 1:3:i])) + end + + @testset "automatic concretization" begin + test_array = randn(5, 5) + @testset "begin" begin + vn = @varname(test_array[begin], true) + @test vn == concretize(@varname(test_array[begin]), test_array) + end + @testset "end" begin + vn = @varname(test_array[end], true) + @test vn == concretize(@varname(test_array[end]), test_array) + end + @testset "expressions thereof" begin + vn = @varname(test_array[(begin + 1):(end - 2)], true) + @test vn == concretize(@varname(test_array[(begin + 1):(end - 2)]), test_array) + end + @testset "different dimensions" begin + vn = @varname(test_array[end, begin:(end - 1)], true) + @test vn == concretize(@varname(test_array[end, begin:(end - 1)]), test_array) + end + @testset "linear indexing for matrices" begin + vn = @varname(test_array[begin:end], true) + @test vn == concretize(@varname(test_array[begin:end]), test_array) + end + end end end # module VarNameTests From 07a7e533a00c1e0f38154170b5ecb345d0e46a24 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 18:52:58 +0000 Subject: [PATCH 12/60] Fix more bugs, add more tests --- src/varname/optic.jl | 1 + src/varname/varname.jl | 2 +- test/varname/varname.jl | 49 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 83622b7a..51ecc37c 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -111,6 +111,7 @@ function _make_dynamicindex_expr(symbol::Symbol, val_sym::Symbol, dim::Union{Not end end function _make_dynamicindex_expr(i::Any, ::Symbol, ::Union{Nothing,Int}) + # this handles things like integers, colons, etc. return i end diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 09b8203a..5ed67b91 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -262,7 +262,7 @@ function _varname(expr::Expr, inner_expr) # expr.head would be :ref or :.) Thus we don't need to recurse further, and we can # just return `inner_expr` as-is. sym_expr = expr.args[1] - return :($VarName{$(sym_expr)}($inner_expr)), nothing + return :(VarName{$(esc(sym_expr))}($inner_expr)), nothing else next_inner = if expr.head == :(.) sym = _handle_property(expr.args[2], expr) diff --git a/test/varname/varname.jl b/test/varname/varname.jl index 49ec02e3..64afdcd0 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -11,14 +11,12 @@ using Test @test @varname(x.a[1]) == VarName{:x}(Property{:a}(Index((1,), Iden()))) end - @testset "errors on nonsensical inputs" begin + @testset "errors on invalid inputs" begin # Note: have to wrap in eval to avoid throwing an error before the actual test errmsg = "malformed variable name" @test_throws errmsg eval(:(@varname(1))) @test_throws errmsg eval(:(@varname(x + y))) - # TODO(penelopeysm): I would like to test this, but JuliaFormatter reformats - # this into x[1::] which then fails to parse. Grr. - # @test_throws MethodError eval(:(@varname(x[1: :]))) + @test_throws MethodError eval(:(@varname(x[1:Colon()]))) end @testset "dynamic indices and manual concretization" begin @@ -94,6 +92,49 @@ using Test @test vn == concretize(@varname(test_array[begin:end]), test_array) end end + + @testset "interpolation" begin + @testset "of property names" begin + prop = :myprop + vn = @varname(x.$prop) + @test vn == @varname(x.myprop) + end + + @testset "of indices" begin + idx = 3 + vn = @varname(x[idx]) + @test vn == @varname(x[3]) + end + + @testset "with dynamic indices" begin + idx = 3 + vn = @varname(x[end - idx]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(6) + @test concretize(vn, arr) == @varname(x[3]) + # Note that `idx` is only resolved at concretization time (because it's stored + # in a function that looks like (val -> lastindex(val) - idx) -- the VALUE of + # `idx` is not interpolated at macro time because we have no way of obtaining + # values inside the macro). So we could change it and re-concretize... + idx = 4 + @test concretize(vn, arr) == @varname(x[2]) + end + + @testset "of top-level name" begin + name = :x + @test @varname($name) == @varname(x) + @test @varname($name[1]) == @varname(x[1]) + @test @varname($name.a) == @varname(x.a) + end + + @testset "mashup of everything" begin + name = :x + index = 2 + prop = :b + @test @varname($name.$prop[3 * index]) == @varname(x.b[6]) + end + end end end # module VarNameTests From 8464ea23b6aa5e3c942914ec343aa748b357d69a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 19:01:30 +0000 Subject: [PATCH 13/60] more tests --- HISTORY.md | 2 ++ test/varname.jl | 79 ----------------------------------------- test/varname/varname.jl | 32 ++++++++++++++--- 3 files changed, 29 insertions(+), 84 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 66f95cce..13c6bd4f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -32,6 +32,8 @@ The `vsym` function (and `@vsym`) has been removed; you should use `getsym(vn)` The `Base.get` and `Base.set!` methods for VarNames have been removed (these were responsible for method ambiguities). +VarNames cannot be composed with optics now (you need to compose the optics yourself). + The `inspace` function has been removed (it used to be relevant for Turing's old Gibbs sampler; but now it no longer serves any use). ## 0.13.6 diff --git a/test/varname.jl b/test/varname.jl index 8de88d8c..047a9dcb 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -4,60 +4,6 @@ using OffsetArrays using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky @testset "varnames" begin - @testset "string and symbol conversion" begin - vn1 = @varname x[1][2] - @test string(vn1) == "x[1][2]" - @test Symbol(vn1) == Symbol("x[1][2]") - end - - @testset "equality and hashing" begin - vn1 = @varname x[1][2] - vn2 = @varname x[1][2] - @test vn2 == vn1 - @test hash(vn2) == hash(vn1) - end - - @testset "construction & concretization" begin - i = 1:10 - j = 2:2:5 - @test @varname(A[1].b[i]) == @varname(A[1].b[1:10]) - @test @varname(A[j]) == @varname(A[2:2:5]) - - @test @varname(A[:, 1][1 + 1]) == @varname(A[:, 1][2]) - @test(@varname(A[:, 1][2]) == VarName{:A}(@o(_[:, 1]) ⨟ @o(_[2]))) - - # concretization - y = zeros(10, 10) - x = (a=[1.0 2.0; 3.0 4.0; 5.0 6.0],) - - @test @varname(y[begin, i], true) == @varname(y[1, 1:10]) - @test test_equal(@varname(y[:], true), @varname(y[1:100])) - @test test_equal(@varname(y[:, begin], true), @varname(y[1:10, 1])) - @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === - AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) - @test test_equal(@varname(x.a[1:end, end][:], true), @varname(x.a[1:3, 2][1:3])) - end - - @testset "compose and opcompose" begin - @test IndexLens(1) ∘ @varname(x.a) == @varname(x.a[1]) - @test @varname(x.a) ⨟ IndexLens(1) == @varname(x.a[1]) - - @test @varname(x) ⨟ identity == @varname(x) - @test identity ∘ @varname(x) == @varname(x) - @test @varname(x.a) ⨟ identity == @varname(x.a) - @test identity ∘ @varname(x.a) == @varname(x.a) - @test @varname(x[1].b) ⨟ identity == @varname(x[1].b) - @test identity ∘ @varname(x[1].b) == @varname(x[1].b) - end - - @testset "get & set" begin - x = (a=[1.0 2.0; 3.0 4.0; 5.0 6.0], b=1.0) - @test get(x, @varname(a[1, 2])) == 2.0 - @test get(x, @varname(b)) == 1.0 - @test set(x, @varname(a[1, 2]), 10) == (a=[1.0 10.0; 3.0 4.0; 5.0 6.0], b=1.0) - @test set(x, @varname(b), 10) == (a=[1.0 2.0; 3.0 4.0; 5.0 6.0], b=10.0) - end - @testset "subsumption with standard indexing" begin # x ⊑ x @test @varname(x) ⊑ @varname(x) @@ -104,31 +50,6 @@ using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[:])) end - @testset "non-standard indexing" begin - A = rand(10, 10) - @test test_equal( - @varname(A[1, Not(3)], true), @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]) - ) - - B = OffsetArray(A, -5, -5) # indices -4:5×-4:5 - @test test_equal(@varname(B[1, :], true), @varname(B[1, -4:5])) - end - @testset "type stability" begin - @inferred VarName{:a}() - @inferred VarName{:a}(IndexLens(1)) - @inferred VarName{:a}(IndexLens(1, 2)) - @inferred VarName{:a}(PropertyLens(:b)) - @inferred VarName{:a}(Accessors.opcompose(IndexLens(1), PropertyLens(:b))) - - b = (a=[1, 2, 3],) - @inferred get(b, @varname(a[1])) - @inferred Accessors.set(b, @varname(a[1]), 10) - - c = (b=(a=[1, 2, 3],),) - @inferred get(c, @varname(b.a[1])) - @inferred Accessors.set(c, @varname(b.a[1]), 10) - end - @testset "de/serialisation of VarNames" begin y = ones(10) z = ones(5, 2) diff --git a/test/varname/varname.jl b/test/varname/varname.jl index 64afdcd0..0f2975bd 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -4,11 +4,11 @@ using AbstractPPL using Test @testset "varname/varname.jl" verbose = true begin - @testset "basic construction" begin - @test @varname(x) == VarName{:x}(Iden()) - @test @varname(x[1]) == VarName{:x}(Index((1,), Iden())) - @test @varname(x.a) == VarName{:x}(Property{:a}(Iden())) - @test @varname(x.a[1]) == VarName{:x}(Property{:a}(Index((1,), Iden()))) + @testset "basic construction (and type stability)" begin + @test @varname(x) == (@inferred VarName{:x}(Iden())) + @test @varname(x[1]) == (@inferred VarName{:x}(Index((1,), Iden()))) + @test @varname(x.a) == (@inferred VarName{:x}(Property{:a}(Iden()))) + @test @varname(x.a[1]) == (@inferred VarName{:x}(Property{:a}(Index((1,), Iden())))) end @testset "errors on invalid inputs" begin @@ -19,6 +19,28 @@ using Test @test_throws MethodError eval(:(@varname(x[1:Colon()]))) end + @testset "equality" begin + @test @varname(x) == @varname(x) + @test @varname(x) != @varname(y) + @test @varname(x[1]) == @varname(x[1]) + @test @varname(x[1]) != @varname(x[2]) + @test @varname(x.a) == @varname(x.a) + @test @varname(x.a) != @varname(x.b) + @test @varname(x.a[1]) == @varname(x.a[1]) + @test @varname(x.a[1]) != @varname(x.a[2]) + @test @varname(x.a[1]) != @varname(x.b[1]) + end + + @testset "pretty-printing" begin + @test string(@varname(x)) == "x" + @test string(@varname(x[1])) == "x[1]" + @test string(@varname(x.a)) == "x.a" + @test string(@varname(x.a[1])) == "x.a[1]" + @test string(@varname(x[begin])) == "x[DynamicIndex(begin)]" + @test string(@varname(x[end])) == "x[DynamicIndex(end)]" + @test string(@varname(x[:])) == "x[:]" + end + @testset "dynamic indices and manual concretization" begin @testset "begin" begin vn = @varname(x[begin]) From 5aa399712518ba38afcc19e4859e1141adf8be4b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 19:02:57 +0000 Subject: [PATCH 14/60] tests --- test/runtests.jl | 2 +- test/varname/varname.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 06e91615..dd958eb1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractPPL.jl" begin if GROUP == "All" || GROUP == "Tests" - include("Aqua.jl") + # include("Aqua.jl") include("abstractprobprog.jl") include("varname/optic.jl") include("varname/varname.jl") diff --git a/test/varname/varname.jl b/test/varname/varname.jl index 0f2975bd..87b63b2f 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -77,7 +77,7 @@ using Test @test vn isa VarName @test is_dynamic(vn) arr = randn(4, 4) - @test concretize(vn, arr) == @varname(x[:]) + @test concretize(vn, arr) == @varname(x[1:16]) end end From 8826e6030850a8a79bad69e35c49b881687a6970 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 19:38:39 +0000 Subject: [PATCH 15/60] fix DynamicIndex comparisons --- src/varname/optic.jl | 28 ++++++++++++++++++++++++- test/varname/varname.jl | 45 ++++++++++++++++++++++++++++++++--------- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 51ecc37c..9ebedb0b 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -59,12 +59,37 @@ For example: - the index `begin` is turned into `DynamicIndex(:begin, (val) -> Base.firstindex(val))`. - the index `1:end` is turned into `DynamicIndex(:(1:end), (val) -> 1:Base.lastindex(val))`. -The `expr` field stores the original expression solely for pretty-printing purposes. +# Stored `Expr` + +The `expr` field stores the original expression and is used both for pretty-printing as well +as comparisons. + +Note that because the stored function `f` is an anonymous function that is generated +dynamically, we should not include it in equality comparisons, as two functions that are +actually equivalent will not compare equal: + +```julia +julia> (x -> x + 1) == (x -> x + 1) +false +``` + +But, thankfully, we can just compare the `expr` field to determine whether the +DynamicIndices were constructed from the same expression (which implies that their +functions are equivalent). + +Note that these definitions also allow us some degree of resilience towards whitespace +changes, or parenthesisation, in the original expression. For example, `begin+1` and `(begin ++ 1)` will be treated as the same expression. However, it does not handle commutative +expressions; e.g., `begin + 1` and `1 + begin` will be treated as different expressions. """ struct DynamicIndex{E<:Union{Expr,Symbol},F} expr::E f::F end +Base.:(==)(a::DynamicIndex, b::DynamicIndex) = a.expr == b.expr +Base.isequal(a::DynamicIndex, b::DynamicIndex) = isequal(a.expr, b.expr) +Base.hash(di::DynamicIndex, h::UInt) = hash(di.expr, h) + function _make_dynamicindex_expr(symbol::Symbol, dim::Union{Nothing,Int}) # NOTE(penelopeysm): We could just use `:end` instead of Symbol(:end), but the former # messes up syntax highlighting with Treesitter @@ -146,6 +171,7 @@ end Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.child == b.child Base.isequal(a::Index, b::Index) = a == b +Base.hash(a::Index, h::UInt) = hash((a.ix, a.child), h) function _pretty_print_optic(io::IO, idx::Index) ixs = join(map(_pretty_string_index, idx.ix), ", ") print(io, "[$(ixs)]") diff --git a/test/varname/varname.jl b/test/varname/varname.jl index 87b63b2f..7deed66d 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -19,16 +19,41 @@ using Test @test_throws MethodError eval(:(@varname(x[1:Colon()]))) end - @testset "equality" begin - @test @varname(x) == @varname(x) - @test @varname(x) != @varname(y) - @test @varname(x[1]) == @varname(x[1]) - @test @varname(x[1]) != @varname(x[2]) - @test @varname(x.a) == @varname(x.a) - @test @varname(x.a) != @varname(x.b) - @test @varname(x.a[1]) == @varname(x.a[1]) - @test @varname(x.a[1]) != @varname(x.a[2]) - @test @varname(x.a[1]) != @varname(x.b[1]) + @testset "equality and hash" begin + function check_doubleeq_and_hash(vn1, vn2, is_equal) + if is_equal + @test vn1 == vn2 + @test hash(vn1) == hash(vn2) + else + @test vn1 != vn2 + @test hash(vn1) != hash(vn2) + end + end + check_doubleeq_and_hash(@varname(x), @varname(x), true) + check_doubleeq_and_hash(@varname(x), @varname(y), false) + check_doubleeq_and_hash(@varname(x[1]), @varname(x[1]), true) + check_doubleeq_and_hash(@varname(x[1]), @varname(x[2]), false) + check_doubleeq_and_hash(@varname(x.a), @varname(x.a), true) + check_doubleeq_and_hash(@varname(x.a), @varname(x.b), false) + check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.a[1]), true) + check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.a[2]), false) + check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.b[1]), false) + + @testset "dynamic indices" begin + check_doubleeq_and_hash(@varname(x[begin]), @varname(x[begin]), true) + check_doubleeq_and_hash(@varname(x[end]), @varname(x[end]), true) + check_doubleeq_and_hash(@varname(x[begin]), @varname(x[end]), false) + check_doubleeq_and_hash(@varname(x[begin + 1]), @varname(x[begin + 1]), true) + check_doubleeq_and_hash(@varname(x[begin + 1]), @varname(x[(begin + 1)]), true) + check_doubleeq_and_hash(@varname(x[begin + 1]), @varname(x[begin + 2]), false) + check_doubleeq_and_hash(@varname(x[end - 1]), @varname(x[end - 1]), true) + check_doubleeq_and_hash(@varname(x[end - 1]), @varname(x[end - 2]), false) + check_doubleeq_and_hash( + @varname(x[(begin * end - begin):end]), + @varname(x[((begin * end) - begin):end]), + true, + ) + end end @testset "pretty-printing" begin From 0fe06bb09011ec85f4f15fe1bb6e7b1c0ca19c7c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 19:41:00 +0000 Subject: [PATCH 16/60] typo --- src/varname/varname.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 5ed67b91..623ca8ca 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -360,5 +360,5 @@ This fails for all other optics. """ optic_to_varname(optic::Property{sym}) where {sym} = VarName{sym}(otail(optic)) function optic_to_varname(::AbstractOptic) - throw(ArgumentError("to_varname: can only convert Property optics to VarName")) + throw(ArgumentError("optic_to_varname: can only convert Property optics to VarName")) end From e9b1f9883a2bdd926c9b5377edce3ead89c914ba Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 21:15:49 +0000 Subject: [PATCH 17/60] implement get and set --- HISTORY.md | 2 +- docs/src/varname.md | 21 +++++++++++++ src/AbstractPPL.jl | 43 +++++++++++++------------ src/varname/optic.jl | 64 ++++++++++++++++--------------------- test/Project.toml | 1 + test/varname/optic.jl | 73 ++++++++++++++++++++++++++++++++++++++++--- 6 files changed, 141 insertions(+), 63 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 13c6bd4f..24821dd7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,7 +8,7 @@ Much of the external API for traversing and manipulating `VarName`s has been pre The `optic` field of VarName now uses our hand-rolled optic types, which are subtypes of `AbstractPPL.AbstractOptic`. Previously these were optics from Accessors.jl. -This change was made for two reasons: firstly, it is easier to provide custom behaviour for VarNames as we avoid running into possible type piracy issues, and secondly, the linked-list data structure used in `AbstractOptic` is easier to work with than Accessors.jl, which used `Base.ComposedFunction` to represent optic compositions and required a lot of care to avoid issues with associativity and identity optics. +This change was made for two reasons: firstly, it is easier to provide custom behaviour for VarNames as we avoid running into possible type piracy issues, and secondly, the linked-list data structure used in `AbstractOptic` is easier to work with than Accessors.jl, which used `Base.ComposedFunction` to represent optic compositions and required a lot of care to avoid a litany of issues with associativity and identity optics (see e.g. https://github.com/JuliaLang/julia/pull/54877). To construct an optic, the easiest way is to use the `@opticof` macro, which superficially behaves similarly to `Accessors.@optic` (for example, you can write `@opticof _[1].y.z`), but also supports automatic concretization by passing a second parameter (just like `@varname`). diff --git a/docs/src/varname.md b/docs/src/varname.md index 66a8b72f..20df1737 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -101,6 +101,27 @@ optic = @opticof(_.a[1]) @opticof ``` +## Getting and setting + +Optics are callable structs, and when passed a value will extract the relevant part of that +value. + +```@example vn +data = (a=[10, 20, 30], b="hello") +optic = @opticof(_.a[2]) +optic(data) +``` + +You can set values using `Accessors.set` (which AbstractPPL re-exports). +Note, though, that this will not mutate the original value. +Furthermore, you cannot use the handy macros like `Accessors.@set`, since those will use the +optics from Accessors.jl. + +```@example vn +new_data = set(data, optic, 99) +new_data, data +``` + ## Composing and decomposing optics If you have two optics, you can compose them using the `∘` operator: diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index f657131d..302d2439 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -1,25 +1,5 @@ module AbstractPPL -# Optics -export AbstractOptic, - Iden, - Index, - Property, - ohead, - otail, - olast, - oinit, - # VarName - VarName, - getsym, - getoptic, - concretize, - is_dynamic, - @varname, - @opticof, - varname_to_optic, - optic_to_varname - # subsumes, # subsumedby, # index_to_dict, @@ -51,4 +31,27 @@ include("varname/varname.jl") # include("varname/prefix.jl") # include("varname/serialize.jl") +# Optics +export AbstractOptic, + Iden, + Index, + Property, + ohead, + otail, + olast, + oinit, + # VarName + VarName, + getsym, + getoptic, + concretize, + is_dynamic, + @varname, + @opticof, + varname_to_optic, + optic_to_varname + +using Accessors: set +export set + end # module diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 9ebedb0b..1ea3a3b8 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -7,22 +7,6 @@ using MacroTools: MacroTools An abstract type that represents the non-symbol part of a VarName, i.e., the section of the variable that is of interest. For example, in `x.a[1][2]`, the `AbstractOptic` represents the `.a[1][2]` part. - -# Public interface - -TODO - -- Base.show -- Base.:(==), Base.isequal -- Base.:(∘) (composition) -- ohead, otail, olast, oinit (decomposition) - -- to_accessors(optic) -> Accessors.Lens (recovering the old representation) -- is_dynamic(optic) -> Bool (whether the optic contains any dynamic indices) -- concretize(optic, val) -> AbstractOptic (resolving any dynamic indices given the value) - -We probably want to introduce getters and setters. See e.g. -https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/ """ abstract type AbstractOptic end function Base.show(io::IO, optic::AbstractOptic) @@ -39,9 +23,10 @@ It is also the base case for composing optics. """ struct Iden <: AbstractOptic end _pretty_print_optic(::IO, ::Iden) = nothing -to_accessors(::Iden) = identity is_dynamic(::Iden) = false concretize(i::Iden, ::Any) = i +(::Iden)(obj) = obj +Accessors.set(obj::Any, ::Iden, val) = Accessors.set(obj, identity, val) """ DynamicIndex @@ -59,7 +44,7 @@ For example: - the index `begin` is turned into `DynamicIndex(:begin, (val) -> Base.firstindex(val))`. - the index `1:end` is turned into `DynamicIndex(:(1:end), (val) -> 1:Base.lastindex(val))`. -# Stored `Expr` +# The stored `Expr` The `expr` field stores the original expression and is used both for pretty-printing as well as comparisons. @@ -177,19 +162,23 @@ function _pretty_print_optic(io::IO, idx::Index) print(io, "[$(ixs)]") return _pretty_print_optic(io, idx.child) end -function to_accessors(idx::Index) - ilens = Accessors.IndexLens(idx.ix) - return if idx.child isa Iden - ilens - else - Base.ComposedFunction(to_accessors(idx.child), ilens) - end -end is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) function concretize(idx::Index, val) - concretized_indices = map(Base.Fix2(_concretize_index, val), idx.ix) - inner_concretized = concretize(idx.child, view(val, concretized_indices...)) - return Index((concretized_indices...,), inner_concretized) + concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) + inner_concretized = concretize(idx.child, val[concretized_indices...]) + return Index(concretized_indices, inner_concretized) +end +function (idx::Index)(obj) + cidx = concretize(idx, obj) + return cidx.child(obj[cidx.ix...]) +end +function Accessors.set(obj, idx::Index, newval) + cidx = concretize(idx, obj) + inner_obj = obj[cidx.ix...] + inner_newval = Accessors.set(inner_obj, idx.child, newval) + # Defer to Accessors' implementation, so that we don't have to reinvent the wheel + # (well, not more than what we have already done...) + return Accessors.set(obj, Accessors.IndexLens(cidx.ix), inner_newval) end """ @@ -212,19 +201,20 @@ function _pretty_print_optic(io::IO, prop::Property{sym}) where {sym} print(io, ".$(sym)") return _pretty_print_optic(io, prop.child) end -function to_accessors(prop::Property{sym}) where {sym} - plens = Accessors.PropertyLens{sym}() - return if prop.child isa Iden - plens - else - Base.ComposedFunction(to_accessors(prop.child), plens) - end -end is_dynamic(prop::Property) = is_dynamic(prop.child) function concretize(prop::Property{sym}, val) where {sym} inner_concretized = concretize(prop.child, getproperty(val, sym)) return Property{sym}(inner_concretized) end +function (prop::Property{sym})(obj) where {sym} + return prop.child(getproperty(obj, sym)) +end +function Accessors.set(obj, prop::Property{sym}, newval) where {sym} + inner_obj = getproperty(obj, sym) + inner_newval = Accessors.set(inner_obj, prop.child, newval) + # Defer to Accessors' implementation again. + return Accessors.set(obj, Accessors.PropertyLens{sym}(), inner_newval) +end """ ∘(outer::AbstractOptic, inner::AbstractOptic) diff --git a/test/Project.toml b/test/Project.toml index 170c644f..d538d4db 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" diff --git a/test/varname/optic.jl b/test/varname/optic.jl index 933c4976..b165ea86 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -1,6 +1,7 @@ module OpticTests using Test +using DimensionalData: DimensionalData as DD using AbstractPPL @testset "varname/optic.jl" verbose = true begin @@ -10,19 +11,81 @@ using AbstractPPL @testset "composition" begin @testset "with identity" begin i = AbstractPPL.Iden() - o = getoptic(@varname(x.a.b)) + o = @opticof(_.a.b) @test i ∘ i == i @test i ∘ o == o @test o ∘ i == o end - o1 = getoptic(@varname(x.a.b)) - o2 = getoptic(@varname(x[1][2])) - @test o1 ∘ o2 == getoptic(@varname(x[1][2].a.b)) - @test o2 ∘ o1 == getoptic(@varname(x.a.b[1][2])) + o1 = @opticof(_.a.b) + o2 = @opticof(_[1][2]) + @test o1 ∘ o2 == @opticof(_[1][2].a.b) + @test o2 ∘ o1 == @opticof(_.a.b[1][2]) + @test cat(o1, o2) == @opticof(_.a.b[1][2]) + @test cat(o2, o1) == @opticof(_[1][2].a.b) + @test cat(o1, o2, o2, o1) == @opticof(_.a.b[1][2][1][2].a.b) end + # TODO @testset "decomposition" begin end + + @testset "getting and setting" begin + @testset "basic" begin + v = (a=(b=42, c=3.14), d=[0.0 1.0; 2.0 3.0]) + @test @opticof(_.a)(v) == v.a + @test set(v, @opticof(_.a), nothing) == (a=nothing, d=v.d) + @test @opticof(_.a.b)(v) == v.a.b + @test set(v, @opticof(_.a.b), 100) == (a=(b=100, c=v.a.c), d=v.d) + @test @opticof(_.a.c)(v) == v.a.c + @test set(v, @opticof(_.a.c), 2.71) == (a=(b=v.a.b, c=2.71), d=v.d) + @test @opticof(_.d)(v) == v.d + @test set(v, @opticof(_.d), zeros(2, 2)) == (a=v.a, d=zeros(2, 2)) + @test @opticof(_.d[1])(v) == v.d[1] + @test set(v, @opticof(_.d[1]), 9.0) == (a=v.a, d=[9.0 1.0; 2.0 3.0]) + @test @opticof(_.d[2])(v) == v.d[2] + @test set(v, @opticof(_.d[2]), 9.0) == (a=v.a, d=[0.0 1.0; 9.0 3.0]) + @test @opticof(_.d[3])(v) == v.d[3] + @test set(v, @opticof(_.d[3]), 9.0) == (a=v.a, d=[0.0 9.0; 2.0 3.0]) + @test @opticof(_.d[4])(v) == v.d[4] + @test set(v, @opticof(_.d[4]), 9.0) == (a=v.a, d=[0.0 1.0; 2.0 9.0]) + @test @opticof(_.d[:])(v) == v.d[:] + @test set(v, @opticof(_.d[:]), fill(9.9, 2, 2)) == (a=v.a, d=fill(9.9, 2, 2)) + end + + @testset "dynamic indices" begin + x = [0.0 1.0; 2.0 3.0] + @test @opticof(_[begin])(x) == x[begin] + @test set(x, @opticof(_[begin]), 9.0) == [9.0 1.0; 2.0 3.0] + @test @opticof(_[end])(x) == x[end] + @test set(x, @opticof(_[end]), 9.0) == [0.0 1.0; 2.0 9.0] + @test @opticof(_[1:end, 2])(x) == x[1:end, 2] + @test set(x, @opticof(_[1:end, 2]), [9.0; 8.0]) == [0.0 9.0; 2.0 8.0] + end + + @testset "unusual indices" begin + x = randn(3, 3) + @test @opticof(_[1:2:4])(x) == x[1:2:4] + @test @opticof(_[CartesianIndex(1, 1)])(x) == x[CartesianIndex(1, 1)] + # `Not` is actually from InvertedIndices.jl (but re-exported by DimensionalData) + @test @opticof(_[DD.Not(3)])(x) == x[DD.Not(3)] + dimarray = DD.DimArray(randn(2, 3), (DD.X, DD.Y)) + @test @opticof(_[DD.X(1)])(dimarray) == dimarray[DD.X(1)] + # TODO(penelopeysm): This doesn't support keyword arguments to getindex yet. + # For example: + # dimarray = DD.DimArray(randn(2, 3), (:x, :y)) + # @test @opticof(_[x=1])(dimarray) == dimarray[x=1] + end + + struct SampleStruct + a::Int + b::Float64 + end + s = SampleStruct(3, 1.5) + @test @opticof(_.a)(s) == 3 + @test @opticof(_.b)(s) == 1.5 + @test set(s, @opticof(_.a), 10) == SampleStruct(10, s.b) + @test set(s, @opticof(_.b), 2.5) == SampleStruct(s.a, 2.5) + end end end # module From ea892724151d89cb5a6c0434df08547d8f25a079 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 21:44:11 +0000 Subject: [PATCH 18/60] Handle keyword arguments to getindex --- src/varname/optic.jl | 62 ++++++++++++++++++++++++++--------------- src/varname/varname.jl | 24 +++++++++++++--- test/varname/optic.jl | 11 +++++--- test/varname/varname.jl | 16 ++++++++--- 4 files changed, 79 insertions(+), 34 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 1ea3a3b8..1d1daac4 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -140,45 +140,63 @@ _concretize_index(idx::Any, ::Any) = idx _concretize_index(idx::DynamicIndex, val) = idx.f(val) """ - Index(ix, child=Iden()) + Index(ix, kw, child=Iden()) -An indexing optic representing access to indices `ix`. A VarName{:x} with this optic -represents access to `x[ix...]`. The child optic represents any further indexing or -property access after this indexing operation. +An indexing optic representing access to indices `ix`, which may also take the form of +keyword arguments `kw`. A VarName{:x} with this optic represents access to `x[ix..., +kw...]`. The child optic represents any further indexing or property access after this +indexing operation. """ -struct Index{I<:Tuple,C<:AbstractOptic} <: AbstractOptic +struct Index{I<:Tuple,N<:NamedTuple,C<:AbstractOptic} <: AbstractOptic ix::I + kw::N child::C - function Index(ix::Tuple, child::C=Iden()) where {C<:AbstractOptic} - return new{typeof(ix),C}(ix, child) + function Index(ix::Tuple, kw::NamedTuple, child::C=Iden()) where {C<:AbstractOptic} + return new{typeof(ix),typeof(kw),C}(ix, kw, child) end end -Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.child == b.child -Base.isequal(a::Index, b::Index) = a == b -Base.hash(a::Index, h::UInt) = hash((a.ix, a.child), h) +Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.kw == b.kw && a.child == b.child +function Base.isequal(a::Index, b::Index) + return isequal(a.ix, b.ix) && isequal(a.kw, b.kw) && isequal(a.child, b.child) +end +Base.hash(a::Index, h::UInt) = hash((a.ix, a.kw, a.child), h) function _pretty_print_optic(io::IO, idx::Index) - ixs = join(map(_pretty_string_index, idx.ix), ", ") - print(io, "[$(ixs)]") + ixs = collect(map(_pretty_string_index, idx.ix)) + kws = map( + kv -> "$(kv.first)=$(_pretty_string_index(kv.second))", collect(pairs(idx.kw)) + ) + print(io, "[$(join(vcat(ixs, kws), ", "))]") return _pretty_print_optic(io, idx.child) end is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) function concretize(idx::Index, val) concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) - inner_concretized = concretize(idx.child, val[concretized_indices...]) - return Index(concretized_indices, inner_concretized) + inner_concretized = concretize(idx.child, val[concretized_indices..., idx.kw...]) + return Index(concretized_indices, idx.kw, inner_concretized) end function (idx::Index)(obj) cidx = concretize(idx, obj) - return cidx.child(obj[cidx.ix...]) + return cidx.child(Base.getindex(obj, cidx.ix...; cidx.kw...)) end function Accessors.set(obj, idx::Index, newval) cidx = concretize(idx, obj) - inner_obj = obj[cidx.ix...] + inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) inner_newval = Accessors.set(inner_obj, idx.child, newval) - # Defer to Accessors' implementation, so that we don't have to reinvent the wheel - # (well, not more than what we have already done...) - return Accessors.set(obj, Accessors.IndexLens(cidx.ix), inner_newval) + return if !isempty(cidx.kw) + # `Accessors.IndexLens` does not handle keyword arguments so we need to do this + # ourselves. Note that the following code essentially assumes that `obj` is an + # AbstractArray or similar type that directly implements `setindex!`. + newobj = similar(obj) + copy!(newobj, obj) + Base.setindex!(newobj, inner_newval, cidx.ix...; cidx.kw...) + newobj + else + # Defer to Accessors' implementation, so that we don't have to reinvent the wheel + # (well, not more than what we have already done...). This is helpful because + # Accessors implements a lot of methods for different types of `obj`. + Accessors.set(obj, Accessors.IndexLens(cidx.ix), inner_newval) + end end """ @@ -241,7 +259,7 @@ function Base.:(∘)(outer::AbstractOptic, inner::AbstractOptic) if inner isa Property return Property{getsym(inner)}(outer ∘ inner.child) elseif inner isa Index - return Index(inner.ix, outer ∘ inner.child) + return Index(inner.ix, inner.kw, outer ∘ inner.child) else error("unreachable; unknown AbstractOptic subtype $(typeof(inner))") end @@ -274,7 +292,7 @@ Optic() ``` """ ohead(::Property{s}) where {s} = Property{s}(Iden()) -ohead(idx::Index) = Index((idx.ix...,), Iden()) +ohead(idx::Index) = Index(idx.ix, idx.kw, Iden()) ohead(i::Iden) = i """ @@ -350,7 +368,7 @@ function oinit(idx::Index) return if idx.child isa Iden Iden() else - Index(idx.ix, oinit(idx.child)) + Index(idx.ix, idx.kw, oinit(idx.child)) end end oinit(i::Iden) = i diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 623ca8ca..abecf248 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -269,11 +269,25 @@ function _varname(expr::Expr, inner_expr) :(Property{$(sym)}($inner_expr)) elseif expr.head == :ref original_ixs = expr.args[2:end] - is_single_index = length(original_ixs) == 1 - ixs = map(enumerate(original_ixs)) do (dim, ix) - _handle_index(ix, is_single_index ? nothing : dim) + positional_args = [] + keyword_args = [] + for (dim, ix_expr) in enumerate(original_ixs) + if _is_kw(ix_expr) + push!(keyword_args, :($(ix_expr.args[1]) = $(esc(ix_expr.args[2])))) + else + push!(positional_args, (dim, ix_expr)) + end end - :(Index(tuple($(ixs...)), $inner_expr)) + is_single_index = length(positional_args) == 1 + positional_ixs = map(positional_args) do (dim, ix_expr) + _handle_index(ix_expr, is_single_index ? nothing : dim) + end + kwarg_expr = if isempty(keyword_args) + :((;)) + else + Expr(:tuple, keyword_args...) + end + :(Index(tuple($(positional_ixs...)), $kwarg_expr, $inner_expr)) else # some other expression we can't parse throw(VarNameParseException(expr)) @@ -301,6 +315,8 @@ function _handle_property(::Any, original_expr) throw(VarNameParseException(original_expr)) end +_is_kw(e::Expr) = Meta.isexpr(e, :kw, 2) +_is_kw(::Any) = false _handle_index(ix::Int, ::Any) = ix _handle_index(ix::Symbol, dim) = _make_dynamicindex_expr(ix, dim) _handle_index(ix::Expr, dim) = _make_dynamicindex_expr(ix, dim) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index b165ea86..c3a23efa 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -70,10 +70,13 @@ using AbstractPPL @test @opticof(_[DD.Not(3)])(x) == x[DD.Not(3)] dimarray = DD.DimArray(randn(2, 3), (DD.X, DD.Y)) @test @opticof(_[DD.X(1)])(dimarray) == dimarray[DD.X(1)] - # TODO(penelopeysm): This doesn't support keyword arguments to getindex yet. - # For example: - # dimarray = DD.DimArray(randn(2, 3), (:x, :y)) - # @test @opticof(_[x=1])(dimarray) == dimarray[x=1] + end + + @testset "keyword arguments to getindex" begin + dimarray = DD.DimArray([0.0 1.0; 2.0 3.0], (:x, :y)) + @test @opticof(_[x=1])(dimarray) == dimarray[x=1] + @test set(dimarray, @opticof(_[y=2]), [9.0; 8.0]) == + DD.DimArray([0.0 9.0; 2.0 8.0], (:x, :y)) end struct SampleStruct diff --git a/test/varname/varname.jl b/test/varname/varname.jl index 7deed66d..c10c5624 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -6,9 +6,10 @@ using Test @testset "varname/varname.jl" verbose = true begin @testset "basic construction (and type stability)" begin @test @varname(x) == (@inferred VarName{:x}(Iden())) - @test @varname(x[1]) == (@inferred VarName{:x}(Index((1,), Iden()))) + @test @varname(x[1]) == (@inferred VarName{:x}(Index((1,), (;), Iden()))) @test @varname(x.a) == (@inferred VarName{:x}(Property{:a}(Iden()))) - @test @varname(x.a[1]) == (@inferred VarName{:x}(Property{:a}(Index((1,), Iden())))) + @test @varname(x.a[1]) == + (@inferred VarName{:x}(Property{:a}(Index((1,), (;), Iden())))) end @testset "errors on invalid inputs" begin @@ -38,6 +39,8 @@ using Test check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.a[1]), true) check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.a[2]), false) check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.b[1]), false) + check_doubleeq_and_hash(@varname(x[1, i=2]), @varname(x[1, i=2]), true) + check_doubleeq_and_hash(@varname(x[i=2, 4]), @varname(x[4, i=2]), true) @testset "dynamic indices" begin check_doubleeq_and_hash(@varname(x[begin]), @varname(x[begin]), true) @@ -64,6 +67,7 @@ using Test @test string(@varname(x[begin])) == "x[DynamicIndex(begin)]" @test string(@varname(x[end])) == "x[DynamicIndex(end)]" @test string(@varname(x[:])) == "x[:]" + @test string(@varname(x[1, i=3])) == "x[1, i=3]" end @testset "dynamic indices and manual concretization" begin @@ -114,6 +118,7 @@ using Test @test !is_dynamic(@varname(x[1:3, 3, 2 + 9])) i = 10 @test !is_dynamic(@varname(x[1:3, 3, 2 + 9, 1:3:i])) + @test !is_dynamic(@varname(x[k=i])) end @testset "automatic concretization" begin @@ -149,8 +154,11 @@ using Test @testset "of indices" begin idx = 3 - vn = @varname(x[idx]) - @test vn == @varname(x[3]) + @test @varname(x[idx]) == @varname(x[3]) + @test @varname(x[2 * idx]) == @varname(x[6]) + @test @varname(x[1:idx]) == @varname(x[1:3]) + @test @varname(x[k=idx]) == @varname(x[k=3]) + @test @varname(x[k=2 * idx]) == @varname(x[k=6]) end @testset "with dynamic indices" begin From 6a72e8b895eab6b5d44c27b4b914ef3709c49934 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 21:49:28 +0000 Subject: [PATCH 19/60] Changelog --- HISTORY.md | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 24821dd7..5124c603 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -26,15 +26,24 @@ However, there are some differences: - Previously, AbstractPPL would refuse to allow you to construct unconcretized versions of `begin` and `end`. This is no longer the case; you can now create such VarNames in their unconcretized forms. This is useful, for example, when indexing into a chain that contains `x` as a variable-length vector. This change allows you to write `chain[@varname(x[end])]` without having AbstractPPL throw an error. -**Interface** +**Keyword arguments to `getindex`** + +VarNames can now be constructed with keyword arguments in `Index` optics, for example `@varname(x[i=1])`. +This is specifically implemented to support DimensionalData.jl's DimArrays. + +**Other interface functions** The `vsym` function (and `@vsym`) has been removed; you should use `getsym(vn)` instead. -The `Base.get` and `Base.set!` methods for VarNames have been removed (these were responsible for method ambiguities). +The `Base.get` and `Accessors.set` methods for VarNames have been removed (these were responsible for method ambiguities). +Instead of using these methods you can first convert the `VarName` to an optic using `varname_to_optic(vn)`, and then use the getter and setter methods on the optics. + +VarNames cannot be composed with optics now (compose the optics yourself). -VarNames cannot be composed with optics now (you need to compose the optics yourself). +The `inspace` function has been removed. +It used to be relevant for Turing's old Gibbs sampler; but now it no longer serves any use. -The `inspace` function has been removed (it used to be relevant for Turing's old Gibbs sampler; but now it no longer serves any use). +`ConcretizedSlice` has been removed (since colons are no longer concretized). ## 0.13.6 From 04231f07a2f40dc1c905ed0d2726c9e2c15ea65c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 22:07:23 +0000 Subject: [PATCH 20/60] make a start on subsumption --- HISTORY.md | 3 + src/AbstractPPL.jl | 2 +- src/varname/subsumes.jl | 141 +++++++-------------------------------- test/runtests.jl | 1 + test/varname.jl | 46 ------------- test/varname/subsumes.jl | 54 +++++++++++++++ 6 files changed, 84 insertions(+), 163 deletions(-) create mode 100644 test/varname/subsumes.jl diff --git a/HISTORY.md b/HISTORY.md index 5124c603..d2d8246a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -45,6 +45,9 @@ It used to be relevant for Turing's old Gibbs sampler; but now it no longer serv `ConcretizedSlice` has been removed (since colons are no longer concretized). +The subsumption interface has been pared down to just a single function, `subsumes`. +All other functions, such as `subsumedby`, `uncomparable`, and the Unicode operators, have been removed. + ## 0.13.6 Fix a missing qualifier in AbstractPPLDistributionsExt. diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 302d2439..e7ce1633 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -25,7 +25,7 @@ include("abstractprobprog.jl") include("evaluate.jl") include("varname/optic.jl") include("varname/varname.jl") -# include("varname/subsumes.jl") +include("varname/subsumes.jl") # include("varname/hasvalue.jl") # include("varname/leaves.jl") # include("varname/prefix.jl") diff --git a/src/varname/subsumes.jl b/src/varname/subsumes.jl index f43a92b4..4a3f3f3d 100644 --- a/src/varname/subsumes.jl +++ b/src/varname/subsumes.jl @@ -1,135 +1,44 @@ """ - inspace(vn::Union{VarName, Symbol}, space::Tuple) + subsumes(parent::VarName, child::VarName) -Check whether `vn`'s variable symbol is in `space`. The empty tuple counts as the "universal space" -containing all variables. Subsumption (see [`subsumes`](@ref)) is respected. - -## Examples +Check whether the variable name `child` describes a sub-range of the variable `parent`, +i.e., is contained within it. ```jldoctest -julia> inspace(@varname(x[1][2:3]), ()) -true - -julia> inspace(@varname(x[1][2:3]), (:x,)) -true - -julia> inspace(@varname(x[1][2:3]), (@varname(x),)) +julia> subsumes(@varname(x), @varname(x[1, 2])) true -julia> inspace(@varname(x[1][2:3]), (@varname(x[1:10]), :y)) -true - -julia> inspace(@varname(x[1][2:3]), (@varname(x[:][2:4]), :y)) -true - -julia> inspace(@varname(x[1][2:3]), (@varname(x[1:10]),)) +julia> subsumes(@varname(x[1, 2]), @varname(x[1, 2][3])) true ``` -""" -inspace(vn, space::Tuple{}) = true # empty tuple is treated as universal space -inspace(vn, space::Tuple) = vn in space -inspace(vn::VarName, space::Tuple{}) = true -inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space) -_in(vn::VarName, s::Symbol) = getsym(vn) == s -_in(vn::VarName, s::VarName) = subsumes(s, vn) +Note that often this is not possible to determine statically. For example: -""" - subsumes(u::VarName, v::VarName) - -Check whether the variable name `v` describes a sub-range of the variable `u`. Supported -indexing: - - - Scalar: - - ```jldoctest - julia> subsumes(@varname(x), @varname(x[1, 2])) - true - - julia> subsumes(@varname(x[1, 2]), @varname(x[1, 2][3])) - true - ``` - - - Array of scalar: basically everything that fulfills `issubset`. - - ```jldoctest - julia> subsumes(@varname(x[[1, 2], 3]), @varname(x[1, 3])) - true - - julia> subsumes(@varname(x[1:3]), @varname(x[2][1])) - true - ``` - - - Slices: - - ```jldoctest - julia> subsumes(@varname(x[2, :]), @varname(x[2, 10][1])) - true - ``` - -Currently _not_ supported are: +- When dynamic indices are present, subsumption cannot be determined, unless `child == + parent`. +- Subsumption between different forms of indexing is not supported, e.g. `x[4]` and `x[2, + 2]` are not considered to subsume each other, even though they might in practice (e.g. if + `x` is a 2x2 matrix). - - Boolean indexing, literal `CartesianIndex` (these could be added, though) - - Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for a matrix `x` - - Trailing ones: `x[2, 1]` does not subsume `x[2]` for a vector `x` +In such cases, `subsumes` will conservatively return `false`. """ function subsumes(u::VarName, v::VarName) return getsym(u) == getsym(v) && subsumes(getoptic(u), getoptic(v)) end - -# Idea behind `subsumes` for `Lens` is that we traverse the two lenses in parallel, -# checking `subsumes` for every level. This for example means that if we are comparing -# `PropertyLens{:a}` and `PropertyLens{:b}` we immediately know that they do not subsume -# each other since at the same level/depth they access different properties. -# E.g. `x`, `x[1]`, i.e. `u` is always subsumed by `t` -subsumes(::typeof(identity), ::typeof(identity)) = true -subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true -subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false - -function subsumes(t::ComposedFunction, u::ComposedFunction) - return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner) +subsumes(::Iden, ::Iden) = true +subsumes(::Iden, ::AbstractOptic) = true +subsumes(::AbstractOptic, ::Iden) = false +subsumes(t::Property{name}, u::Property{name}) where {name} = subsumes(t.child, u.child) +subsumes(t::Property, u::Property) = false +subsumes(::Property, ::Index) = false +subsumes(::Index, ::Property) = false + +function subsumes(i::Index, j::Index) + # TODO + return error("Not implemented.") end -# If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a -# leaf of the "lens-tree". -subsumes(t::ComposedFunction, u::PropertyLens) = false -# Here we need to check if `u.inner` (i.e. the next lens to be applied from `u`) is -# subsumed by `t`, since this would mean that the rest of the composition is also subsumed -# by `t`. -subsumes(t::PropertyLens, u::ComposedFunction) = subsumes(t, u.inner) - -# For `PropertyLens` either they have the same `name` and thus they are indeed the same. -subsumes(t::PropertyLens{name}, u::PropertyLens{name}) where {name} = true -# Otherwise they represent different properties, and thus are not the same. -subsumes(t::PropertyLens, u::PropertyLens) = false - -# PropertyLens and IndexLens can't subsume each other -subsumes(::PropertyLens, ::IndexLens) = false -subsumes(::IndexLens, ::PropertyLens) = false - -# Indices subsumes if they are subindices, i.e. we just call `_issubindex`. -# FIXME: Does not support `DynamicIndexLens`. -# FIXME: Does not correctly handle cases such as `subsumes(x, x[:])` -# (but neither did old implementation). -function subsumes( - t::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, - u::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, -) - return subsumes_indices(t, u) -end - -""" - subsumedby(t, u) - -True if `t` is subsumed by `u`, i.e., if `subsumes(u, t)` is true. -""" -subsumedby(t, u) = subsumes(u, t) -uncomparable(t, u) = t ⋢ u && u ⋢ t -const ⊒ = subsumes -const ⊑ = subsumedby -const ⋣ = !subsumes -const ⋢ = !subsumedby -const ≍ = uncomparable +#= # Since expressions such as `x[:][:][:][1]` and `x[1]` are equal, # the indexing behavior must be considered jointly. @@ -237,4 +146,4 @@ subsumes_index(i, ::Colon) = error("Colons cannot be subsumed") subsumes_index(::AbstractVector, ::Colon) = error("Colons cannot be subsumed") subsumes_index(i::Colon, j) = true subsumes_index(i::AbstractVector, j) = issubset(j, i) -subsumes_index(i, j) = i == j +subsumes_index(i, j) = i == j =# diff --git a/test/runtests.jl b/test/runtests.jl index dd958eb1..e1b8e16c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ const GROUP = get(ENV, "GROUP", "All") include("abstractprobprog.jl") include("varname/optic.jl") include("varname/varname.jl") + include("varname/subsumes.jl") # include("varname.jl") # include("hasvalue.jl") end diff --git a/test/varname.jl b/test/varname.jl index 047a9dcb..88078d40 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -4,52 +4,6 @@ using OffsetArrays using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky @testset "varnames" begin - @testset "subsumption with standard indexing" begin - # x ⊑ x - @test @varname(x) ⊑ @varname(x) - @test @varname(x[1]) ⊑ @varname(x[1]) - @test @varname(x.a) ⊑ @varname(x.a) - - # x ≍ y - @test @varname(x) ≍ @varname(y) - @test @varname(x.a) ≍ @varname(y.a) - @test @varname(a.x) ≍ @varname(a.y) - @test @varname(a.x[1]) ≍ @varname(a.x.z) - @test @varname(x[1]) ≍ @varname(y[1]) - @test @varname(x[1]) ≍ @varname(x.y) - - # x ∘ ℓ ⊑ x - @test_strict_subsumption x.a x - @test_strict_subsumption x[1] x - @test_strict_subsumption x[2:2:5] x - @test_strict_subsumption x[10, 20] x - - # x ∘ ℓ₁ ⊑ x ∘ ℓ₂ ⇔ ℓ₁ ⊑ ℓ₂ - @test_strict_subsumption x.a.b x.a - @test_strict_subsumption x[1].a x[1] - @test_strict_subsumption x.a[1] x.a - @test_strict_subsumption x[1:10][2] x[1:10] - - @test_strict_subsumption x[1] x[1:10] - @test_strict_subsumption x[1:5] x[1:10] - @test_strict_subsumption x[4:6] x[1:10] - - @test_strict_subsumption x[[2, 3, 5]] x[[7, 6, 5, 4, 3, 2, 1]] - - @test_strict_subsumption x[:a][1] x[:a] - - # boolean indexing works as long as it is concretized - A = rand(10, 10) - @test @varname(A[iseven.(1:10), 1], true) ⊑ @varname(A[1:10, 1]) - @test @varname(A[iseven.(1:10), 1], true) ⋣ @varname(A[1:10, 1]) - - # we can reasonably allow colons on the right side ("universal set") - @test @varname(x[1]) ⊑ @varname(x[:]) - @test @varname(x[1:10, 1]) ⊑ @varname(x[:, 1:10]) - @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[1])) - @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[:])) - end - @testset "de/serialisation of VarNames" begin y = ones(10) z = ones(5, 2) diff --git a/test/varname/subsumes.jl b/test/varname/subsumes.jl new file mode 100644 index 00000000..16a3b07d --- /dev/null +++ b/test/varname/subsumes.jl @@ -0,0 +1,54 @@ +module VarNameSubsumesTests + +using AbstractPPL +using Test + +@testset "varname/subsumes.jl" verbose = true begin + @test subsumes(@varname(x), @varname(x)) + @test subsumes(@varname(x[1]), @varname(x[1])) + @test subsumes(@varname(x.a), @varname(x.a)) + + uncomparable(vn1, vn2) = !subsumes(vn1, vn2) && !subsumes(vn2, vn1) + @test uncomparable(@varname(x), @varname(y)) + @test uncomparable(@varname(x.a), @varname(y.a)) + @test uncomparable(@varname(a.x), @varname(a.y)) + @test uncomparable(@varname(a.x[1]), @varname(a.x.z)) + @test uncomparable(@varname(x[1]), @varname(y[1])) + @test uncomparable(@varname(x[1]), @varname(x.y)) + + strictly_subsumes(vn1, vn2) = subsumes(vn1, vn2) && !subsumes(vn2, vn1) + # Subsumption via field/indexing + @test strictly_subsumes(@varname(x), @varname(x.a)) + @test strictly_subsumes(@varname(x), @varname(x[1])) + @test strictly_subsumes(@varname(x), @varname(x[2:2:5])) + @test strictly_subsumes(@varname(x), @varname(x[10, 20])) + @test strictly_subsumes(@varname(x.a), @varname(x.a.b)) + @test strictly_subsumes(@varname(x[1]), @varname(x[1].a)) + @test strictly_subsumes(@varname(x.a), @varname(x.a[1])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:10][2])) + # Range subsumption + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:5])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[4:6])) + @test strictly_subsumes(@varname(x[1:10, 1:10]), @varname(x[1:5, 1:5])) + @test strictly_subsumes(@varname(x[[7, 6, 5, 4, 3, 2, 1]]), @varname(x[[2, 3, 5]])) + + # TODO reenable + # @test_strict_subsumption x[:a][1] x[:a] + # # boolean indexing works as long as it is concretized + # A = rand(10, 10) + # @test @varname(A[iseven.(1:10), 1], true) ⊑ @varname(A[1:10, 1]) + # @test @varname(A[iseven.(1:10), 1], true) ⋣ @varname(A[1:10, 1]) + # + # # we can reasonably allow colons on the right side ("universal set") + # @test @varname(x[1]) ⊑ @varname(x[:]) + # @test @varname(x[1:10, 1]) ⊑ @varname(x[:, 1:10]) + # @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[1])) + # @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[:])) + # + # TODO dynamic indices + # + # TODO keyword indices +end + +end # module From 6bccab02f26d2c9712424e370be07d3bb82a7f9c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 22:13:23 +0000 Subject: [PATCH 21/60] allow string/symbol indices and stuff like that --- src/AbstractPPL.jl | 6 +++--- src/varname/optic.jl | 2 ++ src/varname/subsumes.jl | 11 ++++++++--- src/varname/varname.jl | 2 +- test/varname/optic.jl | 13 +++++++++++++ 5 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index e7ce1633..38bffd6a 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -1,7 +1,5 @@ module AbstractPPL -# subsumes, -# subsumedby, # index_to_dict, # dict_to_index, # varname_to_string, @@ -49,7 +47,9 @@ export AbstractOptic, @varname, @opticof, varname_to_optic, - optic_to_varname + optic_to_varname, + # subsumes + subsumes using Accessors: set export set diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 1d1daac4..5dcb7a5a 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -134,6 +134,8 @@ end _pretty_string_index(ix) = string(ix) _pretty_string_index(::Colon) = ":" +_pretty_string_index(x::Symbol) = repr(x) +_pretty_string_index(x::String) = repr(x) _pretty_string_index(di::DynamicIndex) = "DynamicIndex($(di.expr))" _concretize_index(idx::Any, ::Any) = idx diff --git a/src/varname/subsumes.jl b/src/varname/subsumes.jl index 4a3f3f3d..03fda14a 100644 --- a/src/varname/subsumes.jl +++ b/src/varname/subsumes.jl @@ -34,12 +34,17 @@ subsumes(::Property, ::Index) = false subsumes(::Index, ::Property) = false function subsumes(i::Index, j::Index) - # TODO - return error("Not implemented.") + # TODO(penelopeysm): What we really want to do is to zip i.ix and j.ix + # and check that each index in `i.ix` subsumes the corresponding + # entry in `j.ix`. If that is true, then we can continue recursing. + return if i.ix == j.ix && i.kw == j.kw + subsumes(i.child, j.child) + else + error("Not implemented.") + end end #= - # Since expressions such as `x[:][:][:][1]` and `x[1]` are equal, # the indexing behavior must be considered jointly. # Therefore we must recurse until we reach something that is NOT diff --git a/src/varname/varname.jl b/src/varname/varname.jl index abecf248..9960e54e 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -317,7 +317,7 @@ end _is_kw(e::Expr) = Meta.isexpr(e, :kw, 2) _is_kw(::Any) = false -_handle_index(ix::Int, ::Any) = ix +_handle_index(ix::Any, ::Any) = ix _handle_index(ix::Symbol, dim) = _make_dynamicindex_expr(ix, dim) _handle_index(ix::Expr, dim) = _make_dynamicindex_expr(ix, dim) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index c3a23efa..496e8154 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -66,10 +66,23 @@ using AbstractPPL x = randn(3, 3) @test @opticof(_[1:2:4])(x) == x[1:2:4] @test @opticof(_[CartesianIndex(1, 1)])(x) == x[CartesianIndex(1, 1)] + # `Not` is actually from InvertedIndices.jl (but re-exported by DimensionalData) @test @opticof(_[DD.Not(3)])(x) == x[DD.Not(3)] + + # DimArray selectors dimarray = DD.DimArray(randn(2, 3), (DD.X, DD.Y)) @test @opticof(_[DD.X(1)])(dimarray) == dimarray[DD.X(1)] + + # Symbols on NamedTuples + nt = (a=10, b=20, c=30) + @test @opticof(_[:a])(nt) == nt[:a] + @test set(nt, @opticof(_[:b]), 99) == (a=10, b=99, c=30) + + # Strings on Dicts + dict = Dict("one" => 1, "two" => 2) + @test @opticof(_["two"])(dict) == dict["two"] + @test set(dict, @opticof(_["two"]), 22) == Dict("one" => 1, "two" => 22) end @testset "keyword arguments to getindex" begin From 0bbaf31d9d4fcc186790ecdf2c9d9a0a403a97da Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 23:43:29 +0000 Subject: [PATCH 22/60] subsumes works --- src/varname/subsumes.jl | 158 ++++++++++----------------------------- test/varname/subsumes.jl | 103 +++++++++++++++---------- 2 files changed, 102 insertions(+), 159 deletions(-) diff --git a/src/varname/subsumes.jl b/src/varname/subsumes.jl index 03fda14a..e9d1c3c3 100644 --- a/src/varname/subsumes.jl +++ b/src/varname/subsumes.jl @@ -12,15 +12,28 @@ julia> subsumes(@varname(x[1, 2]), @varname(x[1, 2][3])) true ``` -Note that often this is not possible to determine statically. For example: +This is done by recursively comparing each layer of the VarNames' optics. + +Note that often this is not possible to determine statically, and so the results should +not be over-interpreted. In particular, `Index` optics pose a problem. An `i::Index` will +only subsume `j::Index` if: + +1. They have the same number of positional indices (`i.ix` and `j.ix`); +2. Each positional index in `i` can be determined to comprise the corresponding positional + index in `j`; and +3. The keyword indices of `i` (`i.kw`) are a superset of those in `j.kw`). + +In all other cases, `subsumes` will conservatively return `false`, even though in practice +it might well be that `i` does subsume `j`. Some examples where subsumption cannot be +determined statically are: -- When dynamic indices are present, subsumption cannot be determined, unless `child == - parent`. - Subsumption between different forms of indexing is not supported, e.g. `x[4]` and `x[2, 2]` are not considered to subsume each other, even though they might in practice (e.g. if `x` is a 2x2 matrix). - -In such cases, `subsumes` will conservatively return `false`. +- When dynamic indices (that are not equal) are present. (Dynamic indices that are equal do + subsume each other.) +- Non-standard indices, e.g. `Not(4)`, `2..3`, etc. Again, these only subsume each other + when they are equal. """ function subsumes(u::VarName, v::VarName) return getsym(u) == getsym(v) && subsumes(getoptic(u), getoptic(v)) @@ -34,121 +47,30 @@ subsumes(::Property, ::Index) = false subsumes(::Index, ::Property) = false function subsumes(i::Index, j::Index) - # TODO(penelopeysm): What we really want to do is to zip i.ix and j.ix - # and check that each index in `i.ix` subsumes the corresponding - # entry in `j.ix`. If that is true, then we can continue recursing. - return if i.ix == j.ix && i.kw == j.kw - subsumes(i.child, j.child) - else - error("Not implemented.") - end + return _subsumes_positional(i.ix, j.ix) && + _subsumes_keyword(i.kw, j.kw) && + subsumes(i.child, j.child) end -#= -# Since expressions such as `x[:][:][:][1]` and `x[1]` are equal, -# the indexing behavior must be considered jointly. -# Therefore we must recurse until we reach something that is NOT -# indexing, and then consider the sequence of indices leading up to this. -""" - subsumes_indices(t, u) - -Return `true` if the indexing represented by `t` subsumes `u`. - -This is mostly useful for comparing compositions involving `IndexLens` -e.g. `_[1][2].a[2]` and `_[1][2].a`. In such a scenario we do the following: -1. Combine `[1][2]` into a `Tuple` of indices using [`combine_indices`](@ref). -2. Do the same for `[1][2]`. -3. Compare the two tuples from (1) and (2) using `subsumes_indices`. -4. Since we're still undecided, we call `subsume(@o(_.a[2]), @o(_.a))` - which then returns `false`. - -# Example -```jldoctest; setup=:(using Accessors; using AbstractPPL: subsumes_indices) -julia> t = @o(_[1].a); u = @o(_[1]); - -julia> subsumes_indices(t, u) -false - -julia> subsumes_indices(u, t) -true - -julia> # `identity` subsumes all. - subsumes_indices(identity, t) -true - -julia> # None subsumes `identity`. - subsumes_indices(t, identity) -false - -julia> AbstractPPL.subsumes(@o(_[1][2].a[2]), @o(_[1][2].a)) -false - -julia> AbstractPPL.subsumes(@o(_[1][2].a), @o(_[1][2].a[2])) -true -``` -""" -function subsumes_indices(t::ALLOWED_OPTICS, u::ALLOWED_OPTICS) - t_indices, t_next = combine_indices(t) - u_indices, u_next = combine_indices(u) - - # If we already know that `u` is not subsumed by `t`, return early. - if !subsumes_indices(t_indices, u_indices) - return false - end - - if t_next === nothing - # Means that there's nothing left for `t` and either nothing - # or something left for `u`, i.e. `t` indeed `subsumes` `u`. - return true - elseif u_next === nothing - # If `t_next` is not `nothing` but `u_next` is, then - # `t` does not subsume `u`. - return false - end - - # If neither is `nothing` we continue. - return subsumes(t_next, u_next) -end - -""" - combine_indices(optic) - -Return sequential indexing into a single `Tuple` of indices, -e.g. `x[:][1][2]` becomes `((Colon(), ), (1, ), (2, ))`. - -The result is compatible with [`subsumes_indices`](@ref) for `Tuple` input. -""" -combine_indices(optic::ALLOWED_OPTICS) = (), optic -combine_indices(optic::IndexLens) = (optic.indices,), nothing -function combine_indices(optic::ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}) - indices, next = combine_indices(optic.outer) - return (optic.inner.indices, indices...), next +function _subsumes_positional(i::Tuple, j::Tuple) + return (length(i) == length(j)) && all(_subsumes_index.(i, j)) end - -""" - subsumes_indices(left_indices::Tuple, right_indices::Tuple) - -Return `true` if `right_indices` is subsumed by `left_indices`. `left_indices` is assumed to be -concretized and consist of either `Int`s or `AbstractArray`s of scalar indices that are supported -by array A. - -Currently _not_ supported are: -- Boolean indexing, literal `CartesianIndex` (these could be added, though) -- Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for a matrix `x` -- Trailing ones: `x[2, 1]` does not subsume `x[2]` for a vector `x` -""" -subsumes_indices(::Tuple{}, ::Tuple{}) = true # x subsumes x -subsumes_indices(::Tuple{}, ::Tuple) = true # x subsumes x... -subsumes_indices(::Tuple, ::Tuple{}) = false # x... does not subsume x -function subsumes_indices(t1::Tuple, t2::Tuple) # does x[i]... subsume x[j]...? - first_subsumed = all(Base.splat(subsumes_index), zip(first(t1), first(t2))) - return first_subsumed && subsumes_indices(Base.tail(t1), Base.tail(t2)) +function _subsumes_keyword(i::NamedTuple{f1}, j::NamedTuple{f2}) where {f1,f2} + for name in f2 + if !(name in f1) || !(_subsumes_index(i[name], j[name])) + return false + end + end + return true end -subsumes_index(i::Colon, ::Colon) = error("Colons cannot be subsumed") -subsumes_index(i, ::Colon) = error("Colons cannot be subsumed") -# Necessary to avoid ambiguity errors. -subsumes_index(::AbstractVector, ::Colon) = error("Colons cannot be subsumed") -subsumes_index(i::Colon, j) = true -subsumes_index(i::AbstractVector, j) = issubset(j, i) -subsumes_index(i, j) = i == j =# +_subsumes_index(a::DynamicIndex, b::DynamicIndex) = a == b +_subsumes_index(a::DynamicIndex, ::Any) = false +_subsumes_index(::DynamicIndex, ::Colon) = false +_subsumes_index(::Colon, ::DynamicIndex) = true +_subsumes_index(::Any, ::DynamicIndex) = false +_subsumes_index(::Colon, ::Any) = true +_subsumes_index(::Any, ::Colon) = false +_subsumes_index(a::AbstractVector, b::Any) = issubset(b, a) +_subsumes_index(a::AbstractVector, b::Colon) = false +_subsumes_index(a::Any, b::Any) = a == b diff --git a/test/varname/subsumes.jl b/test/varname/subsumes.jl index 16a3b07d..90d5dd2b 100644 --- a/test/varname/subsumes.jl +++ b/test/varname/subsumes.jl @@ -4,51 +4,72 @@ using AbstractPPL using Test @testset "varname/subsumes.jl" verbose = true begin - @test subsumes(@varname(x), @varname(x)) - @test subsumes(@varname(x[1]), @varname(x[1])) - @test subsumes(@varname(x.a), @varname(x.a)) + @testset "varnames that are equal" begin + @test subsumes(@varname(x), @varname(x)) + @test subsumes(@varname(x[1]), @varname(x[1])) + @test subsumes(@varname(x.a), @varname(x.a)) + end uncomparable(vn1, vn2) = !subsumes(vn1, vn2) && !subsumes(vn2, vn1) - @test uncomparable(@varname(x), @varname(y)) - @test uncomparable(@varname(x.a), @varname(y.a)) - @test uncomparable(@varname(a.x), @varname(a.y)) - @test uncomparable(@varname(a.x[1]), @varname(a.x.z)) - @test uncomparable(@varname(x[1]), @varname(y[1])) - @test uncomparable(@varname(x[1]), @varname(x.y)) + @testset "uncomparable varnames" begin + @test uncomparable(@varname(x), @varname(y)) + @test uncomparable(@varname(x.a), @varname(y.a)) + @test uncomparable(@varname(a.x), @varname(a.y)) + @test uncomparable(@varname(a.x[1]), @varname(a.x.z)) + @test uncomparable(@varname(x[1]), @varname(y[1])) + @test uncomparable(@varname(x[1]), @varname(x.y)) + end strictly_subsumes(vn1, vn2) = subsumes(vn1, vn2) && !subsumes(vn2, vn1) - # Subsumption via field/indexing - @test strictly_subsumes(@varname(x), @varname(x.a)) - @test strictly_subsumes(@varname(x), @varname(x[1])) - @test strictly_subsumes(@varname(x), @varname(x[2:2:5])) - @test strictly_subsumes(@varname(x), @varname(x[10, 20])) - @test strictly_subsumes(@varname(x.a), @varname(x.a.b)) - @test strictly_subsumes(@varname(x[1]), @varname(x[1].a)) - @test strictly_subsumes(@varname(x.a), @varname(x.a[1])) - @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:10][2])) - # Range subsumption - @test strictly_subsumes(@varname(x[1:10]), @varname(x[1])) - @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:5])) - @test strictly_subsumes(@varname(x[1:10]), @varname(x[4:6])) - @test strictly_subsumes(@varname(x[1:10, 1:10]), @varname(x[1:5, 1:5])) - @test strictly_subsumes(@varname(x[[7, 6, 5, 4, 3, 2, 1]]), @varname(x[[2, 3, 5]])) - - # TODO reenable - # @test_strict_subsumption x[:a][1] x[:a] - # # boolean indexing works as long as it is concretized - # A = rand(10, 10) - # @test @varname(A[iseven.(1:10), 1], true) ⊑ @varname(A[1:10, 1]) - # @test @varname(A[iseven.(1:10), 1], true) ⋣ @varname(A[1:10, 1]) - # - # # we can reasonably allow colons on the right side ("universal set") - # @test @varname(x[1]) ⊑ @varname(x[:]) - # @test @varname(x[1:10, 1]) ⊑ @varname(x[:, 1:10]) - # @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[1])) - # @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[:])) - # - # TODO dynamic indices - # - # TODO keyword indices + @testset "strict subsumption - no index comparisons" begin + @test strictly_subsumes(@varname(x), @varname(x.a)) + @test strictly_subsumes(@varname(x), @varname(x[1])) + @test strictly_subsumes(@varname(x), @varname(x[2:2:5])) + @test strictly_subsumes(@varname(x), @varname(x[10, 20])) + @test strictly_subsumes(@varname(x.a), @varname(x.a.b)) + @test strictly_subsumes(@varname(x[1]), @varname(x[1].a)) + @test strictly_subsumes(@varname(x.a), @varname(x.a[1])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:10][2])) + end + + @testset "strict subsumption - index comparisons" begin + @testset "integer vectors" begin + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:5])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[4:6])) + @test strictly_subsumes(@varname(x[1:10, 1:10]), @varname(x[1:5, 1:5])) + @test strictly_subsumes(@varname(x[[5, 4, 3, 2, 1]]), @varname(x[[2, 4]])) + end + + @testset "non-integer indices" begin + @test strictly_subsumes(@varname(x[:a]), @varname(x[:a][1])) + end + + @testset "colon" begin + @test strictly_subsumes(@varname(x[:]), @varname(x[1])) + @test strictly_subsumes(@varname(x[:, 1:10]), @varname(x[1:10, 1])) + end + + @testset "dynamic indices" begin + @test strictly_subsumes(@varname(x), @varname(x[begin])) + @test subsumes(@varname(x[begin]), @varname(x[begin])) + @test strictly_subsumes(@varname(x[:]), @varname(x[begin])) + @test strictly_subsumes(@varname(x), @varname(x[end])) + @test subsumes(@varname(x[end]), @varname(x[end])) + @test strictly_subsumes(@varname(x[:]), @varname(x[end])) + @test strictly_subsumes(@varname(x[:]), @varname(x[1:end])) + @test strictly_subsumes(@varname(x[:]), @varname(x[end - 3])) + end + + @testset "keyword indices" begin + @test strictly_subsumes(@varname(x), @varname(x[a=1])) + @test strictly_subsumes(@varname(x[a=1:10, b=1:10]), @varname(x[a=1:10])) + @test strictly_subsumes(@varname(x[a=1:10, b=1:10]), @varname(x[a=1:5, b=1:5])) + @test strictly_subsumes(@varname(x[a=:]), @varname(x[a=1])) + @test uncomparable(@varname(x[a=1:10, b=5]), @varname(x[a=5, b=1:10])) + @test uncomparable(@varname(x[a=1]), @varname(x[b=1])) + end + end end end # module From c1036024b057e1ffc4246e6d61dad19134031718 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 23:52:38 +0000 Subject: [PATCH 23/60] use views where possible, add more tests --- src/varname/optic.jl | 7 ++++++- test/varname.jl | 48 ------------------------------------------- test/varname/optic.jl | 46 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 51 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 5dcb7a5a..ad80382b 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -172,9 +172,14 @@ function _pretty_print_optic(io::IO, idx::Index) return _pretty_print_optic(io, idx.child) end is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) +# Things like dictionaries can't be `view`ed into. +_maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) +_maybe_view(val, i...; k...) = getindex(val, i...; k...) function concretize(idx::Index, val) concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) - inner_concretized = concretize(idx.child, val[concretized_indices..., idx.kw...]) + inner_concretized = concretize( + idx.child, _maybe_view(val, concretized_indices...; idx.kw...) + ) return Index(concretized_indices, idx.kw, inner_concretized) end function (idx::Index)(obj) diff --git a/test/varname.jl b/test/varname.jl index 88078d40..5818f0a9 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -72,54 +72,6 @@ using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky @test string_to_varname(varname_to_string(vn)) == vn end - @testset "head, tail, init, last" begin - @testset "specification" begin - @test AbstractPPL._head(@o _.a.b.c) == @o _.a - @test AbstractPPL._tail(@o _.a.b.c) == @o _.b.c - @test AbstractPPL._init(@o _.a.b.c) == @o _.a.b - @test AbstractPPL._last(@o _.a.b.c) == @o _.c - - @test AbstractPPL._head(@o _[1][2][3]) == @o _[1] - @test AbstractPPL._tail(@o _[1][2][3]) == @o _[2][3] - @test AbstractPPL._init(@o _[1][2][3]) == @o _[1][2] - @test AbstractPPL._last(@o _[1][2][3]) == @o _[3] - - @test AbstractPPL._head(@o _.a) == @o _.a - @test AbstractPPL._tail(@o _.a) == identity - @test AbstractPPL._init(@o _.a) == identity - @test AbstractPPL._last(@o _.a) == @o _.a - - @test AbstractPPL._head(@o _[1]) == @o _[1] - @test AbstractPPL._tail(@o _[1]) == identity - @test AbstractPPL._init(@o _[1]) == identity - @test AbstractPPL._last(@o _[1]) == @o _[1] - - @test AbstractPPL._head(identity) == identity - @test AbstractPPL._tail(identity) == identity - @test AbstractPPL._init(identity) == identity - @test AbstractPPL._last(identity) == identity - end - - @testset "composition" begin - varnames = ( - @varname(x), - @varname(x[1]), - @varname(x.a), - @varname(x.a.b), - @varname(x[1].a), - ) - for vn in varnames - optic = getoptic(vn) - @test AbstractPPL.normalise( - AbstractPPL._last(optic) ∘ AbstractPPL._init(optic) - ) == optic - @test AbstractPPL.normalise( - AbstractPPL._tail(optic) ∘ AbstractPPL._head(optic) - ) == optic - end - end - end - @testset "prefix and unprefix" begin @testset "basic cases" begin @test prefix(@varname(y), @varname(x)) == @varname(x.y) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index 496e8154..d6c7b502 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -26,8 +26,50 @@ using AbstractPPL @test cat(o1, o2, o2, o1) == @opticof(_.a.b[1][2][1][2].a.b) end - # TODO - @testset "decomposition" begin end + @testset "decomposition" begin + @testset "specification" begin + @test ohead(@opticof _.a.b.c) == @opticof _.a + @test otail(@opticof _.a.b.c) == @opticof _.b.c + @test oinit(@opticof _.a.b.c) == @opticof _.a.b + @test olast(@opticof _.a.b.c) == @opticof _.c + + @test ohead(@opticof _[1][2][3]) == @opticof _[1] + @test otail(@opticof _[1][2][3]) == @opticof _[2][3] + @test oinit(@opticof _[1][2][3]) == @opticof _[1][2] + @test olast(@opticof _[1][2][3]) == @opticof _[3] + + @test ohead(@opticof _.a) == @opticof _.a + @test otail(@opticof _.a) == @opticof _ + @test oinit(@opticof _.a) == @opticof _ + @test olast(@opticof _.a) == @opticof _.a + + @test ohead(@opticof _[1]) == @opticof _[1] + @test otail(@opticof _[1]) == @opticof _ + @test oinit(@opticof _[1]) == @opticof _ + @test olast(@opticof _[1]) == @opticof _[1] + + @test ohead(@opticof _) == @opticof _ + @test otail(@opticof _) == @opticof _ + @test oinit(@opticof _) == @opticof _ + @test olast(@opticof _) == @opticof _ + end + + @testset "invariants" begin + optics = ( + @opticof(_), + @opticof(_[1]), + @opticof(_.a), + @opticof(_.a.b), + @opticof(_[1].a), + @opticof(_[1, x=1].a), + @opticof(_[].a[:]), + ) + for optic in optics + @test olast(optic) ∘ oinit(optic) == optic + @test otail(optic) ∘ ohead(optic) == optic + end + end + end @testset "getting and setting" begin @testset "basic" begin From 8555095eea3abd9b518f32f6b71549c4231e3f93 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 23 Dec 2025 23:55:41 +0000 Subject: [PATCH 24/60] docs --- HISTORY.md | 2 +- docs/src/varname.md | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index d2d8246a..d7b55c83 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,7 +1,7 @@ ## 0.14.0 This release overhauls the `VarName` type. -Much of the external API for traversing and manipulating `VarName`s has been preserved, but there are significant changes: +Much of the external API for traversing and manipulating `VarName`s (once they have been constructed) has been preserved, but if you use the `VarName` type directly, there are significant changes. **Internal representation** diff --git a/docs/src/varname.md b/docs/src/varname.md index 20df1737..a2ca709d 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -189,3 +189,18 @@ This can be achieved with: varname_to_optic optic_to_varname ``` + +## Subsumption + +Sometimes, we want to check whether one VarName 'subsumes' another; that is, whether a VarName refers to a part of another VarName. +This is done using the [`subsumes`](@ref) function: + +```@example vn +vn1 = @varname(x.a) +vn2 = @varname(x.a[1]) +subsumes(vn1, vn2) +``` + +```@docs +subsumes +``` From b2b23b996aace9a1765c7d3f52cde9e32f649ed2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 00:02:25 +0000 Subject: [PATCH 25/60] prefixing works --- src/AbstractPPL.jl | 8 ++-- src/varname/prefix.jl | 91 ++++++++++-------------------------------- test/varname.jl | 43 -------------------- test/varname/prefix.jl | 45 +++++++++++++++++++++ 4 files changed, 71 insertions(+), 116 deletions(-) create mode 100644 test/varname/prefix.jl diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 38bffd6a..0688297f 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -26,7 +26,7 @@ include("varname/varname.jl") include("varname/subsumes.jl") # include("varname/hasvalue.jl") # include("varname/leaves.jl") -# include("varname/prefix.jl") +include("varname/prefix.jl") # include("varname/serialize.jl") # Optics @@ -48,8 +48,10 @@ export AbstractOptic, @opticof, varname_to_optic, optic_to_varname, - # subsumes - subsumes + # other functions + subsumes, + prefix, + unprefix using Accessors: set export set diff --git a/src/varname/prefix.jl b/src/varname/prefix.jl index 25d4485b..943ba5c9 100644 --- a/src/varname/prefix.jl +++ b/src/varname/prefix.jl @@ -1,63 +1,12 @@ -### Functionality for prefixing and unprefixing VarNames. - -""" - optic_to_vn(optic) - -Convert an Accessors optic to a VarName. This is best explained through -examples. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL.optic_to_vn(Accessors.@o _.a) -a - -julia> AbstractPPL.optic_to_vn(Accessors.@o _.a.b) -a.b - -julia> AbstractPPL.optic_to_vn(Accessors.@o _.a[1]) -a[1] -``` - -The outermost layer of the optic (technically, what Accessors.jl calls the -'innermost') must be a `PropertyLens`, or else it will fail. This is because a -VarName needs to have a symbol. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL.optic_to_vn(Accessors.@o _[1]) -ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarName -[...] -``` -""" -function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} - return VarName{sym}() -end -function optic_to_vn( - o::ComposedFunction{Outer,Accessors.PropertyLens{sym}} -) where {Outer,sym} - return VarName{sym}(o.outer) -end -optic_to_vn(o::ComposedFunction) = optic_to_vn(normalise(o)) -function optic_to_vn(@nospecialize(o)) - msg = "optic_to_vn: could not convert optic `$o` to a VarName" - throw(ArgumentError(msg)) -end - -unprefix_optic(o, ::typeof(identity)) = o # Base case -function unprefix_optic(optic, optic_prefix) - # Technically `unprefix_optic` only receives optics that were part of - # VarNames, so the optics should already be normalised (in the inner - # constructor of the VarName). However I guess it doesn't hurt to do it - # again to be safe. - optic = normalise(optic) - optic_prefix = normalise(optic_prefix) - # strip one layer of the optic and check for equality - head = _head(optic) - head_prefix = _head(optic_prefix) +_unprefix_optic(o, ::Iden) = o +function _unprefix_optic(optic, optic_prefix) + head = ohead(optic) + head_prefix = ohead(optic_prefix) if head != head_prefix - msg = "could not remove prefix $(optic_prefix) from optic $(optic)" + msg = "cannot remove prefix $(optic_prefix) from optic $(optic)" throw(ArgumentError(msg)) end - # recurse - return unprefix_optic(_tail(optic), _tail(optic_prefix)) + return _unprefix_optic(otail(optic), otail(optic_prefix)) end """ @@ -66,17 +15,21 @@ end Remove a prefix from a VarName. ```jldoctest -julia> AbstractPPL.unprefix(@varname(y.x), @varname(y)) +julia> unprefix(@varname(y.x), @varname(y)) x -julia> AbstractPPL.unprefix(@varname(y.x.a), @varname(y)) +julia> unprefix(@varname(y.x.a), @varname(y)) x.a -julia> AbstractPPL.unprefix(@varname(y[1].x), @varname(y[1])) +julia> unprefix(@varname(y[1].x), @varname(y[1])) x -julia> AbstractPPL.unprefix(@varname(y), @varname(n)) -ERROR: ArgumentError: could not remove prefix n from VarName y +julia> unprefix(@varname(y), @varname(n)) +ERROR: ArgumentError: cannot remove prefix n from VarName y +[...] + +julia> unprefix(@varname(y[1]), @varname(y)) +ERROR: ArgumentError: optic_to_varname: can only convert Property optics to VarName [...] ``` """ @@ -84,12 +37,12 @@ function unprefix( vn::VarName{sym_vn}, prefix::VarName{sym_prefix} ) where {sym_vn,sym_prefix} if sym_vn != sym_prefix - msg = "could not remove prefix $(prefix) from VarName $(vn)" + msg = "cannot remove prefix $(prefix) from VarName $(vn)" throw(ArgumentError(msg)) end optic_vn = getoptic(vn) optic_prefix = getoptic(prefix) - return optic_to_vn(unprefix_optic(optic_vn, optic_prefix)) + return optic_to_varname(_unprefix_optic(optic_vn, optic_prefix)) end """ @@ -98,19 +51,17 @@ end Add a prefix to a VarName. ```jldoctest -julia> AbstractPPL.prefix(@varname(x), @varname(y)) +julia> prefix(@varname(x), @varname(y)) y.x -julia> AbstractPPL.prefix(@varname(x.a), @varname(y)) +julia> prefix(@varname(x.a), @varname(y)) y.x.a -julia> AbstractPPL.prefix(@varname(x.a), @varname(y[1])) +julia> prefix(@varname(x.a), @varname(y[1])) y[1].x.a ``` """ function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix} - optic_vn = getoptic(vn) - optic_prefix = getoptic(prefix) - new_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() ∘ optic_prefix + new_optic_vn = varname_to_optic(vn) ∘ getoptic(prefix) return VarName{sym_prefix}(new_optic_vn) end diff --git a/test/varname.jl b/test/varname.jl index 5818f0a9..af0f8967 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -72,49 +72,6 @@ using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky @test string_to_varname(varname_to_string(vn)) == vn end - @testset "prefix and unprefix" begin - @testset "basic cases" begin - @test prefix(@varname(y), @varname(x)) == @varname(x.y) - @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) - @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) - @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) - @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) - - @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) - @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) - @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) - @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) - end - - @testset "round-trip" begin - # These seem similar to the ones above, but in the past they used - # to error because of issues with un-normalised ComposedFunction - # optics. We explicitly test round-trip (un)prefixing here to make - # sure that there aren't any regressions. - # This tuple is probably overkill, but the tests are super fast - # anyway. - vns = ( - @varname(p), - @varname(q), - @varname(r[1]), - @varname(s.a), - @varname(t[1].a), - @varname(u[1].a.b), - @varname(v.a[1][2].b.c.d[3]) - ) - for vn1 in vns - for vn2 in vns - prefixed = prefix(vn1, vn2) - @test subsumes(vn2, prefixed) - unprefixed = unprefix(prefixed, vn2) - @test unprefixed == vn1 - end - end - end - end - @testset "varname{_and_value}_leaves" begin @testset "single value: float, int" begin x = 1.0 diff --git a/test/varname/prefix.jl b/test/varname/prefix.jl new file mode 100644 index 00000000..2276d5bc --- /dev/null +++ b/test/varname/prefix.jl @@ -0,0 +1,45 @@ +module VarNamePrefixTests + +using Test +using AbstractPPL + +@testset "varname/prefix.jl" verbose = true begin + @testset "basic cases" begin + @test prefix(@varname(y), @varname(x)) == @varname(x.y) + @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) + @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) + @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) + @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) + + @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) + @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) + @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) + @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) + end + + @testset "round-trip + type stability" begin + # This tuple is probably overkill, but the tests are super fast + # anyway. + vns = ( + @varname(p), + @varname(q), + @varname(r[1]), + @varname(s.a), + @varname(t[1].a), + @varname(u[1].a.b), + @varname(v.a[1][2].b.c.d[3]) + ) + for vn1 in vns + for vn2 in vns + prefixed = @inferred prefix(vn1, vn2) + @test subsumes(vn2, prefixed) + unprefixed = @inferred unprefix(prefixed, vn2) + @test unprefixed == vn1 + end + end + end +end + +end # module From 1ab00dd1d639b7bc4c93c147c7bfa563ced845fe Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 00:53:00 +0000 Subject: [PATCH 26/60] hasvalue works --- Project.toml | 2 +- ext/AbstractPPLDistributionsExt.jl | 41 +++++++++------- src/AbstractPPL.jl | 6 ++- src/varname/hasvalue.jl | 75 +++++++++++++++++------------- src/varname/optic.jl | 10 +++- test/runtests.jl | 3 +- test/{ => varname}/hasvalue.jl | 7 +++ 7 files changed, 89 insertions(+), 55 deletions(-) rename test/{ => varname}/hasvalue.jl (99%) diff --git a/Project.toml b/Project.toml index 78e90428..65b7e6b8 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" [extensions] -AbstractPPLDistributionsExt = ["Distributions"] +AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] [compat] AbstractMCMC = "2, 3, 4, 5" diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index 2824e75e..ff267cb0 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -49,9 +49,7 @@ This decision may be revisited in the future. module AbstractPPLDistributionsExt -#= - -using AbstractPPL: AbstractPPL, VarName, Accessors, LinearAlgebra +using AbstractPPL: AbstractPPL, VarName, Accessors using Distributions: Distributions using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular @@ -68,17 +66,21 @@ struct Lens!{L} pure::L end (l::Lens!)(o) = l.pure(o) -function Accessors.set(o, l::Lens!{<:ComposedFunction}, val) - o_inner = l.pure.inner(o) - return Accessors.set(o_inner, Lens!(l.pure.outer), val) -end -function Accessors.set(o, l::Lens!{Accessors.PropertyLens{prop}}, val) where {prop} - setproperty!(o, prop, val) - return o +Accessors.set(::Any, l::Lens!{AbstractPPL.Iden}, val) = val +function Accessors.set(obj, l::Lens!{<:AbstractPPL.Property{sym}}, newval) where {sym} + inner_obj = getproperty(obj, sym) + inner_newval = Accessors.set(inner_obj, Lens!(l.pure.child), newval) + # Note that the following line actually does not mutate `obj.sym`. That's fine, because + # the things we are dealing with won't have mutable fields. The point is that + # the inner lens will have mutated whatever `obj.sym` pointed to. + return Accessors.set(obj, l.pure, inner_newval) end -function Accessors.set(o, l::Lens!{<:Accessors.IndexLens}, val) - o[l.pure.indices...] = val - return o +function Accessors.set(obj, l::Lens!{<:AbstractPPL.Index}, newval) + cidx = AbstractPPL.concretize(l.pure, obj) + inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) + inner_newval = AbstractPPL.set(inner_obj, Lens!(l.pure.child), newval) + setindex!(obj, inner_newval, cidx.ix...; cidx.kw...) + return obj end """ @@ -92,7 +94,7 @@ function get_optics( dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} ) indices = CartesianIndices(size(dist)) - return map(idx -> Accessors.IndexLens(idx.I), indices) + return map(idx -> AbstractPPL.Index(idx.I, (;), AbstractPPL.Iden()), indices) end function get_optics(dist::Distributions.LKJCholesky) is_up = dist.uplo == 'U' @@ -102,8 +104,14 @@ function get_optics(dist::Distributions.LKJCholesky) end # there is an additional layer as we need to access `.L` or `.U` before we # can index into it - field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L) - return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices) + function make_lens(idx) + if is_up + AbstractPPL.Property{:U}(AbstractPPL.Index(idx.I, (;), AbstractPPL.Iden())) + else + AbstractPPL.Property{:L}(AbstractPPL.Index(idx.I, (;), AbstractPPL.Iden())) + end + end + return map(make_lens, cartesian_indices) end """ @@ -324,5 +332,4 @@ function AbstractPPL.getvalue( end end -=# end diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 0688297f..7ef93eca 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -24,7 +24,7 @@ include("evaluate.jl") include("varname/optic.jl") include("varname/varname.jl") include("varname/subsumes.jl") -# include("varname/hasvalue.jl") +include("varname/hasvalue.jl") # include("varname/leaves.jl") include("varname/prefix.jl") # include("varname/serialize.jl") @@ -51,7 +51,9 @@ export AbstractOptic, # other functions subsumes, prefix, - unprefix + unprefix, + hasvalue, + getvalue using Accessors: set export set diff --git a/src/varname/hasvalue.jl b/src/varname/hasvalue.jl index 31c5f098..4dd71e7d 100644 --- a/src/varname/hasvalue.jl +++ b/src/varname/hasvalue.jl @@ -4,37 +4,48 @@ Return `true` if `optic` can be used to view `container`, and `false` otherwise. # Examples -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL.canview(@o(_.a), (a = 1.0, )) +```jldoctest +julia> AbstractPPL.canview(@opticof(_.a), (a = 1.0, )) true -julia> AbstractPPL.canview(@o(_.a), (b = 1.0, )) # property `a` does not exist +julia> AbstractPPL.canview(@opticof(_.a), (b = 1.0, )) # property `a` does not exist false -julia> AbstractPPL.canview(@o(_.a[1]), (a = [1.0, 2.0], )) +julia> AbstractPPL.canview(@opticof(_.a[1]), (a = [1.0, 2.0], )) true -julia> AbstractPPL.canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +julia> AbstractPPL.canview(@opticof(_.a[3]), (a = [1.0, 2.0], )) # out of bounds false ``` """ canview(optic, container) = false -canview(::typeof(identity), _) = true -function canview(::Accessors.PropertyLens{field}, x) where {field} - return hasproperty(x, field) +canview(::Iden, ::Any) = true +function canview(prop::Property{field}, x) where {field} + return hasproperty(x, field) && canview(prop.child, getproperty(x, field)) end - -# `IndexLens`: only relevant if `x` supports indexing. -canview(optic::Accessors.IndexLens, x) = false -function canview(optic::Accessors.IndexLens, x::AbstractArray) - return checkbounds(Bool, x, optic.indices...) +function canview(optic::Index, x::AbstractArray) + # TODO(penelopeysm): `checkbounds` doesn't work with keyword arguments for + # DimArray. Hence if we have keyword arguments, we just return false for now. + # https://github.com/rafaqz/DimensionalData.jl/issues/1156 + return isempty(optic.kw) && + checkbounds(Bool, x, optic.ix...) && + canview(optic.child, getindex(x, optic.ix...)) end - -# `ComposedFunction`: check that we can view `.inner` and `.outer`, but using -# value extracted using `.inner`. -function canview(optic::ComposedFunction, x) - return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) +# Handle some other edge cases. +function canview(optic::Index, x::AbstractDict) + return isempty(optic.kw) && + optic.ix.length == 1 && + haskey(x, only(optic.ix)) && + canview(optic.child, x[only(optic.ix)]) +end +function canview(optic::Index, x::NamedTuple) + return isempty(optic.kw) && + optic.ix.length == 1 && + haskey(x, only(optic.ix)) && + canview(optic.child, x[only(optic.ix)]) end +# Give up on all other edge cases. +canview(optic::Index, x) = false """ getvalue(vals::NamedTuple, vn::VarName) @@ -153,21 +164,21 @@ function getvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will # then keep removing optics from `test_optic`, either until we find a key # that is present, or until we run out of optics to test (which happens - # after getoptic(test_vn) == identity). + # after getoptic(test_vn) isa Iden). o = getoptic(vn) - test_vn = VarName{sym}(_init(o)) - test_optic = _last(o) + test_vn = VarName{sym}(oinit(o)) + test_optic = olast(o) while true if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) return test_optic(vals[test_vn]) else # Try to move the outermost optic from test_vn into test_optic. - # If test_vn is already an identity, we can't, so we stop. + # If test_vn is already an Iden, we can't, so we stop. o = getoptic(test_vn) - o == identity && error("$(vn) was not found in the dictionary provided") - test_vn = VarName{sym}(_init(o)) - test_optic = normalise(test_optic ∘ _last(o)) + o isa Iden && error("$(vn) was not found in the dictionary provided") + test_vn = VarName{sym}(oinit(o)) + test_optic = test_optic ∘ olast(o) end end end @@ -246,21 +257,21 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will # then keep removing optics from `test_optic`, either until we find a key # that is present, or until we run out of optics to test (which happens - # after getoptic(test_vn) == identity). + # after getoptic(test_vn) == Iden()). o = getoptic(vn) - test_vn = VarName{sym}(_init(o)) - test_optic = _last(o) + test_vn = VarName{sym}(oinit(o)) + test_optic = olast(o) while true if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) return true else # Try to move the outermost optic from test_vn into test_optic. - # If test_vn is already an identity, we can't, so we stop. + # If test_vn is already an Iden, we can't, so we stop. o = getoptic(test_vn) - o == identity && return false - test_vn = VarName{sym}(_init(o)) - test_optic = normalise(test_optic ∘ _last(o)) + o isa Iden && return false + test_vn = VarName{sym}(oinit(o)) + test_optic = test_optic ∘ olast(o) end end return false diff --git a/src/varname/optic.jl b/src/varname/optic.jl index ad80382b..4005bedc 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -172,9 +172,17 @@ function _pretty_print_optic(io::IO, idx::Index) return _pretty_print_optic(io, idx.child) end is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) -# Things like dictionaries can't be `view`ed into. + +# Helper function to decide whether to use `view` or `getindex`. _maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) +# If it's just a single element, don't use a view, as that returns a weird 0-dimensional +# SubArray (rather than an element) that messes things up if there are further layers of +# optics. For example, if it's an Array of NamedTuples, then trying to access fields on that +# 0-dimensional SubArray will fail. +_maybe_view(val::AbstractArray, i::Int...) = getindex(val, i...) +# Other Things like dictionaries can't be `view`ed into. _maybe_view(val, i...; k...) = getindex(val, i...; k...) + function concretize(idx::Index, val) concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) inner_concretized = concretize( diff --git a/test/runtests.jl b/test/runtests.jl index e1b8e16c..d3ca0bac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,8 +11,7 @@ const GROUP = get(ENV, "GROUP", "All") include("varname/optic.jl") include("varname/varname.jl") include("varname/subsumes.jl") - # include("varname.jl") - # include("hasvalue.jl") + include("varname/hasvalue.jl") end if GROUP == "All" || GROUP == "Doctests" diff --git a/test/hasvalue.jl b/test/varname/hasvalue.jl similarity index 99% rename from test/hasvalue.jl rename to test/varname/hasvalue.jl index 5881eaa2..2799bfc8 100644 --- a/test/hasvalue.jl +++ b/test/varname/hasvalue.jl @@ -1,3 +1,8 @@ +module VarNameHasValueTests + +using AbstractPPL +using Test + @testset "base getvalue + hasvalue" begin @testset "basic NamedTuple" begin nt = (a=[1], b=2, c=(x=3, y=[4], z=(; p=[(; q=5)])), d=[1.0 0.5; 0.5 1.0]) @@ -215,3 +220,5 @@ end @test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0, :U); error_on_incomplete=true) end end + +end From 7efa65da5f746fc1da796de15adc56a5555aa012 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 00:55:09 +0000 Subject: [PATCH 27/60] comment --- src/varname/optic.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 4005bedc..fd90a056 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -173,14 +173,15 @@ function _pretty_print_optic(io::IO, idx::Index) end is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) -# Helper function to decide whether to use `view` or `getindex`. +# Helper function to decide whether to use `view` or `getindex`. For AbstractArray, the +# default behaviour is to attempt to use a view. _maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) # If it's just a single element, don't use a view, as that returns a weird 0-dimensional # SubArray (rather than an element) that messes things up if there are further layers of # optics. For example, if it's an Array of NamedTuples, then trying to access fields on that # 0-dimensional SubArray will fail. _maybe_view(val::AbstractArray, i::Int...) = getindex(val, i...) -# Other Things like dictionaries can't be `view`ed into. +# Other things like dictionaries can't be `view`ed into. _maybe_view(val, i...; k...) = getindex(val, i...; k...) function concretize(idx::Index, val) From 57d50dc54a0a6e0918d5c62e18e82dbb4dd4321d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 01:04:16 +0000 Subject: [PATCH 28/60] leaves works --- src/AbstractPPL.jl | 23 +++++----- src/varname/leaves.jl | 33 ++++++++------- test/runtests.jl | 1 + test/varname.jl | 91 --------------------------------------- test/varname/leaves.jl | 96 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 125 insertions(+), 119 deletions(-) create mode 100644 test/varname/leaves.jl diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 7ef93eca..51fe0a43 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -1,16 +1,5 @@ module AbstractPPL -# index_to_dict, -# dict_to_index, -# varname_to_string, -# string_to_varname, -# prefix, -# unprefix, -# getvalue, -# hasvalue, -# varname_leaves, -# varname_and_value_leaves - # Abstract model functions export AbstractProbabilisticProgram, condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!! @@ -25,7 +14,7 @@ include("varname/optic.jl") include("varname/varname.jl") include("varname/subsumes.jl") include("varname/hasvalue.jl") -# include("varname/leaves.jl") +include("varname/leaves.jl") include("varname/prefix.jl") # include("varname/serialize.jl") @@ -53,7 +42,15 @@ export AbstractOptic, prefix, unprefix, hasvalue, - getvalue + getvalue, + varname_leaves, + varname_and_value_leaves + +# Serialisation +# index_to_dict, +# dict_to_index, +# varname_to_string, +# string_to_varname, using Accessors: set export set diff --git a/src/varname/leaves.jl b/src/varname/leaves.jl index ca7418ed..6721e51f 100644 --- a/src/varname/leaves.jl +++ b/src/varname/leaves.jl @@ -1,5 +1,4 @@ using LinearAlgebra: LinearAlgebra - """ varname_leaves(vn::VarName, val) @@ -28,41 +27,41 @@ x.z[2][1] varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_leaves( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)), val[I] ) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() + optic = Property{k}(Iden()) varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) end return Iterators.flatten(iter) end function varname_leaves(vn::VarName, val::LinearAlgebra.Cholesky) return if val.uplo == 'L' - optic = Accessors.PropertyLens{:L}() + optic = Property{:L}(Iden()) varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), val.L) else - optic = Accessors.PropertyLens{:U}() + optic = Property{:U}(Iden()) varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), val.U) end end function varname_leaves(vn::VarName, x::LinearAlgebra.LowerTriangular) return Iterators.map(Iterators.filter(I -> I[1] >= I[2], CartesianIndices(x))) do I - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)) end end function varname_leaves(vn::VarName, x::LinearAlgebra.UpperTriangular) return Iterators.map(Iterators.filter(I -> I[1] <= I[2], CartesianIndices(x))) do I - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)) end end # This is a default fallback for types that we don't know how to handle. @@ -215,7 +214,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ AbstractPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) @@ -224,14 +223,14 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ AbstractPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) end function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() + optic = Property{k}(Iden()) varname_and_value_leaves_inner( VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) ) @@ -242,21 +241,25 @@ end # Special types. function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky) return if x.uplo == 'L' - varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) + varname_and_value_leaves_inner( + VarName{getsym(vn)}(Property{:L}(Iden()) ∘ getoptic(vn)), x.L + ) else - varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) + varname_and_value_leaves_inner( + VarName{getsym(vn)}(Property{:U}(Iden()) ∘ getoptic(vn)), x.U + ) end end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) + Leaf(VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)), x[I]) # Iteration over the lower-triangular indices. for I in CartesianIndices(x) if I[1] >= I[2] ) end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) + Leaf(VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)), x[I]) # Iteration over the upper-triangular indices. for I in CartesianIndices(x) if I[1] <= I[2] ) diff --git a/test/runtests.jl b/test/runtests.jl index d3ca0bac..fdf2e4ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,7 @@ const GROUP = get(ENV, "GROUP", "All") include("varname/varname.jl") include("varname/subsumes.jl") include("varname/hasvalue.jl") + include("varname/leaves.jl") end if GROUP == "All" || GROUP == "Doctests" diff --git a/test/varname.jl b/test/varname.jl index af0f8967..d399201d 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -1,7 +1,6 @@ using Accessors using InvertedIndices using OffsetArrays -using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky @testset "varnames" begin @testset "de/serialisation of VarNames" begin @@ -71,94 +70,4 @@ using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky # Serialisation should now work @test string_to_varname(varname_to_string(vn)) == vn end - - @testset "varname{_and_value}_leaves" begin - @testset "single value: float, int" begin - x = 1.0 - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - x = 2 - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - end - - @testset "Vector" begin - x = randn(2) - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x[1]), @varname(x[2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])]) - x = [(; a=1), (; b=2)] - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x[1].a), @varname(x[2].b)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x[1].a), x[1].a), (@varname(x[2].b), x[2].b)]) - end - - @testset "Matrix" begin - x = randn(2, 2) - @test Set(varname_leaves(@varname(x), x)) == Set([ - @varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2]) - ]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ - (@varname(x[1, 1]), x[1, 1]), - (@varname(x[1, 2]), x[1, 2]), - (@varname(x[2, 1]), x[2, 1]), - (@varname(x[2, 2]), x[2, 2]), - ]) - end - - @testset "Lower/UpperTriangular" begin - x = randn(2, 2) - xl = LowerTriangular(x) - @test Set(varname_leaves(@varname(x), xl)) == - Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([ - (@varname(x[1, 1]), x[1, 1]), - (@varname(x[2, 1]), x[2, 1]), - (@varname(x[2, 2]), x[2, 2]), - ]) - xu = UpperTriangular(x) - @test Set(varname_leaves(@varname(x), xu)) == - Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([ - (@varname(x[1, 1]), x[1, 1]), - (@varname(x[1, 2]), x[1, 2]), - (@varname(x[2, 2]), x[2, 2]), - ]) - end - - @testset "NamedTuple" begin - x = (a=1.0, b=[2.0, 3.0]) - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x.a), @varname(x.b[1]), @varname(x.b[2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ - (@varname(x.a), x.a), (@varname(x.b[1]), x.b[1]), (@varname(x.b[2]), x.b[2]) - ]) - end - - @testset "Cholesky" begin - x = cholesky([1.0 0.5; 0.5 1.0]) - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ - (@varname(x.U[1, 1]), x.U[1, 1]), - (@varname(x.U[1, 2]), x.U[1, 2]), - (@varname(x.U[2, 2]), x.U[2, 2]), - ]) - end - - @testset "fallback on other types, e.g. string" begin - x = "a string" - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - x = 2 - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - end - end end diff --git a/test/varname/leaves.jl b/test/varname/leaves.jl new file mode 100644 index 00000000..e21ed29c --- /dev/null +++ b/test/varname/leaves.jl @@ -0,0 +1,96 @@ +module VarNameLeavesTests + +using AbstractPPL +using Test +using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky + +@testset "varname/leaves.jl" verbose = true begin + @testset "single value: float, int" begin + x = 1.0 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + x = 2 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + end + + @testset "Vector" begin + x = randn(2) + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x[1]), @varname(x[2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])]) + x = [(; a=1), (; b=2)] + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x[1].a), @varname(x[2].b)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x[1].a), x[1].a), (@varname(x[2].b), x[2].b)]) + end + + @testset "Matrix" begin + x = randn(2, 2) + @test Set(varname_leaves(@varname(x), x)) == Set([ + @varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2]) + ]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[1, 2]), x[1, 2]), + (@varname(x[2, 1]), x[2, 1]), + (@varname(x[2, 2]), x[2, 2]), + ]) + end + + @testset "Lower/UpperTriangular" begin + x = randn(2, 2) + xl = LowerTriangular(x) + @test Set(varname_leaves(@varname(x), xl)) == + Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[2, 1]), x[2, 1]), + (@varname(x[2, 2]), x[2, 2]), + ]) + xu = UpperTriangular(x) + @test Set(varname_leaves(@varname(x), xu)) == + Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[1, 2]), x[1, 2]), + (@varname(x[2, 2]), x[2, 2]), + ]) + end + + @testset "NamedTuple" begin + x = (a=1.0, b=[2.0, 3.0]) + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x.a), @varname(x.b[1]), @varname(x.b[2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x.a), x.a), (@varname(x.b[1]), x.b[1]), (@varname(x.b[2]), x.b[2]) + ]) + end + + @testset "Cholesky" begin + x = cholesky([1.0 0.5; 0.5 1.0]) + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x.U[1, 1]), x.U[1, 1]), + (@varname(x.U[1, 2]), x.U[1, 2]), + (@varname(x.U[2, 2]), x.U[2, 2]), + ]) + end + + @testset "fallback on other types, e.g. string" begin + x = "a string" + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + x = 2 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + end +end + +end # module From 9bc9f3df41be128bd4891e04fe08e9e675a59d5c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 01:11:08 +0000 Subject: [PATCH 29/60] bump min version to 1.10.8 for extension stuff --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 65b7e6b8..f3cf6619 100644 --- a/Project.toml +++ b/Project.toml @@ -31,4 +31,4 @@ LinearAlgebra = "<0.0.1, 1.10" MacroTools = "0.5.16" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" -julia = "1.10" +julia = "1.10.8" From 839eaa145cd67b2abd6e2b1eeebdc38414e78f15 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 01:13:58 +0000 Subject: [PATCH 30/60] docs --- docs/src/varname.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/src/varname.md b/docs/src/varname.md index a2ca709d..6609964b 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -204,3 +204,31 @@ subsumes(vn1, vn2) ```@docs subsumes ``` + +## Prefixing and unprefixing + +Composing two optics can be done using the `∘` operator, as shown above. +But what if we want to compose two `VarName`s? +This is used, for example, in DynamicPPL's submodel functionality. + +```@docs +prefix +unprefix +``` + +## VarName leaves + +The following functions are used to extract the 'leaves' of a VarName, that is, the atomic components of a VarName that do not have any further substructure. +For example, for a vector variable `x`, the leaves would be `x[1]`, `x[2]`, etc. + +```@docs +varname_leaves +varname_and_value_leaves +``` + +## Reading from a container with a VarName + +```@docs +hasvalue +getvalue +``` From f31e37dbea404a31c6dee8a150dad1dd7f4ff280 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 02:16:30 +0000 Subject: [PATCH 31/60] don't view --- src/AbstractPPL.jl | 1 + src/varname/optic.jl | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 51fe0a43..a5e38ac5 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -43,6 +43,7 @@ export AbstractOptic, unprefix, hasvalue, getvalue, + canview, varname_leaves, varname_and_value_leaves diff --git a/src/varname/optic.jl b/src/varname/optic.jl index fd90a056..f52f8ec3 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -173,6 +173,7 @@ function _pretty_print_optic(io::IO, idx::Index) end is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) +#= # Helper function to decide whether to use `view` or `getindex`. For AbstractArray, the # default behaviour is to attempt to use a view. _maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) @@ -183,6 +184,10 @@ _maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) _maybe_view(val::AbstractArray, i::Int...) = getindex(val, i...) # Other things like dictionaries can't be `view`ed into. _maybe_view(val, i...; k...) = getindex(val, i...; k...) +=# +# The above implementation works fine in the AbstractPPL test suite, but causes lots of test +# breakage in DynamicPPL. TODO(penelopeysm): Figure out why and see if we can re-enable it. +_maybe_view(val, i...; k...) = getindex(val, i...; k...) function concretize(idx::Index, val) concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) From c759a2e23de13fd647786e0112d258e77e7ae2f6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 03:01:57 +0000 Subject: [PATCH 32/60] workaround undef --- src/varname/optic.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index f52f8ec3..344b45fe 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -173,7 +173,6 @@ function _pretty_print_optic(io::IO, idx::Index) end is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) -#= # Helper function to decide whether to use `view` or `getindex`. For AbstractArray, the # default behaviour is to attempt to use a view. _maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) @@ -184,16 +183,14 @@ _maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) _maybe_view(val::AbstractArray, i::Int...) = getindex(val, i...) # Other things like dictionaries can't be `view`ed into. _maybe_view(val, i...; k...) = getindex(val, i...; k...) -=# -# The above implementation works fine in the AbstractPPL test suite, but causes lots of test -# breakage in DynamicPPL. TODO(penelopeysm): Figure out why and see if we can re-enable it. -_maybe_view(val, i...; k...) = getindex(val, i...; k...) function concretize(idx::Index, val) concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) - inner_concretized = concretize( - idx.child, _maybe_view(val, concretized_indices...; idx.kw...) - ) + inner_concretized = if idx.child isa Iden + Iden() + else + concretize(idx.child, _maybe_view(val, concretized_indices...; idx.kw...)) + end return Index(concretized_indices, idx.kw, inner_concretized) end function (idx::Index)(obj) @@ -202,8 +199,12 @@ function (idx::Index)(obj) end function Accessors.set(obj, idx::Index, newval) cidx = concretize(idx, obj) - inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) - inner_newval = Accessors.set(inner_obj, idx.child, newval) + inner_newval = if idx.child isa Iden + newval + else + inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) + Accessors.set(inner_obj, idx.child, newval) + end return if !isempty(cidx.kw) # `Accessors.IndexLens` does not handle keyword arguments so we need to do this # ourselves. Note that the following code essentially assumes that `obj` is an From c4a2a15c025c53ab87d0c798894727470f287410 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 03:07:32 +0000 Subject: [PATCH 33/60] add a comment --- src/varname/optic.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 344b45fe..0965e2e8 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -187,6 +187,11 @@ _maybe_view(val, i...; k...) = getindex(val, i...; k...) function concretize(idx::Index, val) concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) inner_concretized = if idx.child isa Iden + # Explicitly having this branch allows us to shortcircuit _maybe_view(...), which + # can error if val[concretized_indices...] is an UndefInitializer. Note that if + # val[concretized_indices...] is an UndefInitializer, then it is not meaningful for + # `idx.child` to be anything other than `Iden` anyway, since there is nothing to + # further index into. Iden() else concretize(idx.child, _maybe_view(val, concretized_indices...; idx.kw...)) From 43c969a254b8c00551abf114dab825e4c481c212 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 03:18:22 +0000 Subject: [PATCH 34/60] export canview --- docs/src/varname.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/varname.md b/docs/src/varname.md index 6609964b..429df5f7 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -226,9 +226,10 @@ varname_leaves varname_and_value_leaves ``` -## Reading from a container with a VarName +## Reading from a container with a VarName (or optic) ```@docs +canview hasvalue getvalue ``` From ecb6590c3bd7b6dece38737cf403578c4246a9a7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 03:35:14 +0000 Subject: [PATCH 35/60] export the macro AST function too --- docs/src/varname.md | 1 + src/AbstractPPL.jl | 1 + src/varname/varname.jl | 23 ++++++++++++----------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/docs/src/varname.md b/docs/src/varname.md index 429df5f7..9327b797 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -19,6 +19,7 @@ vn = @varname(x.a[1]) ```@docs VarName @varname +varname_macro ``` You can obtain the components of a `VarName` using the `getsym` and `getoptic` functions: diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index a5e38ac5..6a7e97de 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -34,6 +34,7 @@ export AbstractOptic, concretize, is_dynamic, @varname, + varname_macro, @opticof, varname_to_optic, optic_to_varname, diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 9960e54e..084fe1c8 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -240,6 +240,17 @@ ERROR: LoadError: cannot automatically concretize VarName with interpolated top- ``` """ macro varname(expr, concretize::Bool=false) + return varname_macro(expr, concretize) +end + +""" + varname_macro(expr, concretize::Bool) + +Implementation of the `@varname` macro. See the documentation for `@varname` for details. +This function is exported to allow other macros (e.g. in DynamicPPL) to reuse the same +logic. +""" +function varname_macro(expr, concretize::Bool) unconcretized_vn, sym = _varname(expr, :(Iden())) return if concretize sym === nothing && throw(VarNameConcretizationException()) @@ -348,17 +359,7 @@ specifically, if the top-level symbol is interpolated, automatic concretization possible. """ macro opticof(expr, concretize::Bool=false) - # This implementation is a bit ugly, as it copies the logic from `@varname`. However, - # getting the output of `@varname` and then processing it is a bit tricky, specifically - # when concretization is involved (because the top-level value must be escaped, but not - # anything else!). So it's easier to just duplicate the logic here. - unconcretized_vn, sym = _varname(expr, :(Iden())) - return if concretize - sym === nothing && throw(VarNameConcretizationException()) - :(getoptic(concretize($unconcretized_vn, $(esc(sym))))) - else - :(getoptic($unconcretized_vn)) - end + return :(getoptic($(varname_macro(expr, concretize)))) end """ From add9d3421a496273267cf07969e1c71a3a7a2087 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 03:37:08 +0000 Subject: [PATCH 36/60] improve tuple tupling tupled --- src/varname/varname.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 084fe1c8..dad0e65e 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -293,12 +293,13 @@ function _varname(expr::Expr, inner_expr) positional_ixs = map(positional_args) do (dim, ix_expr) _handle_index(ix_expr, is_single_index ? nothing : dim) end + positional_tpl = Expr(:tuple, positional_ixs...) kwarg_expr = if isempty(keyword_args) :((;)) else Expr(:tuple, keyword_args...) end - :(Index(tuple($(positional_ixs...)), $kwarg_expr, $inner_expr)) + :(Index($positional_tpl, $kwarg_expr, $inner_expr)) else # some other expression we can't parse throw(VarNameParseException(expr)) From 58636a8ce7988f4d2e0ff30b0c6f995c9eb99fdb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 03:37:36 +0000 Subject: [PATCH 37/60] there are two hard things in computer science --- src/varname/varname.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varname/varname.jl b/src/varname/varname.jl index dad0e65e..5c96c22a 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -293,13 +293,13 @@ function _varname(expr::Expr, inner_expr) positional_ixs = map(positional_args) do (dim, ix_expr) _handle_index(ix_expr, is_single_index ? nothing : dim) end - positional_tpl = Expr(:tuple, positional_ixs...) + positional_expr = Expr(:tuple, positional_ixs...) kwarg_expr = if isempty(keyword_args) :((;)) else Expr(:tuple, keyword_args...) end - :(Index($positional_tpl, $kwarg_expr, $inner_expr)) + :(Index($positional_expr, $kwarg_expr, $inner_expr)) else # some other expression we can't parse throw(VarNameParseException(expr)) From 2eb7e7f66299a803965cd7bfe17ca5e2ca9b58a4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 04:05:09 +0000 Subject: [PATCH 38/60] interpolate --- src/varname/optic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 0965e2e8..b2c768e0 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -81,10 +81,10 @@ function _make_dynamicindex_expr(symbol::Symbol, dim::Union{Nothing,Int}) # https://github.com/tree-sitter/tree-sitter-julia/issues/104 if symbol === Symbol(:begin) func = dim === nothing ? :(Base.firstindex) : :(Base.Fix2(firstindex, $dim)) - return :(DynamicIndex($(QuoteNode(symbol)), $func)) + return :($(DynamicIndex)($(QuoteNode(symbol)), $func)) elseif symbol === Symbol(:end) func = dim === nothing ? :(Base.lastindex) : :(Base.Fix2(lastindex, $dim)) - return :(DynamicIndex($(QuoteNode(symbol)), $func)) + return :($(DynamicIndex)($(QuoteNode(symbol)), $func)) else # Just a variable; but we need to escape it to allow interpolation. return esc(symbol) @@ -94,7 +94,7 @@ function _make_dynamicindex_expr(expr::Expr, dim::Union{Nothing,Int}) @gensym val if has_begin_or_end(expr) replaced_expr = MacroTools.postwalk(x -> _make_dynamicindex_expr(x, val, dim), expr) - return :(DynamicIndex($(QuoteNode(expr)), $val -> $replaced_expr)) + return :($(DynamicIndex)($(QuoteNode(expr)), $val -> $replaced_expr)) else return esc(expr) end From 7083a698071c80a89402c9d31212014d7020808b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 04:05:40 +0000 Subject: [PATCH 39/60] interpolate --- src/varname/varname.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 5c96c22a..7cf94299 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -240,21 +240,21 @@ ERROR: LoadError: cannot automatically concretize VarName with interpolated top- ``` """ macro varname(expr, concretize::Bool=false) - return varname_macro(expr, concretize) + return varname(expr, concretize) end """ - varname_macro(expr, concretize::Bool) + varname(expr, concretize::Bool) Implementation of the `@varname` macro. See the documentation for `@varname` for details. This function is exported to allow other macros (e.g. in DynamicPPL) to reuse the same logic. """ -function varname_macro(expr, concretize::Bool) +function varname(expr, concretize::Bool) unconcretized_vn, sym = _varname(expr, :(Iden())) return if concretize sym === nothing && throw(VarNameConcretizationException()) - :(concretize($unconcretized_vn, $(esc(sym)))) + :($(concretize)($unconcretized_vn, $(esc(sym)))) else unconcretized_vn end @@ -264,7 +264,7 @@ function _varname(@nospecialize(expr::Any), ::Any) throw(VarNameParseException(expr)) end function _varname(sym::Symbol, inner_expr) - return :($VarName{$(QuoteNode(sym))}($inner_expr)), sym + return :($(VarName){$(QuoteNode(sym))}($inner_expr)), sym end function _varname(expr::Expr, inner_expr) if Meta.isexpr(expr, :$, 1) @@ -273,11 +273,11 @@ function _varname(expr::Expr, inner_expr) # expr.head would be :ref or :.) Thus we don't need to recurse further, and we can # just return `inner_expr` as-is. sym_expr = expr.args[1] - return :(VarName{$(esc(sym_expr))}($inner_expr)), nothing + return :($(VarName){$(esc(sym_expr))}($inner_expr)), nothing else next_inner = if expr.head == :(.) sym = _handle_property(expr.args[2], expr) - :(Property{$(sym)}($inner_expr)) + :($(Property){$(sym)}($inner_expr)) elseif expr.head == :ref original_ixs = expr.args[2:end] positional_args = [] @@ -299,7 +299,7 @@ function _varname(expr::Expr, inner_expr) else Expr(:tuple, keyword_args...) end - :(Index($positional_expr, $kwarg_expr, $inner_expr)) + :($(Index)($positional_expr, $kwarg_expr, $inner_expr)) else # some other expression we can't parse throw(VarNameParseException(expr)) From ac95c23abb5013fcbc75c14ac41914ca448ff0dc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 04:08:08 +0000 Subject: [PATCH 40/60] fix name --- docs/src/varname.md | 2 +- src/AbstractPPL.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/varname.md b/docs/src/varname.md index 9327b797..f5e94acf 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -19,7 +19,7 @@ vn = @varname(x.a[1]) ```@docs VarName @varname -varname_macro +varname ``` You can obtain the components of a `VarName` using the `getsym` and `getoptic` functions: diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 6a7e97de..69170c7a 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -34,7 +34,7 @@ export AbstractOptic, concretize, is_dynamic, @varname, - varname_macro, + varname, @opticof, varname_to_optic, optic_to_varname, From 456c3221cb14390a644ebeb5770c52eb83635fdd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 04:09:27 +0000 Subject: [PATCH 41/60] more interpolation --- src/varname/varname.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 7cf94299..78214163 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -251,7 +251,7 @@ This function is exported to allow other macros (e.g. in DynamicPPL) to reuse th logic. """ function varname(expr, concretize::Bool) - unconcretized_vn, sym = _varname(expr, :(Iden())) + unconcretized_vn, sym = _varname(expr, :($(Iden)())) return if concretize sym === nothing && throw(VarNameConcretizationException()) :($(concretize)($unconcretized_vn, $(esc(sym)))) From 63b04b94e52191c6e65b58cd09f050904dc52a58 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 04:11:34 +0000 Subject: [PATCH 42/60] fix name --- src/varname/varname.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 78214163..e4bc45ae 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -360,7 +360,7 @@ specifically, if the top-level symbol is interpolated, automatic concretization possible. """ macro opticof(expr, concretize::Bool=false) - return :(getoptic($(varname_macro(expr, concretize)))) + return :(getoptic($(varname(expr, concretize)))) end """ From e9eab0a84d138cc88e740eabcc2f163894afc141 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 04:17:46 +0000 Subject: [PATCH 43/60] so dirty ugh --- src/varname/varname.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varname/varname.jl b/src/varname/varname.jl index e4bc45ae..6a64600e 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -254,7 +254,7 @@ function varname(expr, concretize::Bool) unconcretized_vn, sym = _varname(expr, :($(Iden)())) return if concretize sym === nothing && throw(VarNameConcretizationException()) - :($(concretize)($unconcretized_vn, $(esc(sym)))) + :($(AbstractPPL.concretize)($unconcretized_vn, $(esc(sym)))) else unconcretized_vn end From 4ef515e169555d4327070a2616e89a8b54e9e57d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 14:53:09 +0000 Subject: [PATCH 44/60] try to fix tuple equality --- src/varname/optic.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index b2c768e0..d24d2a62 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -158,7 +158,12 @@ struct Index{I<:Tuple,N<:NamedTuple,C<:AbstractOptic} <: AbstractOptic end end -Base.:(==)(a::Index, b::Index) = a.ix == b.ix && a.kw == b.kw && a.child == b.child +# Workaround for https://github.com/JuliaLang/julia/issues/60470 +_tuple_eq(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _tuple_eq_inner(t1, t2) +_tuple_eq_inner(t1::Tuple{}, t2::Tuple{}) = true +_tuple_eq_inner(t1::Tuple{Any,Vararg{Any}}, t2::Tuple{Any,Vararg{Any}}) = Base._eq(t1, t2) + +Base.:(==)(a::Index, b::Index) = _tuple_eq(a.ix, b.ix) && a.kw == b.kw && a.child == b.child function Base.isequal(a::Index, b::Index) return isequal(a.ix, b.ix) && isequal(a.kw, b.kw) && isequal(a.child, b.child) end From e11e23e2e9b7cf0efe043d54f5b41532b74a0b78 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 15:19:47 +0000 Subject: [PATCH 45/60] workaround for base Julia issue --- src/varname/optic.jl | 50 +++++++++++++++++++++++++++++++++++++++--- src/varname/varname.jl | 4 +++- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index d24d2a62..04986e44 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -14,6 +14,10 @@ function Base.show(io::IO, optic::AbstractOptic) _pretty_print_optic(io, optic) return print(io, ")") end +# Lots of (==) and isequal methods as a workaround for +# https://github.com/JuliaLang/julia/issues/60470 :( +Base.:(==)(a::AbstractOptic, b::AbstractOptic) = false +Base.isequal(a::AbstractOptic, b::AbstractOptic) = false """ Iden() @@ -26,6 +30,8 @@ _pretty_print_optic(::IO, ::Iden) = nothing is_dynamic(::Iden) = false concretize(i::Iden, ::Any) = i (::Iden)(obj) = obj +Base.:(==)(::Iden, ::Iden) = true +Base.isequal(::Iden, ::Iden) = true Accessors.set(obj::Any, ::Iden, val) = Accessors.set(obj, identity, val) """ @@ -159,9 +165,46 @@ struct Index{I<:Tuple,N<:NamedTuple,C<:AbstractOptic} <: AbstractOptic end # Workaround for https://github.com/JuliaLang/julia/issues/60470 -_tuple_eq(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _tuple_eq_inner(t1, t2) +# In particular, these four methods are new: _tuple_eq_inner(t1::Tuple{}, t2::Tuple{}) = true -_tuple_eq_inner(t1::Tuple{Any,Vararg{Any}}, t2::Tuple{Any,Vararg{Any}}) = Base._eq(t1, t2) +_tuple_eq_inner(t1::Tuple{}, t2::Tuple) = false +_tuple_eq_inner_missing(t1::Tuple, t2::Tuple{}) = false +_tuple_eq_inner_missing(t1::Tuple{}, t2::Tuple) = false +# These methods are directly copied from Base tuple.jl and correspond to the usual equality +# checks for tuples +_tuple_eq(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _tuple_eq_inner(t1, t2) +_tuple_eq_inner(t1::Tuple, t2::Tuple{}) = false +_tuple_eq_inner_missing(t1::Tuple{}, t2::Tuple{}) = missing +function _tuple_eq_inner(t1::Tuple, t2::Tuple) + eq = t1[1] == t2[1] + if eq === false + return false + elseif ismissing(eq) + return _tuple_eq_inner_missing(Base.tail(t1), Base.tail(t2)) + else + return _tuple_eq_inner(Base.tail(t1), Base.tail(t2)) + end +end +function _tuple_eq_inner_missing(t1::Tuple, t2::Tuple) + eq = t1[1] == t2[1] + if eq === false + return false + else + return _tuple_eq_inner_missing(Base.tail(t1), Base.tail(t2)) + end +end +function _tuple_eq_inner(t1::Base.Any32, t2::Base.Any32) + anymissing = false + for i in eachindex(t1, t2) + eq = (t1[i] == t2[i]) + if ismissing(eq) + anymissing = true + elseif !eq + return false + end + end + return anymissing ? missing : true +end Base.:(==)(a::Index, b::Index) = _tuple_eq(a.ix, b.ix) && a.kw == b.kw && a.child == b.child function Base.isequal(a::Index, b::Index) @@ -245,7 +288,8 @@ Property{sym}(child::C=Iden()) where {sym,C<:AbstractOptic} = Property{sym,C}(ch Base.:(==)(a::Property{sym}, b::Property{sym}) where {sym} = a.child == b.child Base.:(==)(a::Property, b::Property) = false -Base.isequal(a::Property, b::Property) = a == b +Base.isequal(a::Property{sym}, b::Property{sym}) where {sym} = isequal(a.child, b.child) +Base.isequal(a::Property, b::Property) = false getsym(::Property{s}) where {s} = s function _pretty_print_optic(io::IO, prop::Property{sym}) where {sym} print(io, ".$(sym)") diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 6a64600e..22587ae2 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -52,7 +52,9 @@ getoptic(vn::VarName) = vn.optic function Base.:(==)(x::VarName, y::VarName) return getsym(x) == getsym(y) && getoptic(x) == getoptic(y) end -Base.isequal(x::VarName, y::VarName) = x == y +function Base.isequal(x::VarName, y::VarName) + return getsym(x) == getsym(y) && isequal(getoptic(x), getoptic(y)) +end Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) From 5a28760a5dd71e2cc00c93ed638bc572cc4957bb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 16:05:25 +0000 Subject: [PATCH 46/60] add `with_mutation` --- Project.toml | 2 + docs/src/varname.md | 12 ++++ ext/AbstractPPLDistributionsExt.jl | 38 ++---------- src/AbstractPPL.jl | 1 + src/varname/optic.jl | 97 ++++++++++++++++++++++++++++-- test/varname/optic.jl | 78 ++++++++++++++++++++++++ 6 files changed, 188 insertions(+), 40 deletions(-) diff --git a/Project.toml b/Project.toml index f3cf6619..a0e2a329 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ version = "0.14.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -24,6 +25,7 @@ AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] [compat] AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" +BangBang = "0.4.6" DensityInterface = "0.4" Distributions = "0.25" JSON = "0.19 - 0.21, 1" diff --git a/docs/src/varname.md b/docs/src/varname.md index f5e94acf..261e9fca 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -123,6 +123,18 @@ new_data = set(data, optic, 99) new_data, data ``` +If you want to try to mutate values, you can wrap an optic using `with_mutation`. + +```@example vn +optic_mut = with_mutation(optic) +set(data, optic_mut, 99) +data +``` + +```@docs +with_mutation +``` + ## Composing and decomposing optics If you have two optics, you can compose them using the `∘` operator: diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index ff267cb0..80995365 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -53,36 +53,6 @@ using AbstractPPL: AbstractPPL, VarName, Accessors using Distributions: Distributions using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular -#= -This section is copied from Accessors.jl's documentation: -https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/ - -It defines a wrapper that, when called with `set`, mutates the original value -rather than returning a new value. We need this because the non-mutating optics -don't work for triangular matrices (and hence LKJCholesky): see -https://github.com/JuliaObjects/Accessors.jl/issues/203 -=# -struct Lens!{L} - pure::L -end -(l::Lens!)(o) = l.pure(o) -Accessors.set(::Any, l::Lens!{AbstractPPL.Iden}, val) = val -function Accessors.set(obj, l::Lens!{<:AbstractPPL.Property{sym}}, newval) where {sym} - inner_obj = getproperty(obj, sym) - inner_newval = Accessors.set(inner_obj, Lens!(l.pure.child), newval) - # Note that the following line actually does not mutate `obj.sym`. That's fine, because - # the things we are dealing with won't have mutable fields. The point is that - # the inner lens will have mutated whatever `obj.sym` pointed to. - return Accessors.set(obj, l.pure, inner_newval) -end -function Accessors.set(obj, l::Lens!{<:AbstractPPL.Index}, newval) - cidx = AbstractPPL.concretize(l.pure, obj) - inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) - inner_newval = AbstractPPL.set(inner_obj, Lens!(l.pure.child), newval) - setindex!(obj, inner_newval, cidx.ix...; cidx.kw...) - return obj -end - """ get_optics(dist::MultivariateDistribution) get_optics(dist::MatrixDistribution) @@ -321,10 +291,10 @@ function AbstractPPL.getvalue( # Retrieve the value of this given index sub_value = AbstractPPL.getvalue(vals, sub_vn) # Set it inside the value we're reconstructing. - # Note: `o` is normally non-mutating. We have to wrap it in `Lens!` - # to make it mutating, because Cholesky distributions are broken - # by https://github.com/JuliaObjects/Accessors.jl/issues/203. - Accessors.set(value, Lens!(o), sub_value) + # Note: `o` is normally non-mutating. We have to use the mutating version, + # because Cholesky distributions are broken by + # https://github.com/JuliaObjects/Accessors.jl/issues/203. + Accessors.set(value, AbstractPPL.with_mutation(o), sub_value) end return value else diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 69170c7a..9586b67e 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -23,6 +23,7 @@ export AbstractOptic, Iden, Index, Property, + with_mutation, ohead, otail, olast, diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 04986e44..973fc587 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -1,4 +1,5 @@ using Accessors: Accessors +using BangBang: BangBang using MacroTools: MacroTools """ @@ -166,15 +167,15 @@ end # Workaround for https://github.com/JuliaLang/julia/issues/60470 # In particular, these four methods are new: -_tuple_eq_inner(t1::Tuple{}, t2::Tuple{}) = true -_tuple_eq_inner(t1::Tuple{}, t2::Tuple) = false -_tuple_eq_inner_missing(t1::Tuple, t2::Tuple{}) = false -_tuple_eq_inner_missing(t1::Tuple{}, t2::Tuple) = false +_tuple_eq_inner(::Tuple, ::Tuple{}) = false +_tuple_eq_inner(::Tuple{}, ::Tuple) = false +_tuple_eq_inner_missing(::Tuple, ::Tuple{}) = false +_tuple_eq_inner_missing(::Tuple{}, ::Tuple) = false # These methods are directly copied from Base tuple.jl and correspond to the usual equality # checks for tuples _tuple_eq(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _tuple_eq_inner(t1, t2) -_tuple_eq_inner(t1::Tuple, t2::Tuple{}) = false -_tuple_eq_inner_missing(t1::Tuple{}, t2::Tuple{}) = missing +_tuple_eq_inner(::Tuple{}, ::Tuple{}) = true +_tuple_eq_inner_missing(::Tuple{}, ::Tuple{}) = missing function _tuple_eq_inner(t1::Tuple, t2::Tuple) eq = t1[1] == t2[1] if eq === false @@ -448,3 +449,87 @@ function oinit(idx::Index) end end oinit(i::Iden) = i + +""" + with_mutation(o::AbstractOptic) + +Create a version of the optic `o` which attempts to mutate its input where possible. + +On their own, `AbstractOptic`s are non-mutating: + +```jldoctest +julia> optic = @opticof(_[1]) +Optic([1]) + +julia> x = [0.0, 0.0]; + +julia> set(x, optic, 1.0); x +2-element Vector{Float64}: + 0.0 + 0.0 +``` + +With this function, we can create a mutating version of the optic: + +```jldoctest +julia> optic_mut = with_mutation(@opticof(_[1])) +Optic!!([1]) + +julia> x = [0.0, 0.0]; + +julia> set(x, optic_mut, 1.0); x +2-element Vector{Float64}: + 1.0 + 0.0 +``` + +Thanks to the BangBang.jl package, this optic will gracefully fall back to non-mutating +behaviour if mutation is not possible. For example, if we try to use it on a tuple: + +```jldoctest +julia> optic_mut = with_mutation(@opticof(_[1])) +Optic!!([1]) + +julia> x = (0.0, 0.0); + +julia> set(x, optic_mut, 1.0); x +(0.0, 0.0) +``` +""" +with_mutation(o::AbstractOptic) = _Optic!!(o) + +struct _Optic!!{O<:AbstractOptic} <: AbstractOptic + pure::O +end +_Optic!!(l::_Optic!!) = l +Base.:(==)(a::_Optic!!, b::_Optic!!) = a.pure == b.pure +Base.isequal(a::_Optic!!, b::_Optic!!) = isequal(a.pure, b.pure) +Base.hash(a::_Optic!!, h::UInt) = hash("_Optic!!", hash(a.pure, h)) +function Base.show(io::IO, l::_Optic!!) + print(io, "Optic!!(") + _pretty_print_optic(io, l.pure) + return print(io, ")") +end + +# Getter. +(l::_Optic!!)(o) = l.pure(o) +# Setters. +Accessors.set(::Any, l::_Optic!!{AbstractPPL.Iden}, val) = val +function Accessors.set(obj, l::_Optic!!{<:AbstractPPL.Property{sym}}, newval) where {sym} + inner_obj = getproperty(obj, sym) + inner_newval = Accessors.set(inner_obj, _Optic!!(l.pure.child), newval) + return BangBang.setproperty!!(obj, sym, inner_newval) +end +function Accessors.set(obj, l::_Optic!!{<:AbstractPPL.Index}, newval) + cidx = AbstractPPL.concretize(l.pure, obj) + inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) + inner_newval = AbstractPPL.set(inner_obj, _Optic!!(l.pure.child), newval) + return if isempty(cidx.kw) + # setindex!! doesn't support keyword arguments. + BangBang.setindex!!(obj, inner_newval, cidx.ix...) + else + # If there are keyword arguments, we just assume that setindex! will always work. + # This is a bit dangerous, but fine. + setindex!(obj, inner_newval, cidx.ix...; cidx.kw...) + end +end diff --git a/test/varname/optic.jl b/test/varname/optic.jl index d6c7b502..845e09e8 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -144,6 +144,84 @@ using AbstractPPL @test set(s, @opticof(_.a), 10) == SampleStruct(10, s.b) @test set(s, @opticof(_.b), 2.5) == SampleStruct(s.a, 2.5) end + + @testset "mutating versions" begin + @testset "arrays" begin + x = zeros(4) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[2])) + @test optic(x) === x[2] + set(x, optic, 1.0) + @test x[2] == 1.0 + @test x == [0.0, 1.0, 0.0, 0.0] + @test objectid(x) == old_objid + end + + @testset "dicts" begin + x = Dict("a" => 1, "b" => 2) + old_objid = objectid(x) + optic = with_mutation(@opticof(_["b"])) + @test optic(x) === x["b"] + set(x, optic, 99) + @test x["b"] == 99 + @test x == Dict("a" => 1, "b" => 99) + @test objectid(x) == old_objid + end + + @testset "mutable structs" begin + mutable struct MutableStruct + a::Int + b::Float64 + end + x = MutableStruct(3, 1.5) + old_objid = objectid(x) + optic = with_mutation(@opticof(_.a)) + @test optic(x) === x.a + set(x, optic, 10) + @test x.a == 10 + @test x.b == 1.5 + @test objectid(x) == old_objid + end + + @testset "fallback for immutable data" begin + @testset "NamedTuple" begin + s = (a=1, b=2) + old_objid = objectid(s) + optic = with_mutation(@opticof(_.a)) + @test optic(s) === s.a + s2 = set(s, optic, 10) + @test s2 == (a=10, b=2) + @test s == (a=1, b=2) + @test objectid(s) == old_objid + end + + @testset "tuple" begin + s = (3, 1.5) + old_objid = objectid(s) + optic = with_mutation(@opticof(_[1])) + @test optic(s) === s[1] + s2 = set(s, optic, 10) + @test s2 == (10, 1.5) + @test s == (3, 1.5) + @test objectid(s) == old_objid + end + + @testset "struct" begin + struct SampleStructAgain + a::Int + b::Float64 + end + s = SampleStructAgain(3, 1.5) + old_objid = objectid(s) + optic = with_mutation(@opticof(_.a)) + @test optic(s) === s.a + s2 = set(s, optic, 10) + @test s2 == SampleStructAgain(10, 1.5) + @test s == SampleStructAgain(3, 1.5) + @test objectid(s) == old_objid + end + end + end end end # module From 0317af73c4235d67a0ca698dd08ab0c4c831bfd5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 16:16:05 +0000 Subject: [PATCH 47/60] fix undef vectors --- src/varname/optic.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 973fc587..8dc45f5a 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -514,22 +514,26 @@ end # Getter. (l::_Optic!!)(o) = l.pure(o) # Setters. -Accessors.set(::Any, l::_Optic!!{AbstractPPL.Iden}, val) = val -function Accessors.set(obj, l::_Optic!!{<:AbstractPPL.Property{sym}}, newval) where {sym} +Accessors.set(::Any, l::_Optic!!{Iden}, val) = val +function Accessors.set(obj, l::_Optic!!{<:Property{sym}}, newval) where {sym} inner_obj = getproperty(obj, sym) inner_newval = Accessors.set(inner_obj, _Optic!!(l.pure.child), newval) return BangBang.setproperty!!(obj, sym, inner_newval) end -function Accessors.set(obj, l::_Optic!!{<:AbstractPPL.Index}, newval) - cidx = AbstractPPL.concretize(l.pure, obj) - inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) - inner_newval = AbstractPPL.set(inner_obj, _Optic!!(l.pure.child), newval) +function Accessors.set(obj, l::_Optic!!{<:Index}, newval) + cidx = concretize(l.pure, obj) + inner_newval = if l.pure.child isa Iden + newval + else + inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) + Accessors.set(inner_obj, _Optic!!(l.pure.child), newval) + end return if isempty(cidx.kw) # setindex!! doesn't support keyword arguments. BangBang.setindex!!(obj, inner_newval, cidx.ix...) else # If there are keyword arguments, we just assume that setindex! will always work. # This is a bit dangerous, but fine. - setindex!(obj, inner_newval, cidx.ix...; cidx.kw...) + Base.setindex!(obj, inner_newval, cidx.ix...; cidx.kw...) end end From 38cbf747cae99dc5f09654da87a5dfae408e695d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 17:22:38 +0000 Subject: [PATCH 48/60] Fix some docs typos --- src/varname/optic.jl | 8 ++++---- src/varname/varname.jl | 13 ++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 8dc45f5a..57f612fd 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -152,7 +152,7 @@ _concretize_index(idx::DynamicIndex, val) = idx.f(val) Index(ix, kw, child=Iden()) An indexing optic representing access to indices `ix`, which may also take the form of -keyword arguments `kw`. A VarName{:x} with this optic represents access to `x[ix..., +keyword arguments `kw`. A `VarName{:x}` with this optic represents access to `x[ix..., kw...]`. The child optic represents any further indexing or property access after this indexing operation. """ @@ -278,9 +278,9 @@ end """ Property{sym}(child=Iden()) -A property access optic representing access to property `sym`. A VarName{:x} with this -optic represents access to `x.sym`. The child optic represents any further indexing -or property access after this property access operation. +A property access optic representing access to property `sym`. A `VarName{:x}` with this +optic represents access to `x.sym`. The child optic represents any further indexing or +property access after this property access operation. """ struct Property{sym,C<:AbstractOptic} <: AbstractOptic child::C diff --git a/src/varname/varname.jl b/src/varname/varname.jl index 22587ae2..e78952b1 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -225,15 +225,14 @@ julia> @varname(\$name.a.\$name[1]) hello.a.hello[1] ``` -For indices, you do nott need to use `\$` to interpolate, just use the variable directly: +For indices, you do not need to use `\$` to interpolate, just use the variable directly: ```jldoctest julia> ix = 2; @varname(x[ix]) x[2] ``` -However, if the top-level symbol is interpolated, automatic concretization is not -possible: +Note that if the top-level symbol is interpolated, automatic concretization is not possible: ```jldoctest julia> name = :x; @varname(\$name[1:end], true) @@ -345,15 +344,15 @@ If you don't need to concretize, you should use `_` as the top-level symbol to indicate that it is not relevant: ```jldoctest -julia> AbstractPPL.@opticof(_.a.b) +julia> @opticof(_.a.b) Optic(.a.b) ``` -Only if you need to concretize should you provide a real variable name (in which case -it is then used to look up the value for concretization): +If you need to concretize, then you can provide a real variable name (which is then used to +look up the value for concretization): ```jldoctest -julia> x = randn(3, 4); AbstractPPL.@opticof(x[1:end, end], true) +julia> x = randn(3, 4); @opticof(x[1:end, end], true) Optic([1:3, 4]) ``` From 3b0c3761de40834f70c520a66e6fb2e33690e125 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 17:32:44 +0000 Subject: [PATCH 49/60] Add more tests for mutating optics --- test/varname/optic.jl | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index 845e09e8..bb3e9d98 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -147,14 +147,38 @@ using AbstractPPL @testset "mutating versions" begin @testset "arrays" begin - x = zeros(4) - old_objid = objectid(x) - optic = with_mutation(@opticof(_[2])) - @test optic(x) === x[2] - set(x, optic, 1.0) - @test x[2] == 1.0 - @test x == [0.0, 1.0, 0.0, 0.0] - @test objectid(x) == old_objid + @testset "static index" begin + x = zeros(4) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[2])) + @test optic(x) === x[2] + set(x, optic, 1.0) + @test x[2] == 1.0 + @test x == [0.0, 1.0, 0.0, 0.0] + @test objectid(x) == old_objid + end + + @testset "dynamic index" begin + x = zeros(2, 2) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[begin, end])) + @test optic(x) === x[begin, end] + set(x, optic, 2.0) + @test x[begin, end] == 2.0 + @test x == [0.0 2.0; 0.0 0.0] + @test objectid(x) == old_objid + end + + @testset "keyword index" begin + x = DD.DimArray(zeros(2, 2), (:x, :y)) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[x=1, y=2])) + @test optic(x) === x[x=1, y=2] + set(x, optic, 2.0) + @test x[x=1, y=2] == 2.0 + @test collect(x) == [0.0 2.0; 0.0 0.0] + @test objectid(x) == old_objid + end end @testset "dicts" begin From 7ee1d04aa7035623490db1a0a4dfcdf638227091 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 17:56:22 +0000 Subject: [PATCH 50/60] serialization works --- HISTORY.md | 3 ++ src/AbstractPPL.jl | 16 +++--- src/varname/serialize.jl | 66 ++++++++++++----------- test/runtests.jl | 1 + test/{varname.jl => varname/serialize.jl} | 43 +++++++++------ 5 files changed, 74 insertions(+), 55 deletions(-) rename test/{varname.jl => varname/serialize.jl} (63%) diff --git a/HISTORY.md b/HISTORY.md index d7b55c83..de69ebe4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -48,6 +48,9 @@ It used to be relevant for Turing's old Gibbs sampler; but now it no longer serv The subsumption interface has been pared down to just a single function, `subsumes`. All other functions, such as `subsumedby`, `uncomparable`, and the Unicode operators, have been removed. +Serialization still works exactly as before. +However, you will see differences in the serialization output compared to previous versions, due to the changes in the internal structure. + ## 0.13.6 Fix a missing qualifier in AbstractPPLDistributionsExt. diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 9586b67e..8db124eb 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -16,7 +16,7 @@ include("varname/subsumes.jl") include("varname/hasvalue.jl") include("varname/leaves.jl") include("varname/prefix.jl") -# include("varname/serialize.jl") +include("varname/serialize.jl") # Optics export AbstractOptic, @@ -47,14 +47,14 @@ export AbstractOptic, getvalue, canview, varname_leaves, - varname_and_value_leaves - -# Serialisation -# index_to_dict, -# dict_to_index, -# varname_to_string, -# string_to_varname, + varname_and_value_leaves, + # Serialisation + index_to_dict, + dict_to_index, + varname_to_string, + string_to_varname +# Convenience re-export using Accessors: set export set diff --git a/src/varname/serialize.jl b/src/varname/serialize.jl index 7feb4896..7dca9d08 100644 --- a/src/varname/serialize.jl +++ b/src/varname/serialize.jl @@ -10,7 +10,6 @@ const _BASE_UNITRANGE_TYPE = "Base.UnitRange" const _BASE_STEPRANGE_TYPE = "Base.StepRange" const _BASE_ONETO_TYPE = "Base.OneTo" const _BASE_COLON_TYPE = "Base.Colon" -const _CONCRETIZED_SLICE_TYPE = "AbstractPPL.ConcretizedSlice" const _BASE_TUPLE_TYPE = "Base.Tuple" """ @@ -41,12 +40,16 @@ function index_to_dict(r::Base.OneTo{I}) where {I} return Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop) end index_to_dict(::Colon) = Dict("type" => _BASE_COLON_TYPE) -function index_to_dict(s::ConcretizedSlice{T,R}) where {T,R} - return Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range)) -end function index_to_dict(t::Tuple) return Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t)) end +function index_to_dict(::DynamicIndex) + throw( + ArgumentError( + "DynamicIndex cannot be serialised; please concretise the VarName before serialising.", + ), + ) +end """ dict_to_index(dict) @@ -90,8 +93,6 @@ function dict_to_index(dict) return Base.OneTo(dict["stop"]) elseif t == _BASE_COLON_TYPE return Colon() - elseif t == _CONCRETIZED_SLICE_TYPE - return ConcretizedSlice(Base.Slice(dict_to_index(dict["range"]))) elseif t == _BASE_TUPLE_TYPE return tuple(map(dict_to_index, dict["values"])...) else @@ -101,28 +102,39 @@ function dict_to_index(dict) end end -optic_to_dict(::typeof(identity)) = Dict("type" => "identity") -function optic_to_dict(::PropertyLens{sym}) where {sym} - return Dict("type" => "property", "field" => String(sym)) +const _OPTIC_IDEN_NAME = "Iden" +const _OPTIC_PROPERTY_NAME = "Property" +const _OPTIC_INDEX_NAME = "Index" +optic_to_dict(::Iden) = Dict("type" => _OPTIC_IDEN_NAME) +function optic_to_dict(p::Property{sym}) where {sym} + return Dict( + "type" => _OPTIC_PROPERTY_NAME, + "field" => String(sym), + "child" => optic_to_dict(p.child), + ) end -optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices)) -function optic_to_dict(c::ComposedFunction) +function optic_to_dict(i::Index) return Dict( - "type" => "composed", - "outer" => optic_to_dict(c.outer), - "inner" => optic_to_dict(c.inner), + "type" => _OPTIC_INDEX_NAME, + # For some reason if you don't do the isempty check, it gets serialised as `{}` + # rather than `[]` + "ix" => isempty(i.ix) ? [] : collect(map(index_to_dict, i.ix)), + # TODO(penelopeysm): This is potentially lossy since order is not guaranteed + "kw" => Dict(String(x) => index_to_dict(y) for (x, y) in pairs(i.kw)), + "child" => optic_to_dict(i.child), ) end function dict_to_optic(dict) - if dict["type"] == "identity" - return identity - elseif dict["type"] == "index" - return IndexLens(dict_to_index(dict["indices"])) - elseif dict["type"] == "property" - return PropertyLens{Symbol(dict["field"])}() - elseif dict["type"] == "composed" - return dict_to_optic(dict["outer"]) ∘ dict_to_optic(dict["inner"]) + if dict["type"] == _OPTIC_IDEN_NAME + return Iden() + elseif dict["type"] == _OPTIC_INDEX_NAME + ixs = tuple(map(dict_to_index, dict["ix"])...) + kws = NamedTuple(Symbol(k) => dict_to_index(v) for (k, v) in dict["kw"]) + child = dict_to_optic(dict["child"]) + return Index(ixs, kws, child) + elseif dict["type"] == _OPTIC_PROPERTY_NAME + return Property{Symbol(dict["field"])}(dict_to_optic(dict["child"])) else error("Unknown optic type: $(dict["type"])") end @@ -151,16 +163,10 @@ documentation of [`dict_to_index`](@ref) for instructions on how to do this. ```jldoctest julia> varname_to_string(@varname(x)) -"{\\"optic\\":{\\"type\\":\\"identity\\"},\\"sym\\":\\"x\\"}" +"{\\"optic\\":{\\"type\\":\\"Iden\\"},\\"sym\\":\\"x\\"}" julia> varname_to_string(@varname(x.a)) -"{\\"optic\\":{\\"field\\":\\"a\\",\\"type\\":\\"property\\"},\\"sym\\":\\"x\\"}" - -julia> y = ones(2); varname_to_string(@varname(y[:])) -"{\\"optic\\":{\\"indices\\":{\\"values\\":[{\\"type\\":\\"Base.Colon\\"}],\\"type\\":\\"Base.Tuple\\"},\\"type\\":\\"index\\"},\\"sym\\":\\"y\\"}" - -julia> y = ones(2); varname_to_string(@varname(y[:], true)) -"{\\"optic\\":{\\"indices\\":{\\"values\\":[{\\"range\\":{\\"stop\\":2,\\"type\\":\\"Base.OneTo\\"},\\"type\\":\\"AbstractPPL.ConcretizedSlice\\"}],\\"type\\":\\"Base.Tuple\\"},\\"type\\":\\"index\\"},\\"sym\\":\\"y\\"}" +"{\\"optic\\":{\\"child\\":{\\"type\\":\\"Iden\\"},\\"field\\":\\"a\\",\\"type\\":\\"Property\\"},\\"sym\\":\\"x\\"}" ``` """ varname_to_string(vn::VarName) = JSON.json(varname_to_dict(vn)) diff --git a/test/runtests.jl b/test/runtests.jl index fdf2e4ae..a5c4e3cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ const GROUP = get(ENV, "GROUP", "All") include("varname/subsumes.jl") include("varname/hasvalue.jl") include("varname/leaves.jl") + include("varname/serialize.jl") end if GROUP == "All" || GROUP == "Doctests" diff --git a/test/varname.jl b/test/varname/serialize.jl similarity index 63% rename from test/varname.jl rename to test/varname/serialize.jl index d399201d..d84d0149 100644 --- a/test/varname.jl +++ b/test/varname/serialize.jl @@ -1,9 +1,11 @@ -using Accessors -using InvertedIndices -using OffsetArrays +module VarNameSerialisationTests -@testset "varnames" begin - @testset "de/serialisation of VarNames" begin +using AbstractPPL +using InvertedIndices: Not, InvertedIndex +using Test + +@testset "varname/serialize.jl" verbose = true begin + @testset "roundtrip" begin y = ones(10) z = ones(5, 2) vns = [ @@ -23,8 +25,8 @@ using OffsetArrays @varname(x.a[1:10]), @varname(x[1].a), @varname(y[:]), - @varname(y[begin:end]), - @varname(y[end]), + @varname(y[begin:end], true), + @varname(y[end], true), @varname(y[:], false), @varname(y[:], true), @varname(z[:], false), @@ -35,6 +37,8 @@ using OffsetArrays @varname(z[:, :], true), @varname(z[2:5, :], false), @varname(z[2:5, :], true), + @varname(x[i=1]), + @varname(x[].a[j=2].b[3, 4, 5, [6]]), ] for vn in vns @test string_to_varname(varname_to_string(vn)) == vn @@ -50,24 +54,29 @@ using OffsetArrays @test hash(vn_vec) == hash(vn_vec2) end - @testset "de/serialisation of VarNames with custom index types" begin - using OffsetArrays: OffsetArrays, Origin - weird = Origin(4)(ones(10)) - vn = @varname(weird[:], true) + @testset "deserialisation fails for unconcretised dynamic indices" begin + for vn in (@varname(x[1:end]), @varname(x[begin:end]), @varname(x[2:step:end])) + @test_throws ArgumentError varname_to_string(vn) + end + end + + @testset "custom index types" begin + vn = @varname(x[Not(3)]) # This won't work as we don't yet know how to handle OffsetArray @test_throws MethodError varname_to_string(vn) # Now define the relevant methods - AbstractPPL.index_to_dict(o::OffsetArrays.IdOffsetRange{I,R}) where {I,R} = Dict( - "type" => "OffsetArrays.OffsetArray", - "parent" => AbstractPPL.index_to_dict(o.parent), - "offset" => o.offset, + AbstractPPL.index_to_dict(o::InvertedIndex{I}) where {I} = Dict( + "type" => "InvertedIndices.InvertedIndex", + "skip" => AbstractPPL.index_to_dict(o.skip), ) - AbstractPPL.dict_to_index(::Val{Symbol("OffsetArrays.OffsetArray")}, d) = - OffsetArrays.IdOffsetRange(AbstractPPL.dict_to_index(d["parent"]), d["offset"]) + AbstractPPL.dict_to_index(::Val{Symbol("InvertedIndices.InvertedIndex")}, d) = + InvertedIndex(AbstractPPL.dict_to_index(d["skip"])) # Serialisation should now work @test string_to_varname(varname_to_string(vn)) == vn end end + +end # module From 84b3865f971fd12d4be852407903c0ef0b8794c8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 18:25:21 +0000 Subject: [PATCH 51/60] Add more tests --- test/varname/optic.jl | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index bb3e9d98..524079f8 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -94,6 +94,15 @@ using AbstractPPL @test set(v, @opticof(_.d[:]), fill(9.9, 2, 2)) == (a=v.a, d=fill(9.9, 2, 2)) end + @testset "no indices" begin + x = [0.0] + @test @opticof(_[])(x) == x[] + @test set(x, @opticof(_[]), 9.0) == [9.0] + x = 0.0 + @test @opticof(_[])(x) == x[] + @test set(x, @opticof(_[]), 9.0) == 9.0 + end + @testset "dynamic indices" begin x = [0.0 1.0; 2.0 3.0] @test @opticof(_[begin])(x) == x[begin] @@ -169,6 +178,17 @@ using AbstractPPL @test objectid(x) == old_objid end + @testset "DimArray, setting a single element" begin + dimarray = DD.DimArray(zeros(2, 3), (DD.X, DD.Y)) + old_objid = objectid(dimarray) + optic = with_mutation(@opticof(_[DD.X(1), DD.Y(1)])) + @test optic(dimarray) == dimarray[DD.X(1), DD.Y(1)] + dimarray = set(dimarray, optic, 1.0) + @test dimarray[DD.X(1), DD.Y(1)] == 1.0 + @test collect(dimarray) == [1.0 0.0 0.0; 0.0 0.0 0.0] + @test objectid(dimarray) == old_objid + end + @testset "keyword index" begin x = DD.DimArray(zeros(2, 2), (:x, :y)) old_objid = objectid(x) @@ -219,6 +239,27 @@ using AbstractPPL @test objectid(s) == old_objid end + # NOTE(penelopeysm): This SHOULD really mutate. It is not an error with + # AbstractPPL, though, it is an interface problem between BangBang and + # DimensionalData (essentially BangBang can't detect that DimArray is mutable). + # + # To be precise, the test fails because BangBang thinks that `DD.Y(1)` is an + # index that extracts a single element from the DimArray. For example, this + # would be the case if it was dimarray[1]. So, BangBang thinks that you can't + # set an array there, and so it falls back to the immutable behavior. + # Specifically, it's this line that returns false: + # https://github.com/JuliaFolds2/BangBang.jl/blob/e92b4c1673a686533b5f9724a198b63d8974d52f/src/base.jl#L528 + @testset "DimArray, setting a vector" begin + dimarray = DD.DimArray(zeros(2, 3), (DD.X, DD.Y)) + old_objid = objectid(dimarray) + optic = with_mutation(@opticof(_[DD.Y(1)])) + @test optic(dimarray) == dimarray[DD.Y(1)] + dimarray = set(dimarray, optic, [1.0, 2.0]) + @test collect(dimarray[DD.Y(1)]) == [1.0; 2.0] + @test collect(dimarray) == [1.0 0.0 0.0; 2.0 0.0 0.0] + # @test objectid(dimarray) == old_objid + end + @testset "tuple" begin s = (3, 1.5) old_objid = objectid(s) From 95a69e1e06ec5b5dea4aa284e1aff12b5d7290b5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 18:26:08 +0000 Subject: [PATCH 52/60] Docs --- docs/src/varname.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/src/varname.md b/docs/src/varname.md index 261e9fca..88672ba5 100644 --- a/docs/src/varname.md +++ b/docs/src/varname.md @@ -246,3 +246,12 @@ canview hasvalue getvalue ``` + +## Serializing VarNames + +```@docs +index_to_dict +dict_to_index +varname_to_string +string_to_varname +``` From 0282e150074baea6fd86670c1ce779db0fda3df3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 18:28:40 +0000 Subject: [PATCH 53/60] reenable Aqua tests --- Project.toml | 6 +++--- test/Aqua.jl | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index a0e2a329..de3ab112 100644 --- a/Project.toml +++ b/Project.toml @@ -25,12 +25,12 @@ AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] [compat] AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" -BangBang = "0.4.6" +BangBang = "0.4" DensityInterface = "0.4" Distributions = "0.25" JSON = "0.19 - 0.21, 1" -LinearAlgebra = "<0.0.1, 1.10" -MacroTools = "0.5.16" +LinearAlgebra = "<0.0.1, 1" +MacroTools = "0.5" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" julia = "1.10.8" diff --git a/test/Aqua.jl b/test/Aqua.jl index 6f608e2c..27d4103f 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -3,8 +3,6 @@ module AquaTests using Aqua: Aqua using AbstractPPL -# For now, we skip ambiguities since they come from interactions -# with third-party packages rather than issues in AbstractPPL itself -Aqua.test_all(AbstractPPL; ambiguities=false) +Aqua.test_all(AbstractPPL) end From 9733bfe8766523b272610aad2916daafdfa0564a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 18:29:48 +0000 Subject: [PATCH 54/60] fix ambiguities --- src/varname/subsumes.jl | 2 ++ test/runtests.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varname/subsumes.jl b/src/varname/subsumes.jl index e9d1c3c3..2a29b5e5 100644 --- a/src/varname/subsumes.jl +++ b/src/varname/subsumes.jl @@ -70,7 +70,9 @@ _subsumes_index(::DynamicIndex, ::Colon) = false _subsumes_index(::Colon, ::DynamicIndex) = true _subsumes_index(::Any, ::DynamicIndex) = false _subsumes_index(::Colon, ::Any) = true +_subsumes_index(::Colon, ::Colon) = true _subsumes_index(::Any, ::Colon) = false _subsumes_index(a::AbstractVector, b::Any) = issubset(b, a) _subsumes_index(a::AbstractVector, b::Colon) = false +_subsumes_index(a::AbstractVector, b::DynamicIndex) = false _subsumes_index(a::Any, b::Any) = a == b diff --git a/test/runtests.jl b/test/runtests.jl index a5c4e3cb..fc8d3980 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractPPL.jl" begin if GROUP == "All" || GROUP == "Tests" - # include("Aqua.jl") + include("Aqua.jl") include("abstractprobprog.jl") include("varname/optic.jl") include("varname/varname.jl") From e5277af91aa3e1fab8364b8d35f8878daf2bc4ce Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 20:08:43 +0000 Subject: [PATCH 55/60] Add more tests for coverage --- test/varname/optic.jl | 104 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 93 insertions(+), 11 deletions(-) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index 524079f8..e78f28fb 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -5,9 +5,50 @@ using DimensionalData: DimensionalData as DD using AbstractPPL @testset "varname/optic.jl" verbose = true begin - # Note that much of the functionality in optic.jl is tested by varname.jl (for example, - # pretty-printing VarNames essentially boils down to pretty-printing optics). So, this - # file focuses on tests that are specific to optics. + @testset "pretty-printing" begin + @test string(@opticof(_.a.b.c)) == "Optic(.a.b.c)" + @test string(@opticof(_[1][2][3])) == "Optic([1][2][3])" + @test string(@opticof(_)) == "Optic()" + @test string(@opticof(_[begin])) == "Optic([DynamicIndex(begin)])" + @test string(@opticof(_[2:end])) == "Optic([DynamicIndex(2:end)])" + + @test string(with_mutation(@opticof(_.a.b.c))) == "Optic!!(.a.b.c)" + @test string(with_mutation(@opticof(_[1][2][3]))) == "Optic!!([1][2][3])" + @test string(with_mutation(@opticof(_))) == "Optic!!()" + @test string(with_mutation(@opticof(_[begin]))) == "Optic!!([DynamicIndex(begin)])" + @test string(with_mutation(@opticof(_[2:end]))) == "Optic!!([DynamicIndex(2:end)])" + end + + @testset "equality" begin + optics = ( + @opticof(_), + @opticof(_[1]), + @opticof(_.a), + @opticof(_[begin]), + @opticof(_[end]), + @opticof(_[:]), + @opticof(_.a[2]), + @opticof(_.a[1, :]), + @opticof(_[1].a), + @opticof(_[1, x=1].a), + ) + for (i, optic1) in enumerate(optics) + for (j, optic2) in enumerate(optics) + if i == j + @test optic1 == optic2 + @test with_mutation(optic1) == with_mutation(optic2) + @test isequal(optic1, optic2) + @test isequal(with_mutation(optic1), with_mutation(optic2)) + else + @test optic1 != optic2 + @test with_mutation(optic1) != with_mutation(optic2) + @test !isequal(optic1, optic2) + @test !isequal(with_mutation(optic1), with_mutation(optic2)) + end + end + end + end + @testset "composition" begin @testset "with identity" begin i = AbstractPPL.Iden() @@ -143,15 +184,31 @@ using AbstractPPL DD.DimArray([0.0 9.0; 2.0 8.0], (:x, :y)) end - struct SampleStruct - a::Int - b::Float64 + @testset "properties on struct" begin + struct SampleStruct + a::Int + b::Float64 + end + s = SampleStruct(3, 1.5) + @test @opticof(_.a)(s) == 3 + @test @opticof(_.b)(s) == 1.5 + @test set(s, @opticof(_.a), 10) == SampleStruct(10, s.b) + @test set(s, @opticof(_.b), 2.5) == SampleStruct(s.a, 2.5) + end + + @testset "nested optics" begin + x = (; a=[(; b=1)]) + optic = @opticof(_.a[1].b) + @test optic(x) == 1 + x2 = set(x, optic, 42) + @test x2 == (; a=[(; b=42)]) + + y = [(; a=[1.0])] + optic2 = @opticof(_[1].a[1]) + @test optic2(y) == 1.0 + y2 = set(y, optic2, 3.14) + @test y2 == [(; a=[3.14])] end - s = SampleStruct(3, 1.5) - @test @opticof(_.a)(s) == 3 - @test @opticof(_.b)(s) == 1.5 - @test set(s, @opticof(_.a), 10) == SampleStruct(10, s.b) - @test set(s, @opticof(_.b), 2.5) == SampleStruct(s.a, 2.5) end @testset "mutating versions" begin @@ -199,6 +256,31 @@ using AbstractPPL @test collect(x) == [0.0 2.0; 0.0 0.0] @test objectid(x) == old_objid end + + @testset "nested optics" begin + x = (; a=[(; b=1)]) + old_objid = objectid(x) + old_inner_objid = objectid(x.a) + optic = with_mutation(@opticof(_.a[1].b)) + @test optic(x) == 1 + set(x, optic, 42) + @test x == (; a=[(; b=42)]) + # Check that mutation happened at the very bottom level. + @test objectid(x) == old_objid + @test objectid(x.a) == old_inner_objid + + y = [(; a=[1.0])] + old_objid = objectid(y) + old_inner_objid = objectid(y[1]) + old_inner_inner_objid = objectid(y[1].a) + optic2 = with_mutation(@opticof(_[1].a[1])) + @test optic2(y) == 1.0 + set(y, optic2, 3.14) + @test y == [(; a=[3.14])] + @test objectid(y) == old_objid + @test objectid(y[1]) == old_inner_objid + @test objectid(y[1].a) == old_inner_inner_objid + end end @testset "dicts" begin From a2a7c1f707c594dcdae72a595def91bf6a12bbcd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 20:17:45 +0000 Subject: [PATCH 56/60] Add a test for undef --- test/varname/optic.jl | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index e78f28fb..fcb25a81 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -144,6 +144,22 @@ using AbstractPPL @test set(x, @opticof(_[]), 9.0) == 9.0 end + @testset "vector of undef" begin + # For this test to be meaningful, the eltype of `x` must be abstract. This sort + # of situation happens in DynamicPPL's demo models -- the use of `Real` as + # eltype is meant to help with ForwardDiff (even though *technically* that is + # not necessary... see + # https://github.com/TuringLang/DynamicPPL.jl/issues/823#issuecomment-3166049286) + x = Vector{Real}(undef, 3) + optic = opticof(_[2]) + @test_throws UndefRefError optic(x) + x2 = set(x, optic, 3.14) + @test x2[2] == 3.14 + @test !isassigned(x2, 1) + @test isassigned(x2, 2) + @test !isassigned(x2, 3) + end + @testset "dynamic indices" begin x = [0.0 1.0; 2.0 3.0] @test @opticof(_[begin])(x) == x[begin] @@ -224,6 +240,21 @@ using AbstractPPL @test objectid(x) == old_objid end + @testset "vector of undef" begin + # eltype(x) must be abstract for this test to be meaningful. See above for + # discussion. + x = Vector{Real}(undef, 3) + old_objid = objectid(x) + optic = with_mutation(opticof(_[2])) + @test_throws UndefRefError optic(x) + set(x, optic, 3.14) + @test x[2] == 3.14 + @test !isassigned(x, 1) + @test isassigned(x, 2) + @test !isassigned(x, 3) + @test objectid(x) == old_objid + end + @testset "dynamic index" begin x = zeros(2, 2) old_objid = objectid(x) From b77ed52d122dac5074095b8101387f0f523c691e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 20:37:35 +0000 Subject: [PATCH 57/60] Add JET tests to prevent regressions on base Julia thingy --- src/varname/optic.jl | 12 ++++++++++-- test/Project.toml | 2 ++ test/varname/optic.jl | 5 ++--- test/varname/varname.jl | 23 +++++++++++++++++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/varname/optic.jl b/src/varname/optic.jl index 57f612fd..ba0de07e 100644 --- a/src/varname/optic.jl +++ b/src/varname/optic.jl @@ -206,8 +206,16 @@ function _tuple_eq_inner(t1::Base.Any32, t2::Base.Any32) end return anymissing ? missing : true end - -Base.:(==)(a::Index, b::Index) = _tuple_eq(a.ix, b.ix) && a.kw == b.kw && a.child == b.child +# Because the Base methods for NamedTuple equality rely on tuple equality, we also need to +# patch that :( +_nt_eq(a::NamedTuple{n}, b::NamedTuple{n}) where {n} = _tuple_eq(Tuple(a), Tuple(b)) +_nt_eq(a::NamedTuple, b::NamedTuple) = false + +# Note: Do NOT change this to `a.ix == b.ix`! This is a workaround for +# https://github.com/JuliaLang/julia/issues/60470 +function Base.:(==)(a::Index, b::Index) + return _tuple_eq(a.ix, b.ix) && _nt_eq(a.kw, b.kw) && a.child == b.child +end function Base.isequal(a::Index, b::Index) return isequal(a.ix, b.ix) && isequal(a.kw, b.kw) && isequal(a.child, b.child) end diff --git a/test/Project.toml b/test/Project.toml index d538d4db..79803fe6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -17,6 +18,7 @@ Aqua = "0.8" Distributions = "0.25" Documenter = "0.26.3, 0.27, 1" InvertedIndices = "1" +JET = "0.9, 0.10, 0.11" OffsetArrays = "1" OrderedCollections = "1" julia = "1" diff --git a/test/varname/optic.jl b/test/varname/optic.jl index fcb25a81..cfaa4238 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -11,7 +11,6 @@ using AbstractPPL @test string(@opticof(_)) == "Optic()" @test string(@opticof(_[begin])) == "Optic([DynamicIndex(begin)])" @test string(@opticof(_[2:end])) == "Optic([DynamicIndex(2:end)])" - @test string(with_mutation(@opticof(_.a.b.c))) == "Optic!!(.a.b.c)" @test string(with_mutation(@opticof(_[1][2][3]))) == "Optic!!([1][2][3])" @test string(with_mutation(@opticof(_))) == "Optic!!()" @@ -151,7 +150,7 @@ using AbstractPPL # not necessary... see # https://github.com/TuringLang/DynamicPPL.jl/issues/823#issuecomment-3166049286) x = Vector{Real}(undef, 3) - optic = opticof(_[2]) + optic = @opticof(_[2]) @test_throws UndefRefError optic(x) x2 = set(x, optic, 3.14) @test x2[2] == 3.14 @@ -245,7 +244,7 @@ using AbstractPPL # discussion. x = Vector{Real}(undef, 3) old_objid = objectid(x) - optic = with_mutation(opticof(_[2])) + optic = with_mutation(@opticof(_[2])) @test_throws UndefRefError optic(x) set(x, optic, 3.14) @test x[2] == 3.14 diff --git a/test/varname/varname.jl b/test/varname/varname.jl index c10c5624..65b2875d 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -2,6 +2,7 @@ module VarNameTests using AbstractPPL using Test +using JET: @test_call @testset "varname/varname.jl" verbose = true begin @testset "basic construction (and type stability)" begin @@ -59,6 +60,28 @@ using Test end end + @testset "JET on equality + dynamic dispatch" begin + # This test is very specific, so some context is needed: + # + # In DynamicPPL it's quite common to want to search for a VarName in a collection of + # VarNames. Usually the collection will not have a concrete element type (because + # it's a mixture of different optics). Thus, there will be a fair amount of dynamic + # dispatch when performing the comparisons. + # + # In AbstractPPL, there is custom code in `src/varname/optic.jl` to make sure that + # equality comparisons of `Index` lenses are JET-friendly even when this happens + # (i.e., JET.jl doesn't error on `@report_call`). These were needed because base + # Julia's equality methods on tuples error with JET: + # https://github.com/JuliaLang/julia/issues/60470, and using those default methods + # would cause test failures in DynamicPPLJETExt. + # + # This test therefore makes sure we don't cause any regressions. + vns = [@varname(x), @varname(x[1]), @varname(x.a)] + for vn in vns + @test_call any(k -> k == vn, vns) + end + end + @testset "pretty-printing" begin @test string(@varname(x)) == "x" @test string(@varname(x[1])) == "x[1]" From 15c6968efcfae49d07b583779cdd8ebc10a642b5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 21:18:24 +0000 Subject: [PATCH 58/60] Add more tests to satisfy the coverage demon --- test/varname/optic.jl | 9 +++++++++ test/varname/subsumes.jl | 3 +++ test/varname/varname.jl | 7 +++++++ 3 files changed, 19 insertions(+) diff --git a/test/varname/optic.jl b/test/varname/optic.jl index cfaa4238..c4e3b736 100644 --- a/test/varname/optic.jl +++ b/test/varname/optic.jl @@ -8,11 +8,13 @@ using AbstractPPL @testset "pretty-printing" begin @test string(@opticof(_.a.b.c)) == "Optic(.a.b.c)" @test string(@opticof(_[1][2][3])) == "Optic([1][2][3])" + @test string(@opticof(_["a"][:b])) == "Optic([\"a\"][:b])" @test string(@opticof(_)) == "Optic()" @test string(@opticof(_[begin])) == "Optic([DynamicIndex(begin)])" @test string(@opticof(_[2:end])) == "Optic([DynamicIndex(2:end)])" @test string(with_mutation(@opticof(_.a.b.c))) == "Optic!!(.a.b.c)" @test string(with_mutation(@opticof(_[1][2][3]))) == "Optic!!([1][2][3])" + @test string(with_mutation(@opticof(_["a"][:b]))) == "Optic!!([\"a\"][:b])" @test string(with_mutation(@opticof(_))) == "Optic!!()" @test string(with_mutation(@opticof(_[begin]))) == "Optic!!([DynamicIndex(begin)])" @test string(with_mutation(@opticof(_[2:end]))) == "Optic!!([DynamicIndex(2:end)])" @@ -227,6 +229,13 @@ using AbstractPPL end @testset "mutating versions" begin + @testset "construction and equality" begin + @test with_mutation(@opticof(_.a.b)) != @opticof(_.a.b) + # check that there is no nesting + @test with_mutation(with_mutation(@opticof(_.a.b))) == + with_mutation(@opticof(_.a.b)) + end + @testset "arrays" begin @testset "static index" begin x = zeros(4) diff --git a/test/varname/subsumes.jl b/test/varname/subsumes.jl index 90d5dd2b..d4187fea 100644 --- a/test/varname/subsumes.jl +++ b/test/varname/subsumes.jl @@ -8,6 +8,7 @@ using Test @test subsumes(@varname(x), @varname(x)) @test subsumes(@varname(x[1]), @varname(x[1])) @test subsumes(@varname(x.a), @varname(x.a)) + @test subsumes(@varname(x[:]), @varname(x[:])) end uncomparable(vn1, vn2) = !subsumes(vn1, vn2) && !subsumes(vn2, vn1) @@ -59,6 +60,8 @@ using Test @test strictly_subsumes(@varname(x[:]), @varname(x[end])) @test strictly_subsumes(@varname(x[:]), @varname(x[1:end])) @test strictly_subsumes(@varname(x[:]), @varname(x[end - 3])) + @test uncomparable(@varname(x[begin]), @varname(x["a"])) + @test uncomparable(@varname(x[begin]), @varname(x[1:5])) end @testset "keyword indices" begin diff --git a/test/varname/varname.jl b/test/varname/varname.jl index 65b2875d..8ddb7802 100644 --- a/test/varname/varname.jl +++ b/test/varname/varname.jl @@ -131,6 +131,13 @@ using JET: @test_call arr = randn(4, 4) @test concretize(vn, arr) == @varname(x[1:16]) end + + @testset "nested" begin + vn = @varname(x.a[end].b) + @test vn isa VarName + @test is_dynamic(vn) + @test concretize(vn, (; a=[(; b=1)])) == @varname(x.a[1].b) + end end @testset "things that shouldn't be dynamic aren't dynamic" begin From bc36cdc503018772a187685e4964dcd3a868cdcb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 21:43:04 +0000 Subject: [PATCH 59/60] Fix canview for DimArray --- src/varname/hasvalue.jl | 27 ++++++++++++++++++--------- test/varname/hasvalue.jl | 17 +++++++++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/varname/hasvalue.jl b/src/varname/hasvalue.jl index 4dd71e7d..96a1246e 100644 --- a/src/varname/hasvalue.jl +++ b/src/varname/hasvalue.jl @@ -18,18 +18,17 @@ julia> AbstractPPL.canview(@opticof(_.a[3]), (a = [1.0, 2.0], )) # out of bounds false ``` """ -canview(optic, container) = false +canview(::AbstractOptic, ::Any) = false canview(::Iden, ::Any) = true function canview(prop::Property{field}, x) where {field} return hasproperty(x, field) && canview(prop.child, getproperty(x, field)) end function canview(optic::Index, x::AbstractArray) - # TODO(penelopeysm): `checkbounds` doesn't work with keyword arguments for - # DimArray. Hence if we have keyword arguments, we just return false for now. - # https://github.com/rafaqz/DimensionalData.jl/issues/1156 - return isempty(optic.kw) && - checkbounds(Bool, x, optic.ix...) && - canview(optic.child, getindex(x, optic.ix...)) + return if isempty(optic.kw) + checkbounds(Bool, x, optic.ix...) && canview(optic.child, getindex(x, optic.ix...)) + else + _canview_fallback_kw(optic, x) + end end # Handle some other edge cases. function canview(optic::Index, x::AbstractDict) @@ -44,8 +43,18 @@ function canview(optic::Index, x::NamedTuple) haskey(x, only(optic.ix)) && canview(optic.child, x[only(optic.ix)]) end -# Give up on all other edge cases. -canview(optic::Index, x) = false +# For cases where there are keyword arguments, we don't have much of a choice but to +# try/catch. For example, this will be hit when using keyword arguments for DimArray. +# However, there's no `checkbounds` method (yet): +# https://github.com/rafaqz/DimensionalData.jl/issues/1156 +function _canview_fallback_kw(optic::Index, x) + try + v = getindex(x, optic.ix...; optic.kw...) + return canview(optic.child, v) + catch + return false + end +end """ getvalue(vals::NamedTuple, vn::VarName) diff --git a/test/varname/hasvalue.jl b/test/varname/hasvalue.jl index 2799bfc8..6914d0ff 100644 --- a/test/varname/hasvalue.jl +++ b/test/varname/hasvalue.jl @@ -1,6 +1,7 @@ module VarNameHasValueTests using AbstractPPL +using DimensionalData: DimensionalData as DD using Test @testset "base getvalue + hasvalue" begin @@ -161,6 +162,22 @@ using Test @test !hasvalue(d, @varname(x[2][1][1][1])) end end + + @testset "DimArray and keyword indices" begin + x = (; a=DD.DimArray(randn(2, 3), (:i, :j))) + @test hasvalue(x, @varname(a)) + @test getvalue(x, @varname(a)) == x.a + @test hasvalue(x, @varname(a[1, 2])) + @test getvalue(x, @varname(a[1, 2])) == x.a[1, 2] + @test hasvalue(x, @varname(a[:])) + @test getvalue(x, @varname(a[:])) == x.a[:] + @test canview(@opticof(_[i=1]), x.a) + @test hasvalue(x, @varname(a[i=1])) + @test getvalue(x, @varname(a[i=1])) == x.a[i=1] + @test canview(@opticof(_[i=1, j=2]), x.a) + @test hasvalue(x, @varname(a[i=1, j=2])) + @test getvalue(x, @varname(a[i=1, j=2])) == x.a[i=1, j=2] + end end @testset "with Distributions: getvalue + hasvalue" begin From 2571d7d0e225812b4a751ebed1f42bae53a5a616 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Dec 2025 21:45:31 +0000 Subject: [PATCH 60/60] even. more. tests. --- test/varname/hasvalue.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/varname/hasvalue.jl b/test/varname/hasvalue.jl index 6914d0ff..a0b03b3d 100644 --- a/test/varname/hasvalue.jl +++ b/test/varname/hasvalue.jl @@ -163,7 +163,7 @@ using Test end end - @testset "DimArray and keyword indices" begin + @testset "DimArray indices (including keyword)" begin x = (; a=DD.DimArray(randn(2, 3), (:i, :j))) @test hasvalue(x, @varname(a)) @test getvalue(x, @varname(a)) == x.a @@ -177,6 +177,20 @@ using Test @test canview(@opticof(_[i=1, j=2]), x.a) @test hasvalue(x, @varname(a[i=1, j=2])) @test getvalue(x, @varname(a[i=1, j=2])) == x.a[i=1, j=2] + @test hasvalue(x, @varname(a[i=DD.Not(1)])) + @test getvalue(x, @varname(a[i=DD.Not(1)])) == x.a[i=DD.Not(1)] + + y = (; b=DD.DimArray(randn(2, 3), (DD.X, DD.Y))) + @test hasvalue(y, @varname(b)) + @test getvalue(y, @varname(b)) == y.b + @test hasvalue(y, @varname(b[1, 2])) + @test getvalue(y, @varname(b[1, 2])) == y.b[1, 2] + @test hasvalue(y, @varname(b[:])) + @test getvalue(y, @varname(b[:])) == y.b[:] + @test hasvalue(y, @varname(b[DD.X(1)])) + @test getvalue(y, @varname(b[DD.X(1)])) == y.b[DD.X(1)] + @test hasvalue(y, @varname(b[DD.X(1), DD.Y(2)])) + @test getvalue(y, @varname(b[DD.X(1), DD.Y(2)])) == y.b[DD.X(1), DD.Y(2)] end end