diff --git a/HISTORY.md b/HISTORY.md index 35f1dc8a..de69ebe4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,56 @@ +## 0.14.0 + +This release overhauls the `VarName` type. +Much of the external API for traversing and manipulating `VarName`s (once they have been constructed) has been preserved, but if you use the `VarName` type directly, there are significant changes. + +**Internal representation** + +The `optic` field of VarName now uses our hand-rolled optic types, which are subtypes of `AbstractPPL.AbstractOptic`. +Previously these were optics from Accessors.jl. + +This change was made for two reasons: firstly, it is easier to provide custom behaviour for VarNames as we avoid running into possible type piracy issues, and secondly, the linked-list data structure used in `AbstractOptic` is easier to work with than Accessors.jl, which used `Base.ComposedFunction` to represent optic compositions and required a lot of care to avoid a litany of issues with associativity and identity optics (see e.g. https://github.com/JuliaLang/julia/pull/54877). + +To construct an optic, the easiest way is to use the `@opticof` macro, which superficially behaves similarly to `Accessors.@optic` (for example, you can write `@opticof _[1].y.z`), but also supports automatic concretization by passing a second parameter (just like `@varname`). + +**Concretization** + +VarNames using 'dynamic' indices, i.e., `begin` and `end`, are now instantiated in a 'dynamic' form, meaning that these indices are unresolved. +These indices need to be resolved, or concretized, against the actual container. +For example, `@varname(x[end])` is dynamic, but when concretized against `x = randn(3)`, this becomes `@varname(x[3])`. +This can be done using `concretize(varname, x)`. + +The idea of concretization is not new to AbstractPPL. +However, there are some differences: + + - Colons are no longer concretized: they *always* remain as Colons, even after calling `concretize`. + - Previously, AbstractPPL would refuse to allow you to construct unconcretized versions of `begin` and `end`. This is no longer the case; you can now create such VarNames in their unconcretized forms. + This is useful, for example, when indexing into a chain that contains `x` as a variable-length vector. This change allows you to write `chain[@varname(x[end])]` without having AbstractPPL throw an error. + +**Keyword arguments to `getindex`** + +VarNames can now be constructed with keyword arguments in `Index` optics, for example `@varname(x[i=1])`. +This is specifically implemented to support DimensionalData.jl's DimArrays. + +**Other interface functions** + +The `vsym` function (and `@vsym`) has been removed; you should use `getsym(vn)` instead. + +The `Base.get` and `Accessors.set` methods for VarNames have been removed (these were responsible for method ambiguities). +Instead of using these methods you can first convert the `VarName` to an optic using `varname_to_optic(vn)`, and then use the getter and setter methods on the optics. + +VarNames cannot be composed with optics now (compose the optics yourself). + +The `inspace` function has been removed. +It used to be relevant for Turing's old Gibbs sampler; but now it no longer serves any use. + +`ConcretizedSlice` has been removed (since colons are no longer concretized). + +The subsumption interface has been pared down to just a single function, `subsumes`. +All other functions, such as `subsumedby`, `uncomparable`, and the Unicode operators, have been removed. + +Serialization still works exactly as before. +However, you will see differences in the serialization output compared to previous versions, due to the changes in the internal structure. + ## 0.13.6 Fix a missing qualifier in AbstractPPLDistributionsExt. diff --git a/Project.toml b/Project.toml index d6db8338..de3ab112 100644 --- a/Project.toml +++ b/Project.toml @@ -3,14 +3,16 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.13.6" +version = "0.14.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -18,15 +20,17 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" [extensions] -AbstractPPLDistributionsExt = ["Distributions"] +AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] [compat] AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" +BangBang = "0.4" DensityInterface = "0.4" Distributions = "0.25" JSON = "0.19 - 0.21, 1" -LinearAlgebra = "<0.0.1, 1.10" +LinearAlgebra = "<0.0.1, 1" +MacroTools = "0.5" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" -julia = "1.10" +julia = "1.10.8" diff --git a/docs/Project.toml b/docs/Project.toml index 15b2ec43..9aed942a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,3 +3,6 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[sources] +AbstractPPL = {path = "../"} diff --git a/docs/make.jl b/docs/make.jl index a94901da..abab3f69 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,7 +9,7 @@ DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive= makedocs(; sitename="AbstractPPL", modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)], - pages=["index.md", "api.md", "interface.md"], + pages=["index.md", "varname.md", "pplapi.md", "interface.md"], checkdocs=:exports, doctest=false, ) diff --git a/docs/src/api.md b/docs/src/api.md deleted file mode 100644 index 9f289ab5..00000000 --- a/docs/src/api.md +++ /dev/null @@ -1,64 +0,0 @@ -# API - -## VarNames - -```@docs -VarName -getsym -getoptic -inspace -subsumes -subsumedby -vsym -@varname -@vsym -``` - -## VarName prefixing and unprefixing - -```@docs -prefix -unprefix -``` - -## Extracting values corresponding to a VarName - -```@docs -hasvalue -getvalue -``` - -## Splitting VarNames up into components - -```@docs -varname_leaves -varname_and_value_leaves -``` - -## VarName serialisation - -```@docs -index_to_dict -dict_to_index -varname_to_string -string_to_varname -``` - -## Abstract model functions - -```@docs -AbstractProbabilisticProgram -condition -decondition -fix -unfix -logdensityof -AbstractContext -evaluate!! -``` - -## Abstract traces - -```@docs -AbstractModelTrace -``` diff --git a/docs/src/pplapi.md b/docs/src/pplapi.md new file mode 100644 index 00000000..492ab294 --- /dev/null +++ b/docs/src/pplapi.md @@ -0,0 +1,20 @@ +# Probabilistic programming API + +## Abstract model functions + +```@docs +AbstractProbabilisticProgram +condition +decondition +fix +unfix +logdensityof +AbstractContext +evaluate!! +``` + +## Abstract traces + +```@docs +AbstractModelTrace +``` diff --git a/docs/src/varname.md b/docs/src/varname.md new file mode 100644 index 00000000..88672ba5 --- /dev/null +++ b/docs/src/varname.md @@ -0,0 +1,257 @@ +# VarNames and optics + +## VarNames: an overview + +One of the most important parts of AbstractPPL.jl is the `VarName` type, which is used throughout the TuringLang ecosystem to represent names of random variables. + +Fundamentally, a `VarName` comprises a symbol (which represents the name of the variable itself) and an optic (which tells us which part of the variable we might be interested in). +For example, `x.a[1]` means the first element of the field `a` of the variable `x`. +Here, `x` is the symbol, and `.a[1]` is the optic. + +VarNames can be created using the `@varname` macro: + +```@example vn +using AbstractPPL + +vn = @varname(x.a[1]) +``` + +```@docs +VarName +@varname +varname +``` + +You can obtain the components of a `VarName` using the `getsym` and `getoptic` functions: + +```@example vn +getsym(vn), getoptic(vn) +``` + +```@docs +getsym +getoptic +``` + +## Dynamic indices + +VarNames may contain 'dynamic' indices, that is, indices whose meaning is not known until they are resolved against a specific value. +For example, `x[end]` refers to the last element of `x`; but we don't know what that means until we know what `x` is. + +Specifically, `begin` and `end` symbols in indices are treated as dynamic indices. +This is also true for any expression that contains `begin` or `end`, such as `end-1` or `1:3:end`. + +Dynamic indices are represented using an internal type, `AbstractPPL.DynamicIndex`. + +```@example vn +vn_dyn = @varname(x[1:2:end]) +``` + +You can detect whether a VarName contains dynamic indices using the `is_dynamic` function: + +```@example vn +is_dynamic(vn_dyn) +``` + +```@docs +is_dynamic +``` + +These dynamic indices can be resolved, or _concretized_, by passing a specific value to the `concretize` function: + +```@example vn +x = randn(5) +vn_conc = concretize(vn_dyn, x) +``` + +```@docs +concretize +``` + +## Optics + +The optics used in AbstractPPL.jl are represented as a linked list. +For example, the optic `.a[1]` is a `Property` optic that contains an `Index` optic as its child. +That means that the 'elements' of the linked list can be read from left-to-right: + +``` +Property{:a} -> Index{1} -> Iden +``` + +All optic linked lists are terminated with an `Iden` optic, which represents the identity function. + +```@example vn +optic = getoptic(@varname x.a[1]) +dump(optic) +``` + +```@docs +AbstractOptic +Property +Index +Iden +``` + +Instead of calling `getoptic(@varname(...))`, you can directly use the [`@opticof`](@ref) macro to create optics: + +```@example vn +optic = @opticof(_.a[1]) +``` + +```@docs +@opticof +``` + +## Getting and setting + +Optics are callable structs, and when passed a value will extract the relevant part of that +value. + +```@example vn +data = (a=[10, 20, 30], b="hello") +optic = @opticof(_.a[2]) +optic(data) +``` + +You can set values using `Accessors.set` (which AbstractPPL re-exports). +Note, though, that this will not mutate the original value. +Furthermore, you cannot use the handy macros like `Accessors.@set`, since those will use the +optics from Accessors.jl. + +```@example vn +new_data = set(data, optic, 99) +new_data, data +``` + +If you want to try to mutate values, you can wrap an optic using `with_mutation`. + +```@example vn +optic_mut = with_mutation(optic) +set(data, optic_mut, 99) +data +``` + +```@docs +with_mutation +``` + +## Composing and decomposing optics + +If you have two optics, you can compose them using the `∘` operator: + +```@example vn +optic1 = @opticof(_.a) +optic2 = @opticof(_[1]) +composed = optic2 ∘ optic1 +``` + +Notice the order of composition here, which can be counterintuitive: `optic2 ∘ optic1` means "first apply `optic1`, then apply `optic2`", and thus this represents the optic `.a[1]` (not `.[1].a`). + +```@docs +Base.:∘(::AbstractOptic, ::AbstractOptic) +``` + +`Base.cat(optics...)` is also provided, which composes optics in a more intuitive sense (indeed, if you think of an optic as a linked list, this can be thought of as concatenating the lists). +The following is equivalent to the previous example: + +```@example vn +composed2 = Base.cat(optic1, optic2) +``` + +```@docs +Base.cat(::AbstractOptic...) +``` + +Several functions are provided to decompose optics, which all stem from their linked-list structure. +Their names directly mirror Haskell's functions for decomposing lists, but are prefixed with `o`: + +```@docs +ohead +otail +oinit +olast +``` + +For example, `ohead` returns the first element of the optic linked list, and `otail` returns the rest of the list after removing the head: + +```@example vn +optic = @opticof(_.a[1].b[2]) +ohead(optic), otail(optic) +``` + +Convesely, `oinit` returns the optic linked list without its last element, and `olast` returns the last element: + +```@example vn +oinit(optic), olast(optic) +``` + +If the optic only has a single element, then `oinit` and `otail` return `Iden`, while `ohead` and `olast` return the optic itself: + +```@example vn +optic_single = @opticof(_.a) +oinit(optic_single), olast(optic_single), ohead(optic_single), otail(optic_single) +``` + +## Converting VarNames to optics and back + +Sometimes it is useful to treat a VarName's top level symbol as if it were part of the optic. +For example, when indexing into a NamedTuple `nt`, we might want to treat the entire VarName `x.a[1]` as an optic that can be applied to a NamedTuple: i.e., we want to access the `nt.x` field rather than the variable `x` itself. +This can be achieved with: + +```@docs +varname_to_optic +optic_to_varname +``` + +## Subsumption + +Sometimes, we want to check whether one VarName 'subsumes' another; that is, whether a VarName refers to a part of another VarName. +This is done using the [`subsumes`](@ref) function: + +```@example vn +vn1 = @varname(x.a) +vn2 = @varname(x.a[1]) +subsumes(vn1, vn2) +``` + +```@docs +subsumes +``` + +## Prefixing and unprefixing + +Composing two optics can be done using the `∘` operator, as shown above. +But what if we want to compose two `VarName`s? +This is used, for example, in DynamicPPL's submodel functionality. + +```@docs +prefix +unprefix +``` + +## VarName leaves + +The following functions are used to extract the 'leaves' of a VarName, that is, the atomic components of a VarName that do not have any further substructure. +For example, for a vector variable `x`, the leaves would be `x[1]`, `x[2]`, etc. + +```@docs +varname_leaves +varname_and_value_leaves +``` + +## Reading from a container with a VarName (or optic) + +```@docs +canview +hasvalue +getvalue +``` + +## Serializing VarNames + +```@docs +index_to_dict +dict_to_index +varname_to_string +string_to_varname +``` diff --git a/ext/AbstractPPLDistributionsExt.jl b/ext/AbstractPPLDistributionsExt.jl index d10b748c..80995365 100644 --- a/ext/AbstractPPLDistributionsExt.jl +++ b/ext/AbstractPPLDistributionsExt.jl @@ -49,36 +49,10 @@ This decision may be revisited in the future. module AbstractPPLDistributionsExt -using AbstractPPL: AbstractPPL, VarName, Accessors, LinearAlgebra +using AbstractPPL: AbstractPPL, VarName, Accessors using Distributions: Distributions using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular -#= -This section is copied from Accessors.jl's documentation: -https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/ - -It defines a wrapper that, when called with `set`, mutates the original value -rather than returning a new value. We need this because the non-mutating optics -don't work for triangular matrices (and hence LKJCholesky): see -https://github.com/JuliaObjects/Accessors.jl/issues/203 -=# -struct Lens!{L} - pure::L -end -(l::Lens!)(o) = l.pure(o) -function Accessors.set(o, l::Lens!{<:ComposedFunction}, val) - o_inner = l.pure.inner(o) - return Accessors.set(o_inner, Lens!(l.pure.outer), val) -end -function Accessors.set(o, l::Lens!{Accessors.PropertyLens{prop}}, val) where {prop} - setproperty!(o, prop, val) - return o -end -function Accessors.set(o, l::Lens!{<:Accessors.IndexLens}, val) - o[l.pure.indices...] = val - return o -end - """ get_optics(dist::MultivariateDistribution) get_optics(dist::MatrixDistribution) @@ -90,7 +64,7 @@ function get_optics( dist::Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} ) indices = CartesianIndices(size(dist)) - return map(idx -> Accessors.IndexLens(idx.I), indices) + return map(idx -> AbstractPPL.Index(idx.I, (;), AbstractPPL.Iden()), indices) end function get_optics(dist::Distributions.LKJCholesky) is_up = dist.uplo == 'U' @@ -100,8 +74,14 @@ function get_optics(dist::Distributions.LKJCholesky) end # there is an additional layer as we need to access `.L` or `.U` before we # can index into it - field_lens = is_up ? (Accessors.@o _.U) : (Accessors.@o _.L) - return map(idx -> Accessors.IndexLens(idx.I) ∘ field_lens, cartesian_indices) + function make_lens(idx) + if is_up + AbstractPPL.Property{:U}(AbstractPPL.Index(idx.I, (;), AbstractPPL.Iden())) + else + AbstractPPL.Property{:L}(AbstractPPL.Index(idx.I, (;), AbstractPPL.Iden())) + end + end + return map(make_lens, cartesian_indices) end """ @@ -311,10 +291,10 @@ function AbstractPPL.getvalue( # Retrieve the value of this given index sub_value = AbstractPPL.getvalue(vals, sub_vn) # Set it inside the value we're reconstructing. - # Note: `o` is normally non-mutating. We have to wrap it in `Lens!` - # to make it mutating, because Cholesky distributions are broken - # by https://github.com/JuliaObjects/Accessors.jl/issues/203. - Accessors.set(value, Lens!(o), sub_value) + # Note: `o` is normally non-mutating. We have to use the mutating version, + # because Cholesky distributions are broken by + # https://github.com/JuliaObjects/Accessors.jl/issues/203. + Accessors.set(value, AbstractPPL.with_mutation(o), sub_value) end return value else diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index d7585784..8db124eb 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -1,27 +1,5 @@ module AbstractPPL -# VarName -export VarName, - getsym, - getoptic, - inspace, - subsumes, - subsumedby, - varname, - vsym, - @varname, - @vsym, - index_to_dict, - dict_to_index, - varname_to_string, - string_to_varname, - prefix, - unprefix, - getvalue, - hasvalue, - varname_leaves, - varname_and_value_leaves - # Abstract model functions export AbstractProbabilisticProgram, condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!! @@ -32,6 +10,7 @@ export AbstractModelTrace include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") +include("varname/optic.jl") include("varname/varname.jl") include("varname/subsumes.jl") include("varname/hasvalue.jl") @@ -39,4 +18,44 @@ include("varname/leaves.jl") include("varname/prefix.jl") include("varname/serialize.jl") +# Optics +export AbstractOptic, + Iden, + Index, + Property, + with_mutation, + ohead, + otail, + olast, + oinit, + # VarName + VarName, + getsym, + getoptic, + concretize, + is_dynamic, + @varname, + varname, + @opticof, + varname_to_optic, + optic_to_varname, + # other functions + subsumes, + prefix, + unprefix, + hasvalue, + getvalue, + canview, + varname_leaves, + varname_and_value_leaves, + # Serialisation + index_to_dict, + dict_to_index, + varname_to_string, + string_to_varname + +# Convenience re-export +using Accessors: set +export set + end # module diff --git a/src/varname/hasvalue.jl b/src/varname/hasvalue.jl index 31c5f098..96a1246e 100644 --- a/src/varname/hasvalue.jl +++ b/src/varname/hasvalue.jl @@ -4,36 +4,56 @@ Return `true` if `optic` can be used to view `container`, and `false` otherwise. # Examples -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL.canview(@o(_.a), (a = 1.0, )) +```jldoctest +julia> AbstractPPL.canview(@opticof(_.a), (a = 1.0, )) true -julia> AbstractPPL.canview(@o(_.a), (b = 1.0, )) # property `a` does not exist +julia> AbstractPPL.canview(@opticof(_.a), (b = 1.0, )) # property `a` does not exist false -julia> AbstractPPL.canview(@o(_.a[1]), (a = [1.0, 2.0], )) +julia> AbstractPPL.canview(@opticof(_.a[1]), (a = [1.0, 2.0], )) true -julia> AbstractPPL.canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +julia> AbstractPPL.canview(@opticof(_.a[3]), (a = [1.0, 2.0], )) # out of bounds false ``` """ -canview(optic, container) = false -canview(::typeof(identity), _) = true -function canview(::Accessors.PropertyLens{field}, x) where {field} - return hasproperty(x, field) +canview(::AbstractOptic, ::Any) = false +canview(::Iden, ::Any) = true +function canview(prop::Property{field}, x) where {field} + return hasproperty(x, field) && canview(prop.child, getproperty(x, field)) end - -# `IndexLens`: only relevant if `x` supports indexing. -canview(optic::Accessors.IndexLens, x) = false -function canview(optic::Accessors.IndexLens, x::AbstractArray) - return checkbounds(Bool, x, optic.indices...) +function canview(optic::Index, x::AbstractArray) + return if isempty(optic.kw) + checkbounds(Bool, x, optic.ix...) && canview(optic.child, getindex(x, optic.ix...)) + else + _canview_fallback_kw(optic, x) + end end - -# `ComposedFunction`: check that we can view `.inner` and `.outer`, but using -# value extracted using `.inner`. -function canview(optic::ComposedFunction, x) - return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) +# Handle some other edge cases. +function canview(optic::Index, x::AbstractDict) + return isempty(optic.kw) && + optic.ix.length == 1 && + haskey(x, only(optic.ix)) && + canview(optic.child, x[only(optic.ix)]) +end +function canview(optic::Index, x::NamedTuple) + return isempty(optic.kw) && + optic.ix.length == 1 && + haskey(x, only(optic.ix)) && + canview(optic.child, x[only(optic.ix)]) +end +# For cases where there are keyword arguments, we don't have much of a choice but to +# try/catch. For example, this will be hit when using keyword arguments for DimArray. +# However, there's no `checkbounds` method (yet): +# https://github.com/rafaqz/DimensionalData.jl/issues/1156 +function _canview_fallback_kw(optic::Index, x) + try + v = getindex(x, optic.ix...; optic.kw...) + return canview(optic.child, v) + catch + return false + end end """ @@ -153,21 +173,21 @@ function getvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will # then keep removing optics from `test_optic`, either until we find a key # that is present, or until we run out of optics to test (which happens - # after getoptic(test_vn) == identity). + # after getoptic(test_vn) isa Iden). o = getoptic(vn) - test_vn = VarName{sym}(_init(o)) - test_optic = _last(o) + test_vn = VarName{sym}(oinit(o)) + test_optic = olast(o) while true if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) return test_optic(vals[test_vn]) else # Try to move the outermost optic from test_vn into test_optic. - # If test_vn is already an identity, we can't, so we stop. + # If test_vn is already an Iden, we can't, so we stop. o = getoptic(test_vn) - o == identity && error("$(vn) was not found in the dictionary provided") - test_vn = VarName{sym}(_init(o)) - test_optic = normalise(test_optic ∘ _last(o)) + o isa Iden && error("$(vn) was not found in the dictionary provided") + test_vn = VarName{sym}(oinit(o)) + test_optic = test_optic ∘ olast(o) end end end @@ -246,21 +266,21 @@ function hasvalue(vals::AbstractDict{<:VarName}, vn::VarName{sym}) where {sym} # `x[1][2]`, we start by checking if `x[1]` is present, then `x`). We will # then keep removing optics from `test_optic`, either until we find a key # that is present, or until we run out of optics to test (which happens - # after getoptic(test_vn) == identity). + # after getoptic(test_vn) == Iden()). o = getoptic(vn) - test_vn = VarName{sym}(_init(o)) - test_optic = _last(o) + test_vn = VarName{sym}(oinit(o)) + test_optic = olast(o) while true if haskey(vals, test_vn) && canview(test_optic, vals[test_vn]) return true else # Try to move the outermost optic from test_vn into test_optic. - # If test_vn is already an identity, we can't, so we stop. + # If test_vn is already an Iden, we can't, so we stop. o = getoptic(test_vn) - o == identity && return false - test_vn = VarName{sym}(_init(o)) - test_optic = normalise(test_optic ∘ _last(o)) + o isa Iden && return false + test_vn = VarName{sym}(oinit(o)) + test_optic = test_optic ∘ olast(o) end end return false diff --git a/src/varname/leaves.jl b/src/varname/leaves.jl index ca7418ed..6721e51f 100644 --- a/src/varname/leaves.jl +++ b/src/varname/leaves.jl @@ -1,5 +1,4 @@ using LinearAlgebra: LinearAlgebra - """ varname_leaves(vn::VarName, val) @@ -28,41 +27,41 @@ x.z[2][1] varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_leaves( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)), val[I] ) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() + optic = Property{k}(Iden()) varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) end return Iterators.flatten(iter) end function varname_leaves(vn::VarName, val::LinearAlgebra.Cholesky) return if val.uplo == 'L' - optic = Accessors.PropertyLens{:L}() + optic = Property{:L}(Iden()) varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), val.L) else - optic = Accessors.PropertyLens{:U}() + optic = Property{:U}(Iden()) varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), val.U) end end function varname_leaves(vn::VarName, x::LinearAlgebra.LowerTriangular) return Iterators.map(Iterators.filter(I -> I[1] >= I[2], CartesianIndices(x))) do I - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)) end end function varname_leaves(vn::VarName, x::LinearAlgebra.UpperTriangular) return Iterators.map(Iterators.filter(I -> I[1] <= I[2], CartesianIndices(x))) do I - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)) end end # This is a default fallback for types that we don't know how to handle. @@ -215,7 +214,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ AbstractPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) @@ -224,14 +223,14 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), + VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ AbstractPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) end function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() + optic = Property{k}(Iden()) varname_and_value_leaves_inner( VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) ) @@ -242,21 +241,25 @@ end # Special types. function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky) return if x.uplo == 'L' - varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) + varname_and_value_leaves_inner( + VarName{getsym(vn)}(Property{:L}(Iden()) ∘ getoptic(vn)), x.L + ) else - varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) + varname_and_value_leaves_inner( + VarName{getsym(vn)}(Property{:U}(Iden()) ∘ getoptic(vn)), x.U + ) end end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) + Leaf(VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)), x[I]) # Iteration over the lower-triangular indices. for I in CartesianIndices(x) if I[1] >= I[2] ) end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) + Leaf(VarName{getsym(vn)}(Index(Tuple(I), (;), Iden()) ∘ getoptic(vn)), x[I]) # Iteration over the upper-triangular indices. for I in CartesianIndices(x) if I[1] <= I[2] ) diff --git a/src/varname/optic.jl b/src/varname/optic.jl new file mode 100644 index 00000000..ba0de07e --- /dev/null +++ b/src/varname/optic.jl @@ -0,0 +1,547 @@ +using Accessors: Accessors +using BangBang: BangBang +using MacroTools: MacroTools + +""" + AbstractOptic + +An abstract type that represents the non-symbol part of a VarName, i.e., the section of the +variable that is of interest. For example, in `x.a[1][2]`, the `AbstractOptic` represents +the `.a[1][2]` part. +""" +abstract type AbstractOptic end +function Base.show(io::IO, optic::AbstractOptic) + print(io, "Optic(") + _pretty_print_optic(io, optic) + return print(io, ")") +end +# Lots of (==) and isequal methods as a workaround for +# https://github.com/JuliaLang/julia/issues/60470 :( +Base.:(==)(a::AbstractOptic, b::AbstractOptic) = false +Base.isequal(a::AbstractOptic, b::AbstractOptic) = false + +""" + Iden() + +The identity optic. This is the optic used when we are referring to the entire variable. +It is also the base case for composing optics. +""" +struct Iden <: AbstractOptic end +_pretty_print_optic(::IO, ::Iden) = nothing +is_dynamic(::Iden) = false +concretize(i::Iden, ::Any) = i +(::Iden)(obj) = obj +Base.:(==)(::Iden, ::Iden) = true +Base.isequal(::Iden, ::Iden) = true +Accessors.set(obj::Any, ::Iden, val) = Accessors.set(obj, identity, val) + +""" + DynamicIndex + +An abstract type representing dynamic indices such as `begin`, `end`, and `:`. These indices +are things which cannot be resolved until we provide the value that is being indexed into. +When parsing VarNames, we convert such indices into subtypes of `DynamicIndex`. + +Because a `DynamicIndex` cannot be resolved until we have the value being indexed into, it +is actually a wrapper around a function that, when called on the value, returns the concrete +index. + +For example: + +- the index `begin` is turned into `DynamicIndex(:begin, (val) -> Base.firstindex(val))`. +- the index `1:end` is turned into `DynamicIndex(:(1:end), (val) -> 1:Base.lastindex(val))`. + +# The stored `Expr` + +The `expr` field stores the original expression and is used both for pretty-printing as well +as comparisons. + +Note that because the stored function `f` is an anonymous function that is generated +dynamically, we should not include it in equality comparisons, as two functions that are +actually equivalent will not compare equal: + +```julia +julia> (x -> x + 1) == (x -> x + 1) +false +``` + +But, thankfully, we can just compare the `expr` field to determine whether the +DynamicIndices were constructed from the same expression (which implies that their +functions are equivalent). + +Note that these definitions also allow us some degree of resilience towards whitespace +changes, or parenthesisation, in the original expression. For example, `begin+1` and `(begin ++ 1)` will be treated as the same expression. However, it does not handle commutative +expressions; e.g., `begin + 1` and `1 + begin` will be treated as different expressions. +""" +struct DynamicIndex{E<:Union{Expr,Symbol},F} + expr::E + f::F +end +Base.:(==)(a::DynamicIndex, b::DynamicIndex) = a.expr == b.expr +Base.isequal(a::DynamicIndex, b::DynamicIndex) = isequal(a.expr, b.expr) +Base.hash(di::DynamicIndex, h::UInt) = hash(di.expr, h) + +function _make_dynamicindex_expr(symbol::Symbol, dim::Union{Nothing,Int}) + # NOTE(penelopeysm): We could just use `:end` instead of Symbol(:end), but the former + # messes up syntax highlighting with Treesitter + # https://github.com/tree-sitter/tree-sitter-julia/issues/104 + if symbol === Symbol(:begin) + func = dim === nothing ? :(Base.firstindex) : :(Base.Fix2(firstindex, $dim)) + return :($(DynamicIndex)($(QuoteNode(symbol)), $func)) + elseif symbol === Symbol(:end) + func = dim === nothing ? :(Base.lastindex) : :(Base.Fix2(lastindex, $dim)) + return :($(DynamicIndex)($(QuoteNode(symbol)), $func)) + else + # Just a variable; but we need to escape it to allow interpolation. + return esc(symbol) + end +end +function _make_dynamicindex_expr(expr::Expr, dim::Union{Nothing,Int}) + @gensym val + if has_begin_or_end(expr) + replaced_expr = MacroTools.postwalk(x -> _make_dynamicindex_expr(x, val, dim), expr) + return :($(DynamicIndex)($(QuoteNode(expr)), $val -> $replaced_expr)) + else + return esc(expr) + end +end +function _make_dynamicindex_expr(symbol::Symbol, val_sym::Symbol, dim::Union{Nothing,Int}) + # NOTE(penelopeysm): We could just use `:end` instead of Symbol(:end), but the former + # messes up syntax highlighting with Treesitter + # https://github.com/tree-sitter/tree-sitter-julia/issues/104 + if symbol === Symbol(:begin) + return if dim === nothing + :(Base.firstindex($val_sym)) + else + :(Base.Fix2(firstindex, $dim)($val_sym)) + end + elseif symbol === Symbol(:end) + return if dim === nothing + :(Base.lastindex($val_sym)) + else + :(Base.Fix2(lastindex, $dim)($val_sym)) + end + else + # Just a variable; but we need to escape it to allow interpolation. + return esc(symbol) + end +end +function _make_dynamicindex_expr(i::Any, ::Symbol, ::Union{Nothing,Int}) + # this handles things like integers, colons, etc. + return i +end + +has_begin_or_end(expr::Expr) = has_begin_or_end_inner(expr, false) +function has_begin_or_end_inner(x, found::Bool) + return found || + x ∈ (:end, :begin, Expr(:end), Expr(:begin)) || + (x isa Expr && any(arg -> has_begin_or_end_inner(arg, found), x.args)) +end + +_pretty_string_index(ix) = string(ix) +_pretty_string_index(::Colon) = ":" +_pretty_string_index(x::Symbol) = repr(x) +_pretty_string_index(x::String) = repr(x) +_pretty_string_index(di::DynamicIndex) = "DynamicIndex($(di.expr))" + +_concretize_index(idx::Any, ::Any) = idx +_concretize_index(idx::DynamicIndex, val) = idx.f(val) + +""" + Index(ix, kw, child=Iden()) + +An indexing optic representing access to indices `ix`, which may also take the form of +keyword arguments `kw`. A `VarName{:x}` with this optic represents access to `x[ix..., +kw...]`. The child optic represents any further indexing or property access after this +indexing operation. +""" +struct Index{I<:Tuple,N<:NamedTuple,C<:AbstractOptic} <: AbstractOptic + ix::I + kw::N + child::C + function Index(ix::Tuple, kw::NamedTuple, child::C=Iden()) where {C<:AbstractOptic} + return new{typeof(ix),typeof(kw),C}(ix, kw, child) + end +end + +# Workaround for https://github.com/JuliaLang/julia/issues/60470 +# In particular, these four methods are new: +_tuple_eq_inner(::Tuple, ::Tuple{}) = false +_tuple_eq_inner(::Tuple{}, ::Tuple) = false +_tuple_eq_inner_missing(::Tuple, ::Tuple{}) = false +_tuple_eq_inner_missing(::Tuple{}, ::Tuple) = false +# These methods are directly copied from Base tuple.jl and correspond to the usual equality +# checks for tuples +_tuple_eq(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _tuple_eq_inner(t1, t2) +_tuple_eq_inner(::Tuple{}, ::Tuple{}) = true +_tuple_eq_inner_missing(::Tuple{}, ::Tuple{}) = missing +function _tuple_eq_inner(t1::Tuple, t2::Tuple) + eq = t1[1] == t2[1] + if eq === false + return false + elseif ismissing(eq) + return _tuple_eq_inner_missing(Base.tail(t1), Base.tail(t2)) + else + return _tuple_eq_inner(Base.tail(t1), Base.tail(t2)) + end +end +function _tuple_eq_inner_missing(t1::Tuple, t2::Tuple) + eq = t1[1] == t2[1] + if eq === false + return false + else + return _tuple_eq_inner_missing(Base.tail(t1), Base.tail(t2)) + end +end +function _tuple_eq_inner(t1::Base.Any32, t2::Base.Any32) + anymissing = false + for i in eachindex(t1, t2) + eq = (t1[i] == t2[i]) + if ismissing(eq) + anymissing = true + elseif !eq + return false + end + end + return anymissing ? missing : true +end +# Because the Base methods for NamedTuple equality rely on tuple equality, we also need to +# patch that :( +_nt_eq(a::NamedTuple{n}, b::NamedTuple{n}) where {n} = _tuple_eq(Tuple(a), Tuple(b)) +_nt_eq(a::NamedTuple, b::NamedTuple) = false + +# Note: Do NOT change this to `a.ix == b.ix`! This is a workaround for +# https://github.com/JuliaLang/julia/issues/60470 +function Base.:(==)(a::Index, b::Index) + return _tuple_eq(a.ix, b.ix) && _nt_eq(a.kw, b.kw) && a.child == b.child +end +function Base.isequal(a::Index, b::Index) + return isequal(a.ix, b.ix) && isequal(a.kw, b.kw) && isequal(a.child, b.child) +end +Base.hash(a::Index, h::UInt) = hash((a.ix, a.kw, a.child), h) +function _pretty_print_optic(io::IO, idx::Index) + ixs = collect(map(_pretty_string_index, idx.ix)) + kws = map( + kv -> "$(kv.first)=$(_pretty_string_index(kv.second))", collect(pairs(idx.kw)) + ) + print(io, "[$(join(vcat(ixs, kws), ", "))]") + return _pretty_print_optic(io, idx.child) +end +is_dynamic(idx::Index) = any(ix -> ix isa DynamicIndex, idx.ix) || is_dynamic(idx.child) + +# Helper function to decide whether to use `view` or `getindex`. For AbstractArray, the +# default behaviour is to attempt to use a view. +_maybe_view(val::AbstractArray, i...; k...) = view(val, i...; k...) +# If it's just a single element, don't use a view, as that returns a weird 0-dimensional +# SubArray (rather than an element) that messes things up if there are further layers of +# optics. For example, if it's an Array of NamedTuples, then trying to access fields on that +# 0-dimensional SubArray will fail. +_maybe_view(val::AbstractArray, i::Int...) = getindex(val, i...) +# Other things like dictionaries can't be `view`ed into. +_maybe_view(val, i...; k...) = getindex(val, i...; k...) + +function concretize(idx::Index, val) + concretized_indices = tuple(map(Base.Fix2(_concretize_index, val), idx.ix)...) + inner_concretized = if idx.child isa Iden + # Explicitly having this branch allows us to shortcircuit _maybe_view(...), which + # can error if val[concretized_indices...] is an UndefInitializer. Note that if + # val[concretized_indices...] is an UndefInitializer, then it is not meaningful for + # `idx.child` to be anything other than `Iden` anyway, since there is nothing to + # further index into. + Iden() + else + concretize(idx.child, _maybe_view(val, concretized_indices...; idx.kw...)) + end + return Index(concretized_indices, idx.kw, inner_concretized) +end +function (idx::Index)(obj) + cidx = concretize(idx, obj) + return cidx.child(Base.getindex(obj, cidx.ix...; cidx.kw...)) +end +function Accessors.set(obj, idx::Index, newval) + cidx = concretize(idx, obj) + inner_newval = if idx.child isa Iden + newval + else + inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) + Accessors.set(inner_obj, idx.child, newval) + end + return if !isempty(cidx.kw) + # `Accessors.IndexLens` does not handle keyword arguments so we need to do this + # ourselves. Note that the following code essentially assumes that `obj` is an + # AbstractArray or similar type that directly implements `setindex!`. + newobj = similar(obj) + copy!(newobj, obj) + Base.setindex!(newobj, inner_newval, cidx.ix...; cidx.kw...) + newobj + else + # Defer to Accessors' implementation, so that we don't have to reinvent the wheel + # (well, not more than what we have already done...). This is helpful because + # Accessors implements a lot of methods for different types of `obj`. + Accessors.set(obj, Accessors.IndexLens(cidx.ix), inner_newval) + end +end + +""" + Property{sym}(child=Iden()) + +A property access optic representing access to property `sym`. A `VarName{:x}` with this +optic represents access to `x.sym`. The child optic represents any further indexing or +property access after this property access operation. +""" +struct Property{sym,C<:AbstractOptic} <: AbstractOptic + child::C +end +Property{sym}(child::C=Iden()) where {sym,C<:AbstractOptic} = Property{sym,C}(child) + +Base.:(==)(a::Property{sym}, b::Property{sym}) where {sym} = a.child == b.child +Base.:(==)(a::Property, b::Property) = false +Base.isequal(a::Property{sym}, b::Property{sym}) where {sym} = isequal(a.child, b.child) +Base.isequal(a::Property, b::Property) = false +getsym(::Property{s}) where {s} = s +function _pretty_print_optic(io::IO, prop::Property{sym}) where {sym} + print(io, ".$(sym)") + return _pretty_print_optic(io, prop.child) +end +is_dynamic(prop::Property) = is_dynamic(prop.child) +function concretize(prop::Property{sym}, val) where {sym} + inner_concretized = concretize(prop.child, getproperty(val, sym)) + return Property{sym}(inner_concretized) +end +function (prop::Property{sym})(obj) where {sym} + return prop.child(getproperty(obj, sym)) +end +function Accessors.set(obj, prop::Property{sym}, newval) where {sym} + inner_obj = getproperty(obj, sym) + inner_newval = Accessors.set(inner_obj, prop.child, newval) + # Defer to Accessors' implementation again. + return Accessors.set(obj, Accessors.PropertyLens{sym}(), inner_newval) +end + +""" + ∘(outer::AbstractOptic, inner::AbstractOptic) + +Compose two `AbstractOptic`s together. + +```jldoctest +julia> p1 = @opticof(_.a[1]) +Optic(.a[1]) + +julia> p2 = @opticof(_.b[2, 3]) +Optic(.b[2, 3]) + +julia> p1 ∘ p2 +Optic(.b[2, 3].a[1]) +``` +""" +function Base.:(∘)(outer::AbstractOptic, inner::AbstractOptic) + if outer isa Iden + return inner + elseif inner isa Iden + return outer + else + if inner isa Property + return Property{getsym(inner)}(outer ∘ inner.child) + elseif inner isa Index + return Index(inner.ix, inner.kw, outer ∘ inner.child) + else + error("unreachable; unknown AbstractOptic subtype $(typeof(inner))") + end + end +end + +""" + cat(optics::AbstractOptic...) + +Compose multiple `AbstractOptic`s together. The optics should be provided from +innermost to outermost, i.e., `cat(o1, o2, o3)` corresponds to `o3 ∘ o2 ∘ o1`. + +""" +function Base.cat(optics::AbstractOptic...) + return foldl((a, b) -> b ∘ a, optics; init=Iden()) +end + +""" + ohead(optic::AbstractOptic) + +Get the innermost layer of an optic. For all optics, we have that `otail(optic) ∘ +ohead(optic) == optic`. + +```jldoctest +julia> ohead(@opticof _.a[1][2]) +Optic(.a) + +julia> ohead(@opticof _) +Optic() +``` +""" +ohead(::Property{s}) where {s} = Property{s}(Iden()) +ohead(idx::Index) = Index(idx.ix, idx.kw, Iden()) +ohead(i::Iden) = i + +""" + otail(optic::AbstractOptic) + +Get everything but the innermost layer of an optic. For all optics, we have that +`otail(optic) ∘ ohead(optic) == optic`. + +```jldoctest +julia> otail(@opticof _.a[1][2]) +Optic([1][2]) + +julia> otail(@opticof _) +Optic() +``` +""" +otail(p::Property) = p.child +otail(idx::Index) = idx.child +otail(i::Iden) = i + +""" + olast(optic::AbstractOptic) + +Get the outermost layer of an optic. For all optics, we have that `olast(optic) ∘ +oinit(optic) == optic`. + +```jldoctest +julia> olast(@opticof _.a[1][2]) +Optic([2]) + +julia> olast(@opticof _) +Optic() +``` +""" +function olast(p::Property{s}) where {s} + if p.child isa Iden + return p + else + return olast(p.child) + end +end +function olast(idx::Index) + if idx.child isa Iden + return idx + else + return olast(idx.child) + end +end +olast(i::Iden) = i + +""" + oinit(optic::AbstractOptic) + +Get everything but the outermost layer of an optic. For all optics, we have that +`olast(optic) ∘ oinit(optic) == optic`. + +```jldoctest +julia> oinit(@opticof _.a[1][2]) +Optic(.a[1]) + +julia> oinit(@opticof _) +Optic() +``` +""" +function oinit(p::Property{s}) where {s} + return if p.child isa Iden + Iden() + else + Property{s}(oinit(p.child)) + end +end +function oinit(idx::Index) + return if idx.child isa Iden + Iden() + else + Index(idx.ix, idx.kw, oinit(idx.child)) + end +end +oinit(i::Iden) = i + +""" + with_mutation(o::AbstractOptic) + +Create a version of the optic `o` which attempts to mutate its input where possible. + +On their own, `AbstractOptic`s are non-mutating: + +```jldoctest +julia> optic = @opticof(_[1]) +Optic([1]) + +julia> x = [0.0, 0.0]; + +julia> set(x, optic, 1.0); x +2-element Vector{Float64}: + 0.0 + 0.0 +``` + +With this function, we can create a mutating version of the optic: + +```jldoctest +julia> optic_mut = with_mutation(@opticof(_[1])) +Optic!!([1]) + +julia> x = [0.0, 0.0]; + +julia> set(x, optic_mut, 1.0); x +2-element Vector{Float64}: + 1.0 + 0.0 +``` + +Thanks to the BangBang.jl package, this optic will gracefully fall back to non-mutating +behaviour if mutation is not possible. For example, if we try to use it on a tuple: + +```jldoctest +julia> optic_mut = with_mutation(@opticof(_[1])) +Optic!!([1]) + +julia> x = (0.0, 0.0); + +julia> set(x, optic_mut, 1.0); x +(0.0, 0.0) +``` +""" +with_mutation(o::AbstractOptic) = _Optic!!(o) + +struct _Optic!!{O<:AbstractOptic} <: AbstractOptic + pure::O +end +_Optic!!(l::_Optic!!) = l +Base.:(==)(a::_Optic!!, b::_Optic!!) = a.pure == b.pure +Base.isequal(a::_Optic!!, b::_Optic!!) = isequal(a.pure, b.pure) +Base.hash(a::_Optic!!, h::UInt) = hash("_Optic!!", hash(a.pure, h)) +function Base.show(io::IO, l::_Optic!!) + print(io, "Optic!!(") + _pretty_print_optic(io, l.pure) + return print(io, ")") +end + +# Getter. +(l::_Optic!!)(o) = l.pure(o) +# Setters. +Accessors.set(::Any, l::_Optic!!{Iden}, val) = val +function Accessors.set(obj, l::_Optic!!{<:Property{sym}}, newval) where {sym} + inner_obj = getproperty(obj, sym) + inner_newval = Accessors.set(inner_obj, _Optic!!(l.pure.child), newval) + return BangBang.setproperty!!(obj, sym, inner_newval) +end +function Accessors.set(obj, l::_Optic!!{<:Index}, newval) + cidx = concretize(l.pure, obj) + inner_newval = if l.pure.child isa Iden + newval + else + inner_obj = Base.getindex(obj, cidx.ix...; cidx.kw...) + Accessors.set(inner_obj, _Optic!!(l.pure.child), newval) + end + return if isempty(cidx.kw) + # setindex!! doesn't support keyword arguments. + BangBang.setindex!!(obj, inner_newval, cidx.ix...) + else + # If there are keyword arguments, we just assume that setindex! will always work. + # This is a bit dangerous, but fine. + Base.setindex!(obj, inner_newval, cidx.ix...; cidx.kw...) + end +end diff --git a/src/varname/prefix.jl b/src/varname/prefix.jl index 25d4485b..943ba5c9 100644 --- a/src/varname/prefix.jl +++ b/src/varname/prefix.jl @@ -1,63 +1,12 @@ -### Functionality for prefixing and unprefixing VarNames. - -""" - optic_to_vn(optic) - -Convert an Accessors optic to a VarName. This is best explained through -examples. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL.optic_to_vn(Accessors.@o _.a) -a - -julia> AbstractPPL.optic_to_vn(Accessors.@o _.a.b) -a.b - -julia> AbstractPPL.optic_to_vn(Accessors.@o _.a[1]) -a[1] -``` - -The outermost layer of the optic (technically, what Accessors.jl calls the -'innermost') must be a `PropertyLens`, or else it will fail. This is because a -VarName needs to have a symbol. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL.optic_to_vn(Accessors.@o _[1]) -ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarName -[...] -``` -""" -function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} - return VarName{sym}() -end -function optic_to_vn( - o::ComposedFunction{Outer,Accessors.PropertyLens{sym}} -) where {Outer,sym} - return VarName{sym}(o.outer) -end -optic_to_vn(o::ComposedFunction) = optic_to_vn(normalise(o)) -function optic_to_vn(@nospecialize(o)) - msg = "optic_to_vn: could not convert optic `$o` to a VarName" - throw(ArgumentError(msg)) -end - -unprefix_optic(o, ::typeof(identity)) = o # Base case -function unprefix_optic(optic, optic_prefix) - # Technically `unprefix_optic` only receives optics that were part of - # VarNames, so the optics should already be normalised (in the inner - # constructor of the VarName). However I guess it doesn't hurt to do it - # again to be safe. - optic = normalise(optic) - optic_prefix = normalise(optic_prefix) - # strip one layer of the optic and check for equality - head = _head(optic) - head_prefix = _head(optic_prefix) +_unprefix_optic(o, ::Iden) = o +function _unprefix_optic(optic, optic_prefix) + head = ohead(optic) + head_prefix = ohead(optic_prefix) if head != head_prefix - msg = "could not remove prefix $(optic_prefix) from optic $(optic)" + msg = "cannot remove prefix $(optic_prefix) from optic $(optic)" throw(ArgumentError(msg)) end - # recurse - return unprefix_optic(_tail(optic), _tail(optic_prefix)) + return _unprefix_optic(otail(optic), otail(optic_prefix)) end """ @@ -66,17 +15,21 @@ end Remove a prefix from a VarName. ```jldoctest -julia> AbstractPPL.unprefix(@varname(y.x), @varname(y)) +julia> unprefix(@varname(y.x), @varname(y)) x -julia> AbstractPPL.unprefix(@varname(y.x.a), @varname(y)) +julia> unprefix(@varname(y.x.a), @varname(y)) x.a -julia> AbstractPPL.unprefix(@varname(y[1].x), @varname(y[1])) +julia> unprefix(@varname(y[1].x), @varname(y[1])) x -julia> AbstractPPL.unprefix(@varname(y), @varname(n)) -ERROR: ArgumentError: could not remove prefix n from VarName y +julia> unprefix(@varname(y), @varname(n)) +ERROR: ArgumentError: cannot remove prefix n from VarName y +[...] + +julia> unprefix(@varname(y[1]), @varname(y)) +ERROR: ArgumentError: optic_to_varname: can only convert Property optics to VarName [...] ``` """ @@ -84,12 +37,12 @@ function unprefix( vn::VarName{sym_vn}, prefix::VarName{sym_prefix} ) where {sym_vn,sym_prefix} if sym_vn != sym_prefix - msg = "could not remove prefix $(prefix) from VarName $(vn)" + msg = "cannot remove prefix $(prefix) from VarName $(vn)" throw(ArgumentError(msg)) end optic_vn = getoptic(vn) optic_prefix = getoptic(prefix) - return optic_to_vn(unprefix_optic(optic_vn, optic_prefix)) + return optic_to_varname(_unprefix_optic(optic_vn, optic_prefix)) end """ @@ -98,19 +51,17 @@ end Add a prefix to a VarName. ```jldoctest -julia> AbstractPPL.prefix(@varname(x), @varname(y)) +julia> prefix(@varname(x), @varname(y)) y.x -julia> AbstractPPL.prefix(@varname(x.a), @varname(y)) +julia> prefix(@varname(x.a), @varname(y)) y.x.a -julia> AbstractPPL.prefix(@varname(x.a), @varname(y[1])) +julia> prefix(@varname(x.a), @varname(y[1])) y[1].x.a ``` """ function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix} - optic_vn = getoptic(vn) - optic_prefix = getoptic(prefix) - new_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() ∘ optic_prefix + new_optic_vn = varname_to_optic(vn) ∘ getoptic(prefix) return VarName{sym_prefix}(new_optic_vn) end diff --git a/src/varname/serialize.jl b/src/varname/serialize.jl index 7feb4896..7dca9d08 100644 --- a/src/varname/serialize.jl +++ b/src/varname/serialize.jl @@ -10,7 +10,6 @@ const _BASE_UNITRANGE_TYPE = "Base.UnitRange" const _BASE_STEPRANGE_TYPE = "Base.StepRange" const _BASE_ONETO_TYPE = "Base.OneTo" const _BASE_COLON_TYPE = "Base.Colon" -const _CONCRETIZED_SLICE_TYPE = "AbstractPPL.ConcretizedSlice" const _BASE_TUPLE_TYPE = "Base.Tuple" """ @@ -41,12 +40,16 @@ function index_to_dict(r::Base.OneTo{I}) where {I} return Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop) end index_to_dict(::Colon) = Dict("type" => _BASE_COLON_TYPE) -function index_to_dict(s::ConcretizedSlice{T,R}) where {T,R} - return Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range)) -end function index_to_dict(t::Tuple) return Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t)) end +function index_to_dict(::DynamicIndex) + throw( + ArgumentError( + "DynamicIndex cannot be serialised; please concretise the VarName before serialising.", + ), + ) +end """ dict_to_index(dict) @@ -90,8 +93,6 @@ function dict_to_index(dict) return Base.OneTo(dict["stop"]) elseif t == _BASE_COLON_TYPE return Colon() - elseif t == _CONCRETIZED_SLICE_TYPE - return ConcretizedSlice(Base.Slice(dict_to_index(dict["range"]))) elseif t == _BASE_TUPLE_TYPE return tuple(map(dict_to_index, dict["values"])...) else @@ -101,28 +102,39 @@ function dict_to_index(dict) end end -optic_to_dict(::typeof(identity)) = Dict("type" => "identity") -function optic_to_dict(::PropertyLens{sym}) where {sym} - return Dict("type" => "property", "field" => String(sym)) +const _OPTIC_IDEN_NAME = "Iden" +const _OPTIC_PROPERTY_NAME = "Property" +const _OPTIC_INDEX_NAME = "Index" +optic_to_dict(::Iden) = Dict("type" => _OPTIC_IDEN_NAME) +function optic_to_dict(p::Property{sym}) where {sym} + return Dict( + "type" => _OPTIC_PROPERTY_NAME, + "field" => String(sym), + "child" => optic_to_dict(p.child), + ) end -optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices)) -function optic_to_dict(c::ComposedFunction) +function optic_to_dict(i::Index) return Dict( - "type" => "composed", - "outer" => optic_to_dict(c.outer), - "inner" => optic_to_dict(c.inner), + "type" => _OPTIC_INDEX_NAME, + # For some reason if you don't do the isempty check, it gets serialised as `{}` + # rather than `[]` + "ix" => isempty(i.ix) ? [] : collect(map(index_to_dict, i.ix)), + # TODO(penelopeysm): This is potentially lossy since order is not guaranteed + "kw" => Dict(String(x) => index_to_dict(y) for (x, y) in pairs(i.kw)), + "child" => optic_to_dict(i.child), ) end function dict_to_optic(dict) - if dict["type"] == "identity" - return identity - elseif dict["type"] == "index" - return IndexLens(dict_to_index(dict["indices"])) - elseif dict["type"] == "property" - return PropertyLens{Symbol(dict["field"])}() - elseif dict["type"] == "composed" - return dict_to_optic(dict["outer"]) ∘ dict_to_optic(dict["inner"]) + if dict["type"] == _OPTIC_IDEN_NAME + return Iden() + elseif dict["type"] == _OPTIC_INDEX_NAME + ixs = tuple(map(dict_to_index, dict["ix"])...) + kws = NamedTuple(Symbol(k) => dict_to_index(v) for (k, v) in dict["kw"]) + child = dict_to_optic(dict["child"]) + return Index(ixs, kws, child) + elseif dict["type"] == _OPTIC_PROPERTY_NAME + return Property{Symbol(dict["field"])}(dict_to_optic(dict["child"])) else error("Unknown optic type: $(dict["type"])") end @@ -151,16 +163,10 @@ documentation of [`dict_to_index`](@ref) for instructions on how to do this. ```jldoctest julia> varname_to_string(@varname(x)) -"{\\"optic\\":{\\"type\\":\\"identity\\"},\\"sym\\":\\"x\\"}" +"{\\"optic\\":{\\"type\\":\\"Iden\\"},\\"sym\\":\\"x\\"}" julia> varname_to_string(@varname(x.a)) -"{\\"optic\\":{\\"field\\":\\"a\\",\\"type\\":\\"property\\"},\\"sym\\":\\"x\\"}" - -julia> y = ones(2); varname_to_string(@varname(y[:])) -"{\\"optic\\":{\\"indices\\":{\\"values\\":[{\\"type\\":\\"Base.Colon\\"}],\\"type\\":\\"Base.Tuple\\"},\\"type\\":\\"index\\"},\\"sym\\":\\"y\\"}" - -julia> y = ones(2); varname_to_string(@varname(y[:], true)) -"{\\"optic\\":{\\"indices\\":{\\"values\\":[{\\"range\\":{\\"stop\\":2,\\"type\\":\\"Base.OneTo\\"},\\"type\\":\\"AbstractPPL.ConcretizedSlice\\"}],\\"type\\":\\"Base.Tuple\\"},\\"type\\":\\"index\\"},\\"sym\\":\\"y\\"}" +"{\\"optic\\":{\\"child\\":{\\"type\\":\\"Iden\\"},\\"field\\":\\"a\\",\\"type\\":\\"Property\\"},\\"sym\\":\\"x\\"}" ``` """ varname_to_string(vn::VarName) = JSON.json(varname_to_dict(vn)) diff --git a/src/varname/subsumes.jl b/src/varname/subsumes.jl index f43a92b4..2a29b5e5 100644 --- a/src/varname/subsumes.jl +++ b/src/varname/subsumes.jl @@ -1,240 +1,78 @@ """ - inspace(vn::Union{VarName, Symbol}, space::Tuple) + subsumes(parent::VarName, child::VarName) -Check whether `vn`'s variable symbol is in `space`. The empty tuple counts as the "universal space" -containing all variables. Subsumption (see [`subsumes`](@ref)) is respected. - -## Examples +Check whether the variable name `child` describes a sub-range of the variable `parent`, +i.e., is contained within it. ```jldoctest -julia> inspace(@varname(x[1][2:3]), ()) -true - -julia> inspace(@varname(x[1][2:3]), (:x,)) -true - -julia> inspace(@varname(x[1][2:3]), (@varname(x),)) -true - -julia> inspace(@varname(x[1][2:3]), (@varname(x[1:10]), :y)) +julia> subsumes(@varname(x), @varname(x[1, 2])) true -julia> inspace(@varname(x[1][2:3]), (@varname(x[:][2:4]), :y)) -true - -julia> inspace(@varname(x[1][2:3]), (@varname(x[1:10]),)) +julia> subsumes(@varname(x[1, 2]), @varname(x[1, 2][3])) true ``` -""" -inspace(vn, space::Tuple{}) = true # empty tuple is treated as universal space -inspace(vn, space::Tuple) = vn in space -inspace(vn::VarName, space::Tuple{}) = true -inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space) -_in(vn::VarName, s::Symbol) = getsym(vn) == s -_in(vn::VarName, s::VarName) = subsumes(s, vn) +This is done by recursively comparing each layer of the VarNames' optics. -""" - subsumes(u::VarName, v::VarName) +Note that often this is not possible to determine statically, and so the results should +not be over-interpreted. In particular, `Index` optics pose a problem. An `i::Index` will +only subsume `j::Index` if: -Check whether the variable name `v` describes a sub-range of the variable `u`. Supported -indexing: +1. They have the same number of positional indices (`i.ix` and `j.ix`); +2. Each positional index in `i` can be determined to comprise the corresponding positional + index in `j`; and +3. The keyword indices of `i` (`i.kw`) are a superset of those in `j.kw`). - - Scalar: +In all other cases, `subsumes` will conservatively return `false`, even though in practice +it might well be that `i` does subsume `j`. Some examples where subsumption cannot be +determined statically are: - ```jldoctest - julia> subsumes(@varname(x), @varname(x[1, 2])) - true - - julia> subsumes(@varname(x[1, 2]), @varname(x[1, 2][3])) - true - ``` - - - Array of scalar: basically everything that fulfills `issubset`. - - ```jldoctest - julia> subsumes(@varname(x[[1, 2], 3]), @varname(x[1, 3])) - true - - julia> subsumes(@varname(x[1:3]), @varname(x[2][1])) - true - ``` - - - Slices: - - ```jldoctest - julia> subsumes(@varname(x[2, :]), @varname(x[2, 10][1])) - true - ``` - -Currently _not_ supported are: - - - Boolean indexing, literal `CartesianIndex` (these could be added, though) - - Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for a matrix `x` - - Trailing ones: `x[2, 1]` does not subsume `x[2]` for a vector `x` +- Subsumption between different forms of indexing is not supported, e.g. `x[4]` and `x[2, + 2]` are not considered to subsume each other, even though they might in practice (e.g. if + `x` is a 2x2 matrix). +- When dynamic indices (that are not equal) are present. (Dynamic indices that are equal do + subsume each other.) +- Non-standard indices, e.g. `Not(4)`, `2..3`, etc. Again, these only subsume each other + when they are equal. """ function subsumes(u::VarName, v::VarName) return getsym(u) == getsym(v) && subsumes(getoptic(u), getoptic(v)) end - -# Idea behind `subsumes` for `Lens` is that we traverse the two lenses in parallel, -# checking `subsumes` for every level. This for example means that if we are comparing -# `PropertyLens{:a}` and `PropertyLens{:b}` we immediately know that they do not subsume -# each other since at the same level/depth they access different properties. -# E.g. `x`, `x[1]`, i.e. `u` is always subsumed by `t` -subsumes(::typeof(identity), ::typeof(identity)) = true -subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true -subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false - -function subsumes(t::ComposedFunction, u::ComposedFunction) - return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner) +subsumes(::Iden, ::Iden) = true +subsumes(::Iden, ::AbstractOptic) = true +subsumes(::AbstractOptic, ::Iden) = false +subsumes(t::Property{name}, u::Property{name}) where {name} = subsumes(t.child, u.child) +subsumes(t::Property, u::Property) = false +subsumes(::Property, ::Index) = false +subsumes(::Index, ::Property) = false + +function subsumes(i::Index, j::Index) + return _subsumes_positional(i.ix, j.ix) && + _subsumes_keyword(i.kw, j.kw) && + subsumes(i.child, j.child) end -# If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a -# leaf of the "lens-tree". -subsumes(t::ComposedFunction, u::PropertyLens) = false -# Here we need to check if `u.inner` (i.e. the next lens to be applied from `u`) is -# subsumed by `t`, since this would mean that the rest of the composition is also subsumed -# by `t`. -subsumes(t::PropertyLens, u::ComposedFunction) = subsumes(t, u.inner) - -# For `PropertyLens` either they have the same `name` and thus they are indeed the same. -subsumes(t::PropertyLens{name}, u::PropertyLens{name}) where {name} = true -# Otherwise they represent different properties, and thus are not the same. -subsumes(t::PropertyLens, u::PropertyLens) = false - -# PropertyLens and IndexLens can't subsume each other -subsumes(::PropertyLens, ::IndexLens) = false -subsumes(::IndexLens, ::PropertyLens) = false - -# Indices subsumes if they are subindices, i.e. we just call `_issubindex`. -# FIXME: Does not support `DynamicIndexLens`. -# FIXME: Does not correctly handle cases such as `subsumes(x, x[:])` -# (but neither did old implementation). -function subsumes( - t::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, - u::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, -) - return subsumes_indices(t, u) +function _subsumes_positional(i::Tuple, j::Tuple) + return (length(i) == length(j)) && all(_subsumes_index.(i, j)) end - -""" - subsumedby(t, u) - -True if `t` is subsumed by `u`, i.e., if `subsumes(u, t)` is true. -""" -subsumedby(t, u) = subsumes(u, t) -uncomparable(t, u) = t ⋢ u && u ⋢ t -const ⊒ = subsumes -const ⊑ = subsumedby -const ⋣ = !subsumes -const ⋢ = !subsumedby -const ≍ = uncomparable - -# Since expressions such as `x[:][:][:][1]` and `x[1]` are equal, -# the indexing behavior must be considered jointly. -# Therefore we must recurse until we reach something that is NOT -# indexing, and then consider the sequence of indices leading up to this. -""" - subsumes_indices(t, u) - -Return `true` if the indexing represented by `t` subsumes `u`. - -This is mostly useful for comparing compositions involving `IndexLens` -e.g. `_[1][2].a[2]` and `_[1][2].a`. In such a scenario we do the following: -1. Combine `[1][2]` into a `Tuple` of indices using [`combine_indices`](@ref). -2. Do the same for `[1][2]`. -3. Compare the two tuples from (1) and (2) using `subsumes_indices`. -4. Since we're still undecided, we call `subsume(@o(_.a[2]), @o(_.a))` - which then returns `false`. - -# Example -```jldoctest; setup=:(using Accessors; using AbstractPPL: subsumes_indices) -julia> t = @o(_[1].a); u = @o(_[1]); - -julia> subsumes_indices(t, u) -false - -julia> subsumes_indices(u, t) -true - -julia> # `identity` subsumes all. - subsumes_indices(identity, t) -true - -julia> # None subsumes `identity`. - subsumes_indices(t, identity) -false - -julia> AbstractPPL.subsumes(@o(_[1][2].a[2]), @o(_[1][2].a)) -false - -julia> AbstractPPL.subsumes(@o(_[1][2].a), @o(_[1][2].a[2])) -true -``` -""" -function subsumes_indices(t::ALLOWED_OPTICS, u::ALLOWED_OPTICS) - t_indices, t_next = combine_indices(t) - u_indices, u_next = combine_indices(u) - - # If we already know that `u` is not subsumed by `t`, return early. - if !subsumes_indices(t_indices, u_indices) - return false - end - - if t_next === nothing - # Means that there's nothing left for `t` and either nothing - # or something left for `u`, i.e. `t` indeed `subsumes` `u`. - return true - elseif u_next === nothing - # If `t_next` is not `nothing` but `u_next` is, then - # `t` does not subsume `u`. - return false +function _subsumes_keyword(i::NamedTuple{f1}, j::NamedTuple{f2}) where {f1,f2} + for name in f2 + if !(name in f1) || !(_subsumes_index(i[name], j[name])) + return false + end end - - # If neither is `nothing` we continue. - return subsumes(t_next, u_next) -end - -""" - combine_indices(optic) - -Return sequential indexing into a single `Tuple` of indices, -e.g. `x[:][1][2]` becomes `((Colon(), ), (1, ), (2, ))`. - -The result is compatible with [`subsumes_indices`](@ref) for `Tuple` input. -""" -combine_indices(optic::ALLOWED_OPTICS) = (), optic -combine_indices(optic::IndexLens) = (optic.indices,), nothing -function combine_indices(optic::ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}) - indices, next = combine_indices(optic.outer) - return (optic.inner.indices, indices...), next -end - -""" - subsumes_indices(left_indices::Tuple, right_indices::Tuple) - -Return `true` if `right_indices` is subsumed by `left_indices`. `left_indices` is assumed to be -concretized and consist of either `Int`s or `AbstractArray`s of scalar indices that are supported -by array A. - -Currently _not_ supported are: -- Boolean indexing, literal `CartesianIndex` (these could be added, though) -- Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for a matrix `x` -- Trailing ones: `x[2, 1]` does not subsume `x[2]` for a vector `x` -""" -subsumes_indices(::Tuple{}, ::Tuple{}) = true # x subsumes x -subsumes_indices(::Tuple{}, ::Tuple) = true # x subsumes x... -subsumes_indices(::Tuple, ::Tuple{}) = false # x... does not subsume x -function subsumes_indices(t1::Tuple, t2::Tuple) # does x[i]... subsume x[j]...? - first_subsumed = all(Base.splat(subsumes_index), zip(first(t1), first(t2))) - return first_subsumed && subsumes_indices(Base.tail(t1), Base.tail(t2)) + return true end -subsumes_index(i::Colon, ::Colon) = error("Colons cannot be subsumed") -subsumes_index(i, ::Colon) = error("Colons cannot be subsumed") -# Necessary to avoid ambiguity errors. -subsumes_index(::AbstractVector, ::Colon) = error("Colons cannot be subsumed") -subsumes_index(i::Colon, j) = true -subsumes_index(i::AbstractVector, j) = issubset(j, i) -subsumes_index(i, j) = i == j +_subsumes_index(a::DynamicIndex, b::DynamicIndex) = a == b +_subsumes_index(a::DynamicIndex, ::Any) = false +_subsumes_index(::DynamicIndex, ::Colon) = false +_subsumes_index(::Colon, ::DynamicIndex) = true +_subsumes_index(::Any, ::DynamicIndex) = false +_subsumes_index(::Colon, ::Any) = true +_subsumes_index(::Colon, ::Colon) = true +_subsumes_index(::Any, ::Colon) = false +_subsumes_index(a::AbstractVector, b::Any) = issubset(b, a) +_subsumes_index(a::AbstractVector, b::Colon) = false +_subsumes_index(a::AbstractVector, b::DynamicIndex) = false +_subsumes_index(a::Any, b::Any) = a == b diff --git a/src/varname/varname.jl b/src/varname/varname.jl index c2916de6..e78952b1 100644 --- a/src/varname/varname.jl +++ b/src/varname/varname.jl @@ -1,122 +1,20 @@ -using Accessors -using Accessors: PropertyLens, IndexLens, DynamicIndexLens - -# nb. ComposedFunction is the same as Accessors.ComposedOptic -const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedFunction} - """ VarName{sym}(optic=identity) -A variable identifier for a symbol `sym` and optic `optic`. - -The Julia variable in the model corresponding to `sym` can refer to a single value or to a -hierarchical array structure of univariate, multivariate or matrix variables. The field `lens` -stores the indices requires to access the random variable from the Julia variable indicated by `sym` -as a tuple of tuples. Each element of the tuple thereby contains the indices of one optic -operation. - -`VarName`s can be manually constructed using the `VarName{sym}(optic)` constructor, or from an -optic expression through the [`@varname`](@ref) convenience macro. - -# Examples +A variable identifier for a symbol `sym` and optic `optic`. `sym` refers to the name of the +top-level Julia variable, while `optic` allows one to specify a particular property or index +inside that variable. -```jldoctest; setup=:(using Accessors) -julia> vn = VarName{:x}(Accessors.IndexLens((Colon(), 1)) ⨟ Accessors.IndexLens((2, ))) -x[:, 1][2] - -julia> getoptic(vn) -(@o _[Colon(), 1][2]) - -julia> @varname x[:, 1][1+1] -x[:, 1][2] -``` +`VarName`s can be manually constructed using the `VarName{sym}(optic)` constructor, or from +an optic expression through the [`@varname`](@ref) convenience macro. """ -struct VarName{sym,T<:ALLOWED_OPTICS} +struct VarName{sym,T<:AbstractOptic} optic::T - - function VarName{sym}(optic=identity) where {sym} - optic = normalise(optic) - if !is_static_optic(typeof(optic)) - throw( - ArgumentError( - "attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))", - ), - ) - end + function VarName{sym}(optic=Iden()) where {sym} return new{sym,typeof(optic)}(optic) end end -""" - is_static_optic(l) - -Return `true` if `l` is one or a composition of `identity`, `PropertyLens`, and `IndexLens`; `false` if `l` is -one or a composition of `DynamicIndexLens`; and undefined otherwise. -""" -is_static_optic(::Type{<:Union{typeof(identity),PropertyLens,IndexLens}}) = true -function is_static_optic(::Type{ComposedFunction{LO,LI}}) where {LO,LI} - return is_static_optic(LO) && is_static_optic(LI) -end -is_static_optic(::Type{<:DynamicIndexLens}) = false - -""" - normalise(optic) - -Enforce that compositions of optics are always nested in the same way, in that -a ComposedFunction never has a ComposedFunction as its inner lens. Thus, for -example, - -```jldoctest; setup=:(using Accessors) -julia> op1 = ((@o _.c) ∘ (@o _.b)) ∘ (@o _.a) -(@o _.a.b.c) - -julia> op2 = (@o _.c) ∘ ((@o _.b) ∘ (@o _.a)) -(@o _.c) ∘ ((@o _.a.b)) - -julia> op1 == op2 -false - -julia> AbstractPPL.normalise(op1) == AbstractPPL.normalise(op2) == @o _.a.b.c -true -``` - -This function also removes redundant `identity` optics from ComposedFunctions: - -```jldoctest; setup=:(using Accessors) -julia> op3 = ((@o _.b) ∘ identity) ∘ (@o _.a) -(@o identity(_.a).b) - -julia> op4 = (@o _.b) ∘ (identity ∘ (@o _.a)) -(@o _.b) ∘ ((@o identity(_.a))) - -julia> AbstractPPL.normalise(op3) == AbstractPPL.normalise(op4) == @o _.a.b -true -``` -""" -function normalise(o::ComposedFunction{Outer,<:ComposedFunction}) where {Outer} - # `o` is currently (outer ∘ (inner_outer ∘ inner_inner)). - # We want to change this to: - # o = (outer ∘ inner_outer) ∘ inner_inner - inner_inner = o.inner.inner - inner_outer = o.inner.outer - # Recursively call normalise because inner_inner could itself be a - # ComposedFunction - return normalise((o.outer ∘ inner_outer) ∘ inner_inner) -end -function normalise(o::ComposedFunction{Outer,typeof(identity)} where {Outer}) - # strip outer identity - return normalise(o.outer) -end -function normalise(o::ComposedFunction{typeof(identity),Inner} where {Inner}) - # strip inner identity - return normalise(o.inner) -end -normalise(o::ComposedFunction) = normalise(o.outer) ∘ o.inner -normalise(o::ALLOWED_OPTICS) = o -# These two methods are needed to avoid method ambiguity. -normalise(o::ComposedFunction{typeof(identity),<:ComposedFunction}) = normalise(o.inner) -normalise(::ComposedFunction{typeof(identity),typeof(identity)}) = identity - """ getsym(vn::VarName) @@ -143,78 +41,28 @@ Return the optic of the Julia variable used to generate `vn`. ```jldoctest julia> getoptic(@varname(x[1][2:3])) -(@o _[1][2:3]) +Optic([1][2:3]) julia> getoptic(@varname(y)) -identity (generic function with 1 method) +Optic() ``` """ getoptic(vn::VarName) = vn.optic -""" - get(obj, vn::VarName{sym}) - -Alias for `(PropertyLens{sym}() ⨟ getoptic(vn))(obj)`. -``` -""" -function Base.get(obj, vn::VarName{sym}) where {sym} - return (PropertyLens{sym}() ⨟ getoptic(vn))(obj) -end - -""" - set(obj, vn::VarName{sym}, value) - -Alias for `set(obj, PropertyLens{sym}() ⨟ getoptic(vn), value)`. - -# Example - -```jldoctest; setup = :(using AbstractPPL: Accessors; nt = (a = 1, b = (c = [1, 2, 3],)); name = :nt) -julia> Accessors.set(nt, @varname(a), 10) -(a = 10, b = (c = [1, 2, 3],)) - -julia> Accessors.set(nt, @varname(b.c[1]), 10) -(a = 1, b = (c = [10, 2, 3],)) -``` -""" -function Accessors.set(obj, vn::VarName{sym}, value) where {sym} - return Accessors.set(obj, PropertyLens{sym}() ⨟ getoptic(vn), value) +function Base.:(==)(x::VarName, y::VarName) + return getsym(x) == getsym(y) && getoptic(x) == getoptic(y) end - -# Allow compositions with optic. -function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym}) where {sym} - return VarName{sym}(optic ∘ getoptic(vn)) +function Base.isequal(x::VarName, y::VarName) + return getsym(x) == getsym(y) && isequal(getoptic(x), getoptic(y)) end Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) -function Base.:(==)(x::VarName, y::VarName) - return getsym(x) == getsym(y) && getoptic(x) == getoptic(y) -end function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T} print(io, getsym(vn)) - return _show_optic(io, getoptic(vn)) -end - -# modified from https://github.com/JuliaObjects/Accessors.jl/blob/01528a81fdf17c07436e1f3d99119d3f635e4c26/src/sugar.jl#L502 -function _show_optic(io::IO, optic) - opts = Accessors.deopcompose(optic) - inner = Iterators.takewhile(x -> applicable(_shortstring, "", x), opts) - outer = Iterators.dropwhile(x -> applicable(_shortstring, "", x), opts) - if !isempty(outer) - show(io, opcompose(outer...)) - print(io, " ∘ ") - end - shortstr = reduce(_shortstring, inner; init="") - return print(io, shortstr) + return _pretty_print_optic(io, getoptic(vn)) end -_shortstring(prev, o::IndexLens) = "$prev[$(join(map(prettify_index, o.indices), ", "))]" -_shortstring(prev, ::typeof(identity)) = "$prev" -_shortstring(prev, o) = Accessors._shortstring(prev, o) - -prettify_index(x) = repr(x) -prettify_index(::Colon) = ":" - """ Symbol(vn::VarName) @@ -229,465 +77,307 @@ julia> Symbol(@varname(x[1][:])) Symbol("x[1][:]") ``` """ -Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol +Base.Symbol(vn::VarName) = Symbol(string(vn)) """ - ConcretizedSlice(::Base.Slice) + concretize(vn::VarName, x) -An indexing object wrapping the range of a `Base.Slice` object representing the concrete indices a -`:` indicates. Behaves the same, but prints differently, namely, still as `:`. +Return `vn` concretized on `x`, i.e. any information related to the runtime shape of `x` is +evaluated. This will convert any `begin` and `end` indices in `vn` to concrete indices with +information about the length of the dimension being indexed into. """ -struct ConcretizedSlice{T,R} <: AbstractVector{T} - range::R -end - -function ConcretizedSlice(s::Base.Slice{R}) where {R} - return ConcretizedSlice{eltype(s.indices),R}(s.indices) -end -Base.show(io::IO, s::ConcretizedSlice) = print(io, ":") -function Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) - return print(io, "ConcretizedSlice(", s.range, ")") -end -Base.size(s::ConcretizedSlice) = size(s.range) -Base.iterate(s::ConcretizedSlice, state...) = Base.iterate(s.range, state...) -Base.collect(s::ConcretizedSlice) = collect(s.range) -Base.getindex(s::ConcretizedSlice, i) = s.range[i] -Base.hasfastin(::Type{<:ConcretizedSlice}) = true -Base.in(i, s::ConcretizedSlice) = i in s.range - -# and this is the reason why we are doing this: -Base.to_index(A, s::ConcretizedSlice) = Base.Slice(s.range) +concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x)) """ - reconcretize_index(original_index, lowered_index) + is_dynamic(vn::VarName) -Create the index to be emitted in `concretize`. `original_index` is the original, unconcretized -index, and `lowered_index` the respective position of the result of `to_indices`. - -The only purpose of this are special cases like `:`, which we want to avoid becoming a -`Base.Slice(OneTo(...))` -- it would confuse people when printed. Instead, we concretize to a -`ConcretizedSlice` based on the `lowered_index`, just what you'd get with an explicit `begin:end` +Return `true` if `vn` contains any dynamic indices (i.e., `begin`, `end`, or `:`). If a +`VarName` has been concretized, this will always return `false`. """ -reconcretize_index(original_index, lowered_index) = lowered_index -function reconcretize_index(original_index::Colon, lowered_index::Base.Slice) - return ConcretizedSlice(lowered_index) -end +is_dynamic(vn::VarName) = is_dynamic(getoptic(vn)) """ - concretize(l, x) - -Return `l` instantiated on `x`, i.e. any information related to the runtime shape of `x` is -evaluated. This concerns `begin`, `end`, and `:` slices. + VarNameParseException(expr) -Basically, every index is converted to a concrete value using `Base.to_index` on `x`. However, `:` -slices are only converted to `ConcretizedSlice` (as opposed to `Base.Slice{Base.OneTo}`), to keep -the result close to the original indexing. +An exception thrown when a variable name expression cannot be parsed by the +[`@varname`](@ref) macro. """ -concretize(I::ALLOWED_OPTICS, x) = I -concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x) -function concretize(I::IndexLens, x) - return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices))) +struct VarNameParseException <: Exception + expr end -function concretize(I::ComposedFunction, x) - x_inner = I.inner(x) # TODO: get view here - return ComposedFunction(concretize(I.outer, x_inner), concretize(I.inner, x)) +function Base.showerror(io::IO, e::VarNameParseException) + return print(io, "malformed variable name `$(e.expr)`") end """ - concretize(vn::VarName, x) - -Return `vn` concretized on `x`, i.e. any information related to the runtime shape of `x` is -evaluated. This concerns `begin`, `end`, and `:` slices. - -# Examples -```jldoctest; setup=:(using Accessors) -julia> x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); + VarNameConcretizationException() -julia> getoptic(@varname(x.a[1:end, end][:], true)) # concrete=true required for @varname -(@o _.a[1:3, 2][:]) +When constructing a `VarName` using [`@varname`](@ref) (or [`@opticof`](@ref)), we allow +for interpolation of the top-level symbol, e.g. using `name = :x; @varname(\$name)`. However, +if this is done, it is not possible to automatically concretize the resulting `VarName` by +passing `true` as the second argument to `@varname`. -julia> y = zeros(10, 10); +Because macros are confusing, this is probably worth more explanation. For example, consider +the user input `name = :x; @varname(\$name, true)`. -julia> @varname(y[:], true) -y[:] - -julia> # The underlying value is concretized, though: - AbstractPPL.getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] -ConcretizedSlice(Base.OneTo(100)) -``` +Without concretization, we can easily handle this as `VarName{name}(Iden())`. `name` is then +resolved outside the macro to produce `VarName{:x}(Iden())`. However, to correctly +concretize this, we would need to generate the output `concretize(VarName{name}(), x)`; +i.e., we need to know at macro-expansion time that `name` evaluates to `:x`. This is not +possible given the expression `\$name` alone, which is why this error is thrown. """ -concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x)) +struct VarNameConcretizationException <: Exception end +function Base.showerror(io::IO, ::VarNameConcretizationException) + return print( + io, + "cannot automatically concretize VarName with interpolated top-level symbol; call `concretize(vn, val)` manually instead", + ) +end """ @varname(expr, concretize=false) -A macro that returns an instance of [`VarName`](@ref) given a symbol or indexing expression `expr`. - -If `concretize` is `true`, the resulting expression will be wrapped in a `concretize()` call. +Create a [`VarName`](@ref) given an expression `expr` representing a variable or part of it. -Note that expressions involving dynamic indexing, i.e. `begin` and/or `end`, will always need to be -concretized as `VarName` only supports non-dynamic indexing as determined by -`is_static_optic`. See examples below. +# Basic examples -## Examples +In general, `VarName`s must have a top-level symbol representing the identifier itself, and +can then have any number of property accesses or indexing operations chained to it. -### Dynamic indexing ```jldoctest -julia> x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); - -julia> @varname(x.a[1:end, end][:], true) -x.a[1:3, 2][:] +julia> @varname(x) +x -julia> @varname(x.a[end], false) # disable concretization -ERROR: LoadError: Variable name `x.a[end]` is dynamic and requires concretization! -[...] +julia> @varname(x.a.b.c) +x.a.b.c -julia> @varname(x.a[end]) # concretization occurs by default if deemed necessary -x.a[6] +julia> @varname(x[1][2][3]) +x[1][2][3] -julia> # Note that "dynamic" here refers to usage of `begin` and/or `end`, - # _not_ "information only available at runtime", i.e. the following works. - [@varname(x.a[i]) for i = 1:length(x.a)][end] -x.a[6] - -julia> # Potentially surprising behaviour, but this is equivalent to what Base does: - @varname(x[2:2:5]), 2:2:5 -(x[2:2:4], 2:2:4) +julia> @varname(x.a[1:3].b[2]) +x.a[1:3].b[2] ``` -### General indexing +# Dynamic indices -Under the hood `optic`s are used for the indexing: +Some expressions may involve dynamic indices, specifically, `begin`, `end`. These indices +cannot be resolved, or 'concretized', until the value being indexed into is known. By +default, `@varname(...)` will not automatically concretize these expressions, and thus the +resulting `VarName` will contain markers for these. -```jldoctest -julia> getoptic(@varname(x)) -identity (generic function with 1 method) +Note that colons are not considered dynamic. -julia> getoptic(@varname(x[1])) -(@o _[1]) +```jldoctest +julia> vn = @varname(x[end]) +x[DynamicIndex(end)] -julia> getoptic(@varname(x[:, 1])) -(@o _[Colon(), 1]) +julia> vn = @varname(x[1, end-1]) +x[1, DynamicIndex(end - 1)] +``` -julia> getoptic(@varname(x[:, 1][2])) -(@o _[Colon(), 1][2]) +You can detect whether a `VarName` contains any dynamic indices using [`is_dynamic`](@ref): -julia> getoptic(@varname(x[1,2][1+5][45][3])) -(@o _[1, 2][6][45][3]) +```jldoctest +julia> vn = @varname(x[1, end-1]); AbstractPPL.is_dynamic(vn) +true ``` -This also means that we support property access: +To concretize such expressions, you can call [`concretize`](@ref) on the resulting +`VarName`. After concretization, the resulting `VarName` will no longer be dynamic. ```jldoctest -julia> getoptic(@varname(x.a)) -(@o _.a) +julia> x = randn(2, 3); + +julia> vn = @varname(x[1, end-1]); vn2 = AbstractPPL.concretize(vn, x) +x[1, 2] -julia> getoptic(@varname(x.a[1])) -(@o _.a[1]) +julia> getoptic(vn2).ix # Just an ordinary tuple. +(1, 2) -julia> x = (a = [(b = rand(2), )], ); getoptic(@varname(x.a[1].b[end], true)) -(@o _.a[1].b[2]) +julia> AbstractPPL.is_dynamic(vn2) +false ``` -Interpolation can be used for variable names, or array name, but not the lhs of a `.` expression. -Variables within indices are always evaluated in the calling scope. +Alternatively, you can pass `true` as the second positional argument to the `@varname` macro +(note that it is not a keyword argument!). This will automatically call [`concretize`](@ref) +for you, using the top-level symbol to look up the value used for concretization. ```jldoctest -julia> name, i = :a, 10; +julia> x = randn(2, 3); -julia> @varname(x.\$name[i, i+1]) -x.a[10, 11] +julia> @varname(x[1:end, end][:], true) +x[1:2, 3][:] +``` + +# Interpolation + +Property names, as well as top-level symbols, can also be constructed from interpolated +symbols: + +```jldoctest +julia> name = :hello; @varname(x.\$name) +x.hello julia> @varname(\$name) -a +hello -julia> @varname(\$name[1]) -a[1] +julia> @varname(\$name.a.\$name[1]) +hello.a.hello[1] +``` -julia> @varname(\$name.x[1]) -a.x[1] +For indices, you do not need to use `\$` to interpolate, just use the variable directly: -julia> @varname(b.\$name.x[1]) -b.a.x[1] +```jldoctest +julia> ix = 2; @varname(x[ix]) +x[2] +``` + +Note that if the top-level symbol is interpolated, automatic concretization is not possible: + +```jldoctest +julia> name = :x; @varname(\$name[1:end], true) +ERROR: LoadError: cannot automatically concretize VarName with interpolated top-level symbol; call `concretize(vn, val)` manually instead +[...] ``` """ -macro varname(expr::Union{Expr,Symbol}, concretize::Bool=Accessors.need_dynamic_optic(expr)) +macro varname(expr, concretize::Bool=false) return varname(expr, concretize) end -varname(sym::Symbol) = :($(AbstractPPL.VarName){$(QuoteNode(sym))}()) -varname(sym::Symbol, _) = varname(sym) -function varname(expr::Expr, concretize=Accessors.need_dynamic_optic(expr)) - if Meta.isexpr(expr, :ref) || Meta.isexpr(expr, :.) - # Split into object/base symbol and lens. - sym_escaped, optics = _parse_obj_optic(expr) - # Setfield.jl escapes the return symbol, so we need to unescape - # to call `QuoteNode` on it. - sym = drop_escape(sym_escaped) - - # This is to handle interpolated heads -- Setfield treats them differently: - # julia> AbstractPPL._parse_obj_optics(Meta.parse("\$name.a")) - # (:($(Expr(:escape, :_))), (:($(Expr(:escape, :name))), :((PropertyLens){:a}()))) - # julia> AbstractPPL._parse_obj_optic(:(x.a)) - # (:($(Expr(:escape, :x))), :(Accessors.opticcompose((PropertyLens){:a}()))) - if sym != :_ - sym = QuoteNode(sym) - else - sym = optics.args[2] - optics = Expr(:call, optics.args[1], optics.args[3:end]...) - end +""" + varname(expr, concretize::Bool) - if concretize - return :($(AbstractPPL.VarName){$sym}( - $(AbstractPPL.concretize)($optics, $sym_escaped) - )) - elseif Accessors.need_dynamic_optic(expr) - error("Variable name `$(expr)` is dynamic and requires concretization!") - else - return :($(AbstractPPL.VarName){$sym}($optics)) - end - elseif Meta.isexpr(expr, :$, 1) - return :($(AbstractPPL.VarName){$(esc(expr.args[1]))}()) +Implementation of the `@varname` macro. See the documentation for `@varname` for details. +This function is exported to allow other macros (e.g. in DynamicPPL) to reuse the same +logic. +""" +function varname(expr, concretize::Bool) + unconcretized_vn, sym = _varname(expr, :($(Iden)())) + return if concretize + sym === nothing && throw(VarNameConcretizationException()) + :($(AbstractPPL.concretize)($unconcretized_vn, $(esc(sym)))) else - error("Malformed variable name `$(expr)`!") + unconcretized_vn end end - -drop_escape(x) = x -function drop_escape(expr::Expr) - Meta.isexpr(expr, :escape) && return drop_escape(expr.args[1]) - return Expr(expr.head, map(x -> drop_escape(x), expr.args)...) +function _varname(@nospecialize(expr::Any), ::Any) + # fallback: it's not a variable! + throw(VarNameParseException(expr)) end - -function _parse_obj_optic(ex) - obj, optics = _parse_obj_optics(ex) - optic = Expr(:call, Accessors.opticcompose, optics...) - return obj, optic +function _varname(sym::Symbol, inner_expr) + return :($(VarName){$(QuoteNode(sym))}($inner_expr)), sym end - -# Accessors doesn't have the same support for interpolation -# so this function is copied and altered from `Setfield._parse_obj_lens` -function _parse_obj_optics(ex) - if Meta.isexpr(ex, :$, 1) - return esc(:_), (esc(ex.args[1]),) - elseif Meta.isexpr(ex, :ref) && !isempty(ex.args) - front, indices... = ex.args - obj, frontoptics = _parse_obj_optics(front) - if any(Accessors.need_dynamic_optic, indices) - @gensym collection - indices = Accessors.replace_underscore.(indices, collection) - dims = length(indices) == 1 ? nothing : 1:length(indices) - lindices = esc.(Accessors.lower_index.(collection, indices, dims)) - optics = - :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),))) - else - index = esc(Expr(:tuple, indices...)) - optics = :($(Accessors.IndexLens)($index)) - end - elseif Meta.isexpr(ex, :., 2) - front = ex.args[1] - property = ex.args[2].value # ex.args[2] is a QuoteNode - obj, frontoptics = _parse_obj_optics(front) - if property isa Union{Symbol,String} - optics = :($(Accessors.PropertyLens){$(QuoteNode(property))}()) - elseif Meta.isexpr(property, :$, 1) - optics = :($(Accessors.PropertyLens){$(esc(property.args[1]))}()) +function _varname(expr::Expr, inner_expr) + if Meta.isexpr(expr, :$, 1) + # Interpolation of the top-level symbol e.g. @varname($name). If we hit this branch, + # it means that there are no further property/indexing accesses (because otherwise + # expr.head would be :ref or :.) Thus we don't need to recurse further, and we can + # just return `inner_expr` as-is. + sym_expr = expr.args[1] + return :($(VarName){$(esc(sym_expr))}($inner_expr)), nothing + else + next_inner = if expr.head == :(.) + sym = _handle_property(expr.args[2], expr) + :($(Property){$(sym)}($inner_expr)) + elseif expr.head == :ref + original_ixs = expr.args[2:end] + positional_args = [] + keyword_args = [] + for (dim, ix_expr) in enumerate(original_ixs) + if _is_kw(ix_expr) + push!(keyword_args, :($(ix_expr.args[1]) = $(esc(ix_expr.args[2])))) + else + push!(positional_args, (dim, ix_expr)) + end + end + is_single_index = length(positional_args) == 1 + positional_ixs = map(positional_args) do (dim, ix_expr) + _handle_index(ix_expr, is_single_index ? nothing : dim) + end + positional_expr = Expr(:tuple, positional_ixs...) + kwarg_expr = if isempty(keyword_args) + :((;)) + else + Expr(:tuple, keyword_args...) + end + :($(Index)($positional_expr, $kwarg_expr, $inner_expr)) else - throw( - ArgumentError( - string( - "Error while parsing :($ex). Second argument to `getproperty` can only be", - "a `Symbol` or `String` literal, received `$property` instead.", - ), - ), - ) + # some other expression we can't parse + throw(VarNameParseException(expr)) end - else - obj = esc(ex) - return obj, () + return _varname(expr.args[1], next_inner) end - return obj, tuple(frontoptics..., optics) end -""" - @vsym(expr) - -A macro that returns the variable symbol given the input variable expression `expr`. -For example, `@vsym x[1]` returns `:x`. - -## Examples - -```jldoctest -julia> @vsym x -:x - -julia> @vsym x[1,1][2,3] -:x - -julia> @vsym x[end] -:x -``` -""" -macro vsym(expr::Union{Expr,Symbol}) - return QuoteNode(vsym(expr)) -end - -""" - vsym(expr) - -Return name part of the [`@varname`](@ref)-compatible expression `expr` as a symbol for input of the -[`VarName`](@ref) constructor. -""" -function vsym end - -vsym(expr::Symbol) = expr -function vsym(expr::Expr) - if Meta.isexpr(expr, :ref) || Meta.isexpr(expr, :.) - return vsym(expr.args[1]) +function _handle_property(qn::QuoteNode, original_expr) + if qn.value isa Symbol # no interpolation e.g. @varname(x.a) + return qn + elseif Meta.isexpr(qn.value, :$, 1) && qn.value.args[1] isa Symbol + # interpolated property e.g. @varname(x.$name). + # TODO(penelopeysm): Note that $name must evaluate to a Symbol, or else you will get + # a slightly inscrutable error: "ERROR: TypeError: in Type, in parameter, expected + # Type, got a value of type String". This should probably be fixed, but I don't + # actually *know* how to do it. Again, this is not a new issue, the old VarName + # also had the same problem. + return esc(qn.value.args[1]) else - error("Malformed variable name `$(expr)`!") + throw(VarNameParseException(original_expr)) end end +function _handle_property(::Any, original_expr) + throw(VarNameParseException(original_expr)) +end -""" - _head(optic) - -Get the innermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_tail(optic) ∘ -_head(optic) == optic)`. - -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._head(Accessors.@o _.a.b.c) -(@o _.a) - -julia> AbstractPPL._head(Accessors.@o _[1][2][3]) -(@o _[1]) - -julia> AbstractPPL._head(Accessors.@o _.a) -(@o _.a) - -julia> AbstractPPL._head(Accessors.@o _[1]) -(@o _[1]) - -julia> AbstractPPL._head(Accessors.@o _) -identity (generic function with 1 method) -``` -""" -_head(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner -_head(o::Accessors.PropertyLens) = o -_head(o::Accessors.IndexLens) = o -_head(::typeof(identity)) = identity +_is_kw(e::Expr) = Meta.isexpr(e, :kw, 2) +_is_kw(::Any) = false +_handle_index(ix::Any, ::Any) = ix +_handle_index(ix::Symbol, dim) = _make_dynamicindex_expr(ix, dim) +_handle_index(ix::Expr, dim) = _make_dynamicindex_expr(ix, dim) """ - _tail(optic) - -Get everything but the innermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_tail(optic) ∘ -_head(optic) == optic)`. + @opticof(expr, concretize=false) -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. +Extract the optic from `@varname(expr, concretize)`. This is a thin wrapper around +`getoptic(@varname(...))`. -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._tail(Accessors.@o _.a.b.c) -(@o _.b.c) +If you don't need to concretize, you should use `_` as the top-level symbol to +indicate that it is not relevant: -julia> AbstractPPL._tail(Accessors.@o _[1][2][3]) -(@o _[2][3]) - -julia> AbstractPPL._tail(Accessors.@o _.a) -identity (generic function with 1 method) +```jldoctest +julia> @opticof(_.a.b) +Optic(.a.b) +``` -julia> AbstractPPL._tail(Accessors.@o _[1]) -identity (generic function with 1 method) +If you need to concretize, then you can provide a real variable name (which is then used to +look up the value for concretization): -julia> AbstractPPL._tail(Accessors.@o _) -identity (generic function with 1 method) +```jldoctest +julia> x = randn(3, 4); @opticof(x[1:end, end], true) +Optic([1:3, 4]) ``` -""" -_tail(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer -_tail(::Accessors.PropertyLens) = identity -_tail(::Accessors.IndexLens) = identity -_tail(::typeof(identity)) = identity +Note that concretization with `@opticof` has the same limitations as with `@varname`, +specifically, if the top-level symbol is interpolated, automatic concretization is not +possible. """ - _last(optic) - -Get the outermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_last(optic) ∘ -_init(optic)) == optic`. - -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._last(Accessors.@o _.a.b.c) -(@o _.c) - -julia> AbstractPPL._last(Accessors.@o _[1][2][3]) -(@o _[3]) - -julia> AbstractPPL._last(Accessors.@o _.a) -(@o _.a) - -julia> AbstractPPL._last(Accessors.@o _[1]) -(@o _[1]) +macro opticof(expr, concretize::Bool=false) + return :(getoptic($(varname(expr, concretize)))) +end -julia> AbstractPPL._last(Accessors.@o _) -identity (generic function with 1 method) -``` """ -_last(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} = _last(o.outer) -_last(o::Accessors.PropertyLens) = o -_last(o::Accessors.IndexLens) = o -_last(::typeof(identity)) = identity + varname_to_optic(vn::VarName) +Convert a `VarName` to an optic, by converting the top-level symbol to a `Property` optic. """ - _init(optic) - -Get everything but the outermost layer of an optic. - -For all (normalised) optics, we have that `normalise(_last(optic) ∘ -_init(optic)) == optic`. +varname_to_optic(vn::VarName{sym}) where {sym} = Property{sym}(getoptic(vn)) -!!! note - Does not perform optic normalisation on the input. You may wish to call - `normalise(optic)` before using this function if the optic you are passing - was not obtained from a VarName. - -```jldoctest; setup=:(using Accessors) -julia> AbstractPPL._init(Accessors.@o _.a.b.c) -(@o _.a.b) - -julia> AbstractPPL._init(Accessors.@o _[1][2][3]) -(@o _[1][2]) - -julia> AbstractPPL._init(Accessors.@o _.a) -identity (generic function with 1 method) - -julia> AbstractPPL._init(Accessors.@o _[1]) -identity (generic function with 1 method) +""" + optic_to_varname(optic::Property{sym}) where {sym} -julia> AbstractPPL._init(Accessors.@o _) -identity (generic function with 1 method) +Convert a `Property` optic to a `VarName`, by converting the top-level property to a symbol. +This fails for all other optics. """ -# This one needs normalise because it's going 'against' the direction of the -# linked list (otherwise you will end up with identities scattered throughout) -function _init(o::ComposedFunction{Outer,Inner}) where {Outer,Inner} - return normalise(_init(o.outer) ∘ o.inner) +optic_to_varname(optic::Property{sym}) where {sym} = VarName{sym}(otail(optic)) +function optic_to_varname(::AbstractOptic) + throw(ArgumentError("optic_to_varname: can only convert Property optics to VarName")) end -_init(::Accessors.PropertyLens) = identity -_init(::Accessors.IndexLens) = identity -_init(::typeof(identity)) = identity diff --git a/test/Aqua.jl b/test/Aqua.jl index 6f608e2c..27d4103f 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -3,8 +3,6 @@ module AquaTests using Aqua: Aqua using AbstractPPL -# For now, we skip ambiguities since they come from interactions -# with third-party packages rather than issues in AbstractPPL itself -Aqua.test_all(AbstractPPL; ambiguities=false) +Aqua.test_all(AbstractPPL) end diff --git a/test/Project.toml b/test/Project.toml index 170c644f..79803fe6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,11 @@ [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -16,6 +18,7 @@ Aqua = "0.8" Distributions = "0.25" Documenter = "0.26.3, 0.27, 1" InvertedIndices = "1" +JET = "0.9, 0.10, 0.11" OffsetArrays = "1" OrderedCollections = "1" julia = "1" diff --git a/test/runtests.jl b/test/runtests.jl index cb07ee02..fc8d3980 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,9 +7,13 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractPPL.jl" begin if GROUP == "All" || GROUP == "Tests" include("Aqua.jl") - include("varname.jl") include("abstractprobprog.jl") - include("hasvalue.jl") + include("varname/optic.jl") + include("varname/varname.jl") + include("varname/subsumes.jl") + include("varname/hasvalue.jl") + include("varname/leaves.jl") + include("varname/serialize.jl") end if GROUP == "All" || GROUP == "Doctests" diff --git a/test/varname.jl b/test/varname.jl deleted file mode 100644 index c86c4b4e..00000000 --- a/test/varname.jl +++ /dev/null @@ -1,438 +0,0 @@ -using Accessors -using InvertedIndices -using OffsetArrays -using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky - -using AbstractPPL: ⊑, ⊒, ⋢, ⋣, ≍ - -using AbstractPPL: Accessors -using AbstractPPL.Accessors: IndexLens, PropertyLens, ⨟ - -macro test_strict_subsumption(x, y) - quote - @test $((varname(x))) ⊑ $((varname(y))) - @test $((varname(x))) ⋣ $((varname(y))) - end -end - -function test_equal(o1::VarName{sym1}, o2::VarName{sym2}) where {sym1,sym2} - return sym1 === sym2 && test_equal(o1.optic, o2.optic) -end -function test_equal(o1::ComposedFunction, o2::ComposedFunction) - return test_equal(o1.inner, o2.inner) && test_equal(o1.outer, o2.outer) -end -function test_equal(o1::Accessors.IndexLens, o2::Accessors.IndexLens) - return test_equal(o1.indices, o2.indices) -end -function test_equal(o1, o2) - return o1 == o2 -end - -@testset "varnames" begin - @testset "string and symbol conversion" begin - vn1 = @varname x[1][2] - @test string(vn1) == "x[1][2]" - @test Symbol(vn1) == Symbol("x[1][2]") - end - - @testset "equality and hashing" begin - vn1 = @varname x[1][2] - vn2 = @varname x[1][2] - @test vn2 == vn1 - @test hash(vn2) == hash(vn1) - end - - @testset "inspace" begin - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) - end - - @testset "optic normalisation" begin - # Push the limits a bit with four optics, one of which is identity, and - # we'll parenthesise them in every possible way. (Some of these are - # going to be equal even before normalisation, but we should test that - # `normalise` works regardless of how Base or Accessors.jl define - # associativity.) - op1 = ((@o _.c) ∘ (@o _.b)) ∘ identity ∘ (@o _.a) - op2 = (@o _.c) ∘ ((@o _.b) ∘ identity) ∘ (@o _.a) - op3 = (@o _.c) ∘ (@o _.b) ∘ (identity ∘ (@o _.a)) - op4 = ((@o _.c) ∘ (@o _.b) ∘ identity) ∘ (@o _.a) - op5 = (@o _.c) ∘ ((@o _.b) ∘ identity ∘ (@o _.a)) - op6 = (@o _.c) ∘ (@o _.b) ∘ identity ∘ (@o _.a) - for op in (op1, op2, op3, op4, op5, op6) - @test AbstractPPL.normalise(op) == (@o _.c) ∘ (@o _.b) ∘ (@o _.a) - end - # Prefix and unprefix also provide further testing for normalisation. - end - - @testset "construction & concretization" begin - i = 1:10 - j = 2:2:5 - @test @varname(A[1].b[i]) == @varname(A[1].b[1:10]) - @test @varname(A[j]) == @varname(A[2:2:5]) - - @test @varname(A[:, 1][1 + 1]) == @varname(A[:, 1][2]) - @test(@varname(A[:, 1][2]) == VarName{:A}(@o(_[:, 1]) ⨟ @o(_[2]))) - - # concretization - y = zeros(10, 10) - x = (a=[1.0 2.0; 3.0 4.0; 5.0 6.0],) - - @test @varname(y[begin, i], true) == @varname(y[1, 1:10]) - @test test_equal(@varname(y[:], true), @varname(y[1:100])) - @test test_equal(@varname(y[:, begin], true), @varname(y[1:10, 1])) - @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === - AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) - @test test_equal(@varname(x.a[1:end, end][:], true), @varname(x.a[1:3, 2][1:3])) - end - - @testset "compose and opcompose" begin - @test IndexLens(1) ∘ @varname(x.a) == @varname(x.a[1]) - @test @varname(x.a) ⨟ IndexLens(1) == @varname(x.a[1]) - - @test @varname(x) ⨟ identity == @varname(x) - @test identity ∘ @varname(x) == @varname(x) - @test @varname(x.a) ⨟ identity == @varname(x.a) - @test identity ∘ @varname(x.a) == @varname(x.a) - @test @varname(x[1].b) ⨟ identity == @varname(x[1].b) - @test identity ∘ @varname(x[1].b) == @varname(x[1].b) - end - - @testset "get & set" begin - x = (a=[1.0 2.0; 3.0 4.0; 5.0 6.0], b=1.0) - @test get(x, @varname(a[1, 2])) == 2.0 - @test get(x, @varname(b)) == 1.0 - @test set(x, @varname(a[1, 2]), 10) == (a=[1.0 10.0; 3.0 4.0; 5.0 6.0], b=1.0) - @test set(x, @varname(b), 10) == (a=[1.0 2.0; 3.0 4.0; 5.0 6.0], b=10.0) - end - - @testset "subsumption with standard indexing" begin - # x ⊑ x - @test @varname(x) ⊑ @varname(x) - @test @varname(x[1]) ⊑ @varname(x[1]) - @test @varname(x.a) ⊑ @varname(x.a) - - # x ≍ y - @test @varname(x) ≍ @varname(y) - @test @varname(x.a) ≍ @varname(y.a) - @test @varname(a.x) ≍ @varname(a.y) - @test @varname(a.x[1]) ≍ @varname(a.x.z) - @test @varname(x[1]) ≍ @varname(y[1]) - @test @varname(x[1]) ≍ @varname(x.y) - - # x ∘ ℓ ⊑ x - @test_strict_subsumption x.a x - @test_strict_subsumption x[1] x - @test_strict_subsumption x[2:2:5] x - @test_strict_subsumption x[10, 20] x - - # x ∘ ℓ₁ ⊑ x ∘ ℓ₂ ⇔ ℓ₁ ⊑ ℓ₂ - @test_strict_subsumption x.a.b x.a - @test_strict_subsumption x[1].a x[1] - @test_strict_subsumption x.a[1] x.a - @test_strict_subsumption x[1:10][2] x[1:10] - - @test_strict_subsumption x[1] x[1:10] - @test_strict_subsumption x[1:5] x[1:10] - @test_strict_subsumption x[4:6] x[1:10] - - @test_strict_subsumption x[[2, 3, 5]] x[[7, 6, 5, 4, 3, 2, 1]] - - @test_strict_subsumption x[:a][1] x[:a] - - # boolean indexing works as long as it is concretized - A = rand(10, 10) - @test @varname(A[iseven.(1:10), 1], true) ⊑ @varname(A[1:10, 1]) - @test @varname(A[iseven.(1:10), 1], true) ⋣ @varname(A[1:10, 1]) - - # we can reasonably allow colons on the right side ("universal set") - @test @varname(x[1]) ⊑ @varname(x[:]) - @test @varname(x[1:10, 1]) ⊑ @varname(x[:, 1:10]) - @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[1])) - @test_throws ErrorException (@varname(x[:]) ⊑ @varname(x[:])) - end - - @testset "non-standard indexing" begin - A = rand(10, 10) - @test test_equal( - @varname(A[1, Not(3)], true), @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]) - ) - - B = OffsetArray(A, -5, -5) # indices -4:5×-4:5 - @test test_equal(@varname(B[1, :], true), @varname(B[1, -4:5])) - end - @testset "type stability" begin - @inferred VarName{:a}() - @inferred VarName{:a}(IndexLens(1)) - @inferred VarName{:a}(IndexLens(1, 2)) - @inferred VarName{:a}(PropertyLens(:b)) - @inferred VarName{:a}(Accessors.opcompose(IndexLens(1), PropertyLens(:b))) - - b = (a=[1, 2, 3],) - @inferred get(b, @varname(a[1])) - @inferred Accessors.set(b, @varname(a[1]), 10) - - c = (b=(a=[1, 2, 3],),) - @inferred get(c, @varname(b.a[1])) - @inferred Accessors.set(c, @varname(b.a[1]), 10) - end - - @testset "de/serialisation of VarNames" begin - y = ones(10) - z = ones(5, 2) - vns = [ - @varname(x), - @varname(ä), - @varname(x.a), - @varname(x.a.b), - @varname(var"x.a"), - @varname(x[1]), - @varname(var"x[1]"), - @varname(x[1:10]), - @varname(x[1:3:10]), - @varname(x[1, 2]), - @varname(x[1, 2:5]), - @varname(x[:]), - @varname(x.a[1]), - @varname(x.a[1:10]), - @varname(x[1].a), - @varname(y[:]), - @varname(y[begin:end]), - @varname(y[end]), - @varname(y[:], false), - @varname(y[:], true), - @varname(z[:], false), - @varname(z[:], true), - @varname(z[:][:], false), - @varname(z[:][:], true), - @varname(z[:, :], false), - @varname(z[:, :], true), - @varname(z[2:5, :], false), - @varname(z[2:5, :], true), - ] - for vn in vns - @test string_to_varname(varname_to_string(vn)) == vn - end - - # For this VarName, the {de,}serialisation works correctly but we must - # test in a different way because equality comparison of structs with - # vector fields (such as Accessors.IndexLens) compares the memory - # addresses rather than the contents (thus vn_vec == vn_vec2 returns - # false). - vn_vec = @varname(x[[1, 2, 5, 6]]) - vn_vec2 = string_to_varname(varname_to_string(vn_vec)) - @test hash(vn_vec) == hash(vn_vec2) - end - - @testset "de/serialisation of VarNames with custom index types" begin - using OffsetArrays: OffsetArrays, Origin - weird = Origin(4)(ones(10)) - vn = @varname(weird[:], true) - - # This won't work as we don't yet know how to handle OffsetArray - @test_throws MethodError varname_to_string(vn) - - # Now define the relevant methods - AbstractPPL.index_to_dict(o::OffsetArrays.IdOffsetRange{I,R}) where {I,R} = Dict( - "type" => "OffsetArrays.OffsetArray", - "parent" => AbstractPPL.index_to_dict(o.parent), - "offset" => o.offset, - ) - AbstractPPL.dict_to_index(::Val{Symbol("OffsetArrays.OffsetArray")}, d) = - OffsetArrays.IdOffsetRange(AbstractPPL.dict_to_index(d["parent"]), d["offset"]) - - # Serialisation should now work - @test string_to_varname(varname_to_string(vn)) == vn - end - - @testset "head, tail, init, last" begin - @testset "specification" begin - @test AbstractPPL._head(@o _.a.b.c) == @o _.a - @test AbstractPPL._tail(@o _.a.b.c) == @o _.b.c - @test AbstractPPL._init(@o _.a.b.c) == @o _.a.b - @test AbstractPPL._last(@o _.a.b.c) == @o _.c - - @test AbstractPPL._head(@o _[1][2][3]) == @o _[1] - @test AbstractPPL._tail(@o _[1][2][3]) == @o _[2][3] - @test AbstractPPL._init(@o _[1][2][3]) == @o _[1][2] - @test AbstractPPL._last(@o _[1][2][3]) == @o _[3] - - @test AbstractPPL._head(@o _.a) == @o _.a - @test AbstractPPL._tail(@o _.a) == identity - @test AbstractPPL._init(@o _.a) == identity - @test AbstractPPL._last(@o _.a) == @o _.a - - @test AbstractPPL._head(@o _[1]) == @o _[1] - @test AbstractPPL._tail(@o _[1]) == identity - @test AbstractPPL._init(@o _[1]) == identity - @test AbstractPPL._last(@o _[1]) == @o _[1] - - @test AbstractPPL._head(identity) == identity - @test AbstractPPL._tail(identity) == identity - @test AbstractPPL._init(identity) == identity - @test AbstractPPL._last(identity) == identity - end - - @testset "composition" begin - varnames = ( - @varname(x), - @varname(x[1]), - @varname(x.a), - @varname(x.a.b), - @varname(x[1].a), - ) - for vn in varnames - optic = getoptic(vn) - @test AbstractPPL.normalise( - AbstractPPL._last(optic) ∘ AbstractPPL._init(optic) - ) == optic - @test AbstractPPL.normalise( - AbstractPPL._tail(optic) ∘ AbstractPPL._head(optic) - ) == optic - end - end - end - - @testset "prefix and unprefix" begin - @testset "basic cases" begin - @test prefix(@varname(y), @varname(x)) == @varname(x.y) - @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) - @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) - @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) - @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) - - @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) - @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) - @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) - @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) - end - - @testset "round-trip" begin - # These seem similar to the ones above, but in the past they used - # to error because of issues with un-normalised ComposedFunction - # optics. We explicitly test round-trip (un)prefixing here to make - # sure that there aren't any regressions. - # This tuple is probably overkill, but the tests are super fast - # anyway. - vns = ( - @varname(p), - @varname(q), - @varname(r[1]), - @varname(s.a), - @varname(t[1].a), - @varname(u[1].a.b), - @varname(v.a[1][2].b.c.d[3]) - ) - for vn1 in vns - for vn2 in vns - prefixed = prefix(vn1, vn2) - @test subsumes(vn2, prefixed) - unprefixed = unprefix(prefixed, vn2) - @test unprefixed == vn1 - end - end - end - end - - @testset "varname{_and_value}_leaves" begin - @testset "single value: float, int" begin - x = 1.0 - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - x = 2 - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - end - - @testset "Vector" begin - x = randn(2) - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x[1]), @varname(x[2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])]) - x = [(; a=1), (; b=2)] - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x[1].a), @varname(x[2].b)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x[1].a), x[1].a), (@varname(x[2].b), x[2].b)]) - end - - @testset "Matrix" begin - x = randn(2, 2) - @test Set(varname_leaves(@varname(x), x)) == Set([ - @varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2]) - ]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ - (@varname(x[1, 1]), x[1, 1]), - (@varname(x[1, 2]), x[1, 2]), - (@varname(x[2, 1]), x[2, 1]), - (@varname(x[2, 2]), x[2, 2]), - ]) - end - - @testset "Lower/UpperTriangular" begin - x = randn(2, 2) - xl = LowerTriangular(x) - @test Set(varname_leaves(@varname(x), xl)) == - Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([ - (@varname(x[1, 1]), x[1, 1]), - (@varname(x[2, 1]), x[2, 1]), - (@varname(x[2, 2]), x[2, 2]), - ]) - xu = UpperTriangular(x) - @test Set(varname_leaves(@varname(x), xu)) == - Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([ - (@varname(x[1, 1]), x[1, 1]), - (@varname(x[1, 2]), x[1, 2]), - (@varname(x[2, 2]), x[2, 2]), - ]) - end - - @testset "NamedTuple" begin - x = (a=1.0, b=[2.0, 3.0]) - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x.a), @varname(x.b[1]), @varname(x.b[2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ - (@varname(x.a), x.a), (@varname(x.b[1]), x.b[1]), (@varname(x.b[2]), x.b[2]) - ]) - end - - @testset "Cholesky" begin - x = cholesky([1.0 0.5; 0.5 1.0]) - @test Set(varname_leaves(@varname(x), x)) == - Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ - (@varname(x.U[1, 1]), x.U[1, 1]), - (@varname(x.U[1, 2]), x.U[1, 2]), - (@varname(x.U[2, 2]), x.U[2, 2]), - ]) - end - - @testset "fallback on other types, e.g. string" begin - x = "a string" - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - x = 2 - @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) - @test Set(collect(varname_and_value_leaves(@varname(x), x))) == - Set([(@varname(x), x)]) - end - end -end diff --git a/test/hasvalue.jl b/test/varname/hasvalue.jl similarity index 86% rename from test/hasvalue.jl rename to test/varname/hasvalue.jl index 5881eaa2..a0b03b3d 100644 --- a/test/hasvalue.jl +++ b/test/varname/hasvalue.jl @@ -1,3 +1,9 @@ +module VarNameHasValueTests + +using AbstractPPL +using DimensionalData: DimensionalData as DD +using Test + @testset "base getvalue + hasvalue" begin @testset "basic NamedTuple" begin nt = (a=[1], b=2, c=(x=3, y=[4], z=(; p=[(; q=5)])), d=[1.0 0.5; 0.5 1.0]) @@ -156,6 +162,36 @@ @test !hasvalue(d, @varname(x[2][1][1][1])) end end + + @testset "DimArray indices (including keyword)" begin + x = (; a=DD.DimArray(randn(2, 3), (:i, :j))) + @test hasvalue(x, @varname(a)) + @test getvalue(x, @varname(a)) == x.a + @test hasvalue(x, @varname(a[1, 2])) + @test getvalue(x, @varname(a[1, 2])) == x.a[1, 2] + @test hasvalue(x, @varname(a[:])) + @test getvalue(x, @varname(a[:])) == x.a[:] + @test canview(@opticof(_[i=1]), x.a) + @test hasvalue(x, @varname(a[i=1])) + @test getvalue(x, @varname(a[i=1])) == x.a[i=1] + @test canview(@opticof(_[i=1, j=2]), x.a) + @test hasvalue(x, @varname(a[i=1, j=2])) + @test getvalue(x, @varname(a[i=1, j=2])) == x.a[i=1, j=2] + @test hasvalue(x, @varname(a[i=DD.Not(1)])) + @test getvalue(x, @varname(a[i=DD.Not(1)])) == x.a[i=DD.Not(1)] + + y = (; b=DD.DimArray(randn(2, 3), (DD.X, DD.Y))) + @test hasvalue(y, @varname(b)) + @test getvalue(y, @varname(b)) == y.b + @test hasvalue(y, @varname(b[1, 2])) + @test getvalue(y, @varname(b[1, 2])) == y.b[1, 2] + @test hasvalue(y, @varname(b[:])) + @test getvalue(y, @varname(b[:])) == y.b[:] + @test hasvalue(y, @varname(b[DD.X(1)])) + @test getvalue(y, @varname(b[DD.X(1)])) == y.b[DD.X(1)] + @test hasvalue(y, @varname(b[DD.X(1), DD.Y(2)])) + @test getvalue(y, @varname(b[DD.X(1), DD.Y(2)])) == y.b[DD.X(1), DD.Y(2)] + end end @testset "with Distributions: getvalue + hasvalue" begin @@ -215,3 +251,5 @@ end @test !hasvalue(d, @varname(y), LKJCholesky(3, 1.0, :U); error_on_incomplete=true) end end + +end diff --git a/test/varname/leaves.jl b/test/varname/leaves.jl new file mode 100644 index 00000000..e21ed29c --- /dev/null +++ b/test/varname/leaves.jl @@ -0,0 +1,96 @@ +module VarNameLeavesTests + +using AbstractPPL +using Test +using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky + +@testset "varname/leaves.jl" verbose = true begin + @testset "single value: float, int" begin + x = 1.0 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + x = 2 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + end + + @testset "Vector" begin + x = randn(2) + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x[1]), @varname(x[2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])]) + x = [(; a=1), (; b=2)] + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x[1].a), @varname(x[2].b)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x[1].a), x[1].a), (@varname(x[2].b), x[2].b)]) + end + + @testset "Matrix" begin + x = randn(2, 2) + @test Set(varname_leaves(@varname(x), x)) == Set([ + @varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2]) + ]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[1, 2]), x[1, 2]), + (@varname(x[2, 1]), x[2, 1]), + (@varname(x[2, 2]), x[2, 2]), + ]) + end + + @testset "Lower/UpperTriangular" begin + x = randn(2, 2) + xl = LowerTriangular(x) + @test Set(varname_leaves(@varname(x), xl)) == + Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[2, 1]), x[2, 1]), + (@varname(x[2, 2]), x[2, 2]), + ]) + xu = UpperTriangular(x) + @test Set(varname_leaves(@varname(x), xu)) == + Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[1, 2]), x[1, 2]), + (@varname(x[2, 2]), x[2, 2]), + ]) + end + + @testset "NamedTuple" begin + x = (a=1.0, b=[2.0, 3.0]) + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x.a), @varname(x.b[1]), @varname(x.b[2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x.a), x.a), (@varname(x.b[1]), x.b[1]), (@varname(x.b[2]), x.b[2]) + ]) + end + + @testset "Cholesky" begin + x = cholesky([1.0 0.5; 0.5 1.0]) + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x.U[1, 1]), x.U[1, 1]), + (@varname(x.U[1, 2]), x.U[1, 2]), + (@varname(x.U[2, 2]), x.U[2, 2]), + ]) + end + + @testset "fallback on other types, e.g. string" begin + x = "a string" + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + x = 2 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + end +end + +end # module diff --git a/test/varname/optic.jl b/test/varname/optic.jl new file mode 100644 index 00000000..c4e3b736 --- /dev/null +++ b/test/varname/optic.jl @@ -0,0 +1,413 @@ +module OpticTests + +using Test +using DimensionalData: DimensionalData as DD +using AbstractPPL + +@testset "varname/optic.jl" verbose = true begin + @testset "pretty-printing" begin + @test string(@opticof(_.a.b.c)) == "Optic(.a.b.c)" + @test string(@opticof(_[1][2][3])) == "Optic([1][2][3])" + @test string(@opticof(_["a"][:b])) == "Optic([\"a\"][:b])" + @test string(@opticof(_)) == "Optic()" + @test string(@opticof(_[begin])) == "Optic([DynamicIndex(begin)])" + @test string(@opticof(_[2:end])) == "Optic([DynamicIndex(2:end)])" + @test string(with_mutation(@opticof(_.a.b.c))) == "Optic!!(.a.b.c)" + @test string(with_mutation(@opticof(_[1][2][3]))) == "Optic!!([1][2][3])" + @test string(with_mutation(@opticof(_["a"][:b]))) == "Optic!!([\"a\"][:b])" + @test string(with_mutation(@opticof(_))) == "Optic!!()" + @test string(with_mutation(@opticof(_[begin]))) == "Optic!!([DynamicIndex(begin)])" + @test string(with_mutation(@opticof(_[2:end]))) == "Optic!!([DynamicIndex(2:end)])" + end + + @testset "equality" begin + optics = ( + @opticof(_), + @opticof(_[1]), + @opticof(_.a), + @opticof(_[begin]), + @opticof(_[end]), + @opticof(_[:]), + @opticof(_.a[2]), + @opticof(_.a[1, :]), + @opticof(_[1].a), + @opticof(_[1, x=1].a), + ) + for (i, optic1) in enumerate(optics) + for (j, optic2) in enumerate(optics) + if i == j + @test optic1 == optic2 + @test with_mutation(optic1) == with_mutation(optic2) + @test isequal(optic1, optic2) + @test isequal(with_mutation(optic1), with_mutation(optic2)) + else + @test optic1 != optic2 + @test with_mutation(optic1) != with_mutation(optic2) + @test !isequal(optic1, optic2) + @test !isequal(with_mutation(optic1), with_mutation(optic2)) + end + end + end + end + + @testset "composition" begin + @testset "with identity" begin + i = AbstractPPL.Iden() + o = @opticof(_.a.b) + @test i ∘ i == i + @test i ∘ o == o + @test o ∘ i == o + end + + o1 = @opticof(_.a.b) + o2 = @opticof(_[1][2]) + @test o1 ∘ o2 == @opticof(_[1][2].a.b) + @test o2 ∘ o1 == @opticof(_.a.b[1][2]) + @test cat(o1, o2) == @opticof(_.a.b[1][2]) + @test cat(o2, o1) == @opticof(_[1][2].a.b) + @test cat(o1, o2, o2, o1) == @opticof(_.a.b[1][2][1][2].a.b) + end + + @testset "decomposition" begin + @testset "specification" begin + @test ohead(@opticof _.a.b.c) == @opticof _.a + @test otail(@opticof _.a.b.c) == @opticof _.b.c + @test oinit(@opticof _.a.b.c) == @opticof _.a.b + @test olast(@opticof _.a.b.c) == @opticof _.c + + @test ohead(@opticof _[1][2][3]) == @opticof _[1] + @test otail(@opticof _[1][2][3]) == @opticof _[2][3] + @test oinit(@opticof _[1][2][3]) == @opticof _[1][2] + @test olast(@opticof _[1][2][3]) == @opticof _[3] + + @test ohead(@opticof _.a) == @opticof _.a + @test otail(@opticof _.a) == @opticof _ + @test oinit(@opticof _.a) == @opticof _ + @test olast(@opticof _.a) == @opticof _.a + + @test ohead(@opticof _[1]) == @opticof _[1] + @test otail(@opticof _[1]) == @opticof _ + @test oinit(@opticof _[1]) == @opticof _ + @test olast(@opticof _[1]) == @opticof _[1] + + @test ohead(@opticof _) == @opticof _ + @test otail(@opticof _) == @opticof _ + @test oinit(@opticof _) == @opticof _ + @test olast(@opticof _) == @opticof _ + end + + @testset "invariants" begin + optics = ( + @opticof(_), + @opticof(_[1]), + @opticof(_.a), + @opticof(_.a.b), + @opticof(_[1].a), + @opticof(_[1, x=1].a), + @opticof(_[].a[:]), + ) + for optic in optics + @test olast(optic) ∘ oinit(optic) == optic + @test otail(optic) ∘ ohead(optic) == optic + end + end + end + + @testset "getting and setting" begin + @testset "basic" begin + v = (a=(b=42, c=3.14), d=[0.0 1.0; 2.0 3.0]) + @test @opticof(_.a)(v) == v.a + @test set(v, @opticof(_.a), nothing) == (a=nothing, d=v.d) + @test @opticof(_.a.b)(v) == v.a.b + @test set(v, @opticof(_.a.b), 100) == (a=(b=100, c=v.a.c), d=v.d) + @test @opticof(_.a.c)(v) == v.a.c + @test set(v, @opticof(_.a.c), 2.71) == (a=(b=v.a.b, c=2.71), d=v.d) + @test @opticof(_.d)(v) == v.d + @test set(v, @opticof(_.d), zeros(2, 2)) == (a=v.a, d=zeros(2, 2)) + @test @opticof(_.d[1])(v) == v.d[1] + @test set(v, @opticof(_.d[1]), 9.0) == (a=v.a, d=[9.0 1.0; 2.0 3.0]) + @test @opticof(_.d[2])(v) == v.d[2] + @test set(v, @opticof(_.d[2]), 9.0) == (a=v.a, d=[0.0 1.0; 9.0 3.0]) + @test @opticof(_.d[3])(v) == v.d[3] + @test set(v, @opticof(_.d[3]), 9.0) == (a=v.a, d=[0.0 9.0; 2.0 3.0]) + @test @opticof(_.d[4])(v) == v.d[4] + @test set(v, @opticof(_.d[4]), 9.0) == (a=v.a, d=[0.0 1.0; 2.0 9.0]) + @test @opticof(_.d[:])(v) == v.d[:] + @test set(v, @opticof(_.d[:]), fill(9.9, 2, 2)) == (a=v.a, d=fill(9.9, 2, 2)) + end + + @testset "no indices" begin + x = [0.0] + @test @opticof(_[])(x) == x[] + @test set(x, @opticof(_[]), 9.0) == [9.0] + x = 0.0 + @test @opticof(_[])(x) == x[] + @test set(x, @opticof(_[]), 9.0) == 9.0 + end + + @testset "vector of undef" begin + # For this test to be meaningful, the eltype of `x` must be abstract. This sort + # of situation happens in DynamicPPL's demo models -- the use of `Real` as + # eltype is meant to help with ForwardDiff (even though *technically* that is + # not necessary... see + # https://github.com/TuringLang/DynamicPPL.jl/issues/823#issuecomment-3166049286) + x = Vector{Real}(undef, 3) + optic = @opticof(_[2]) + @test_throws UndefRefError optic(x) + x2 = set(x, optic, 3.14) + @test x2[2] == 3.14 + @test !isassigned(x2, 1) + @test isassigned(x2, 2) + @test !isassigned(x2, 3) + end + + @testset "dynamic indices" begin + x = [0.0 1.0; 2.0 3.0] + @test @opticof(_[begin])(x) == x[begin] + @test set(x, @opticof(_[begin]), 9.0) == [9.0 1.0; 2.0 3.0] + @test @opticof(_[end])(x) == x[end] + @test set(x, @opticof(_[end]), 9.0) == [0.0 1.0; 2.0 9.0] + @test @opticof(_[1:end, 2])(x) == x[1:end, 2] + @test set(x, @opticof(_[1:end, 2]), [9.0; 8.0]) == [0.0 9.0; 2.0 8.0] + end + + @testset "unusual indices" begin + x = randn(3, 3) + @test @opticof(_[1:2:4])(x) == x[1:2:4] + @test @opticof(_[CartesianIndex(1, 1)])(x) == x[CartesianIndex(1, 1)] + + # `Not` is actually from InvertedIndices.jl (but re-exported by DimensionalData) + @test @opticof(_[DD.Not(3)])(x) == x[DD.Not(3)] + + # DimArray selectors + dimarray = DD.DimArray(randn(2, 3), (DD.X, DD.Y)) + @test @opticof(_[DD.X(1)])(dimarray) == dimarray[DD.X(1)] + + # Symbols on NamedTuples + nt = (a=10, b=20, c=30) + @test @opticof(_[:a])(nt) == nt[:a] + @test set(nt, @opticof(_[:b]), 99) == (a=10, b=99, c=30) + + # Strings on Dicts + dict = Dict("one" => 1, "two" => 2) + @test @opticof(_["two"])(dict) == dict["two"] + @test set(dict, @opticof(_["two"]), 22) == Dict("one" => 1, "two" => 22) + end + + @testset "keyword arguments to getindex" begin + dimarray = DD.DimArray([0.0 1.0; 2.0 3.0], (:x, :y)) + @test @opticof(_[x=1])(dimarray) == dimarray[x=1] + @test set(dimarray, @opticof(_[y=2]), [9.0; 8.0]) == + DD.DimArray([0.0 9.0; 2.0 8.0], (:x, :y)) + end + + @testset "properties on struct" begin + struct SampleStruct + a::Int + b::Float64 + end + s = SampleStruct(3, 1.5) + @test @opticof(_.a)(s) == 3 + @test @opticof(_.b)(s) == 1.5 + @test set(s, @opticof(_.a), 10) == SampleStruct(10, s.b) + @test set(s, @opticof(_.b), 2.5) == SampleStruct(s.a, 2.5) + end + + @testset "nested optics" begin + x = (; a=[(; b=1)]) + optic = @opticof(_.a[1].b) + @test optic(x) == 1 + x2 = set(x, optic, 42) + @test x2 == (; a=[(; b=42)]) + + y = [(; a=[1.0])] + optic2 = @opticof(_[1].a[1]) + @test optic2(y) == 1.0 + y2 = set(y, optic2, 3.14) + @test y2 == [(; a=[3.14])] + end + end + + @testset "mutating versions" begin + @testset "construction and equality" begin + @test with_mutation(@opticof(_.a.b)) != @opticof(_.a.b) + # check that there is no nesting + @test with_mutation(with_mutation(@opticof(_.a.b))) == + with_mutation(@opticof(_.a.b)) + end + + @testset "arrays" begin + @testset "static index" begin + x = zeros(4) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[2])) + @test optic(x) === x[2] + set(x, optic, 1.0) + @test x[2] == 1.0 + @test x == [0.0, 1.0, 0.0, 0.0] + @test objectid(x) == old_objid + end + + @testset "vector of undef" begin + # eltype(x) must be abstract for this test to be meaningful. See above for + # discussion. + x = Vector{Real}(undef, 3) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[2])) + @test_throws UndefRefError optic(x) + set(x, optic, 3.14) + @test x[2] == 3.14 + @test !isassigned(x, 1) + @test isassigned(x, 2) + @test !isassigned(x, 3) + @test objectid(x) == old_objid + end + + @testset "dynamic index" begin + x = zeros(2, 2) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[begin, end])) + @test optic(x) === x[begin, end] + set(x, optic, 2.0) + @test x[begin, end] == 2.0 + @test x == [0.0 2.0; 0.0 0.0] + @test objectid(x) == old_objid + end + + @testset "DimArray, setting a single element" begin + dimarray = DD.DimArray(zeros(2, 3), (DD.X, DD.Y)) + old_objid = objectid(dimarray) + optic = with_mutation(@opticof(_[DD.X(1), DD.Y(1)])) + @test optic(dimarray) == dimarray[DD.X(1), DD.Y(1)] + dimarray = set(dimarray, optic, 1.0) + @test dimarray[DD.X(1), DD.Y(1)] == 1.0 + @test collect(dimarray) == [1.0 0.0 0.0; 0.0 0.0 0.0] + @test objectid(dimarray) == old_objid + end + + @testset "keyword index" begin + x = DD.DimArray(zeros(2, 2), (:x, :y)) + old_objid = objectid(x) + optic = with_mutation(@opticof(_[x=1, y=2])) + @test optic(x) === x[x=1, y=2] + set(x, optic, 2.0) + @test x[x=1, y=2] == 2.0 + @test collect(x) == [0.0 2.0; 0.0 0.0] + @test objectid(x) == old_objid + end + + @testset "nested optics" begin + x = (; a=[(; b=1)]) + old_objid = objectid(x) + old_inner_objid = objectid(x.a) + optic = with_mutation(@opticof(_.a[1].b)) + @test optic(x) == 1 + set(x, optic, 42) + @test x == (; a=[(; b=42)]) + # Check that mutation happened at the very bottom level. + @test objectid(x) == old_objid + @test objectid(x.a) == old_inner_objid + + y = [(; a=[1.0])] + old_objid = objectid(y) + old_inner_objid = objectid(y[1]) + old_inner_inner_objid = objectid(y[1].a) + optic2 = with_mutation(@opticof(_[1].a[1])) + @test optic2(y) == 1.0 + set(y, optic2, 3.14) + @test y == [(; a=[3.14])] + @test objectid(y) == old_objid + @test objectid(y[1]) == old_inner_objid + @test objectid(y[1].a) == old_inner_inner_objid + end + end + + @testset "dicts" begin + x = Dict("a" => 1, "b" => 2) + old_objid = objectid(x) + optic = with_mutation(@opticof(_["b"])) + @test optic(x) === x["b"] + set(x, optic, 99) + @test x["b"] == 99 + @test x == Dict("a" => 1, "b" => 99) + @test objectid(x) == old_objid + end + + @testset "mutable structs" begin + mutable struct MutableStruct + a::Int + b::Float64 + end + x = MutableStruct(3, 1.5) + old_objid = objectid(x) + optic = with_mutation(@opticof(_.a)) + @test optic(x) === x.a + set(x, optic, 10) + @test x.a == 10 + @test x.b == 1.5 + @test objectid(x) == old_objid + end + + @testset "fallback for immutable data" begin + @testset "NamedTuple" begin + s = (a=1, b=2) + old_objid = objectid(s) + optic = with_mutation(@opticof(_.a)) + @test optic(s) === s.a + s2 = set(s, optic, 10) + @test s2 == (a=10, b=2) + @test s == (a=1, b=2) + @test objectid(s) == old_objid + end + + # NOTE(penelopeysm): This SHOULD really mutate. It is not an error with + # AbstractPPL, though, it is an interface problem between BangBang and + # DimensionalData (essentially BangBang can't detect that DimArray is mutable). + # + # To be precise, the test fails because BangBang thinks that `DD.Y(1)` is an + # index that extracts a single element from the DimArray. For example, this + # would be the case if it was dimarray[1]. So, BangBang thinks that you can't + # set an array there, and so it falls back to the immutable behavior. + # Specifically, it's this line that returns false: + # https://github.com/JuliaFolds2/BangBang.jl/blob/e92b4c1673a686533b5f9724a198b63d8974d52f/src/base.jl#L528 + @testset "DimArray, setting a vector" begin + dimarray = DD.DimArray(zeros(2, 3), (DD.X, DD.Y)) + old_objid = objectid(dimarray) + optic = with_mutation(@opticof(_[DD.Y(1)])) + @test optic(dimarray) == dimarray[DD.Y(1)] + dimarray = set(dimarray, optic, [1.0, 2.0]) + @test collect(dimarray[DD.Y(1)]) == [1.0; 2.0] + @test collect(dimarray) == [1.0 0.0 0.0; 2.0 0.0 0.0] + # @test objectid(dimarray) == old_objid + end + + @testset "tuple" begin + s = (3, 1.5) + old_objid = objectid(s) + optic = with_mutation(@opticof(_[1])) + @test optic(s) === s[1] + s2 = set(s, optic, 10) + @test s2 == (10, 1.5) + @test s == (3, 1.5) + @test objectid(s) == old_objid + end + + @testset "struct" begin + struct SampleStructAgain + a::Int + b::Float64 + end + s = SampleStructAgain(3, 1.5) + old_objid = objectid(s) + optic = with_mutation(@opticof(_.a)) + @test optic(s) === s.a + s2 = set(s, optic, 10) + @test s2 == SampleStructAgain(10, 1.5) + @test s == SampleStructAgain(3, 1.5) + @test objectid(s) == old_objid + end + end + end +end + +end # module diff --git a/test/varname/prefix.jl b/test/varname/prefix.jl new file mode 100644 index 00000000..2276d5bc --- /dev/null +++ b/test/varname/prefix.jl @@ -0,0 +1,45 @@ +module VarNamePrefixTests + +using Test +using AbstractPPL + +@testset "varname/prefix.jl" verbose = true begin + @testset "basic cases" begin + @test prefix(@varname(y), @varname(x)) == @varname(x.y) + @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) + @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) + @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) + @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) + + @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) + @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) + @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) + @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) + end + + @testset "round-trip + type stability" begin + # This tuple is probably overkill, but the tests are super fast + # anyway. + vns = ( + @varname(p), + @varname(q), + @varname(r[1]), + @varname(s.a), + @varname(t[1].a), + @varname(u[1].a.b), + @varname(v.a[1][2].b.c.d[3]) + ) + for vn1 in vns + for vn2 in vns + prefixed = @inferred prefix(vn1, vn2) + @test subsumes(vn2, prefixed) + unprefixed = @inferred unprefix(prefixed, vn2) + @test unprefixed == vn1 + end + end + end +end + +end # module diff --git a/test/varname/serialize.jl b/test/varname/serialize.jl new file mode 100644 index 00000000..d84d0149 --- /dev/null +++ b/test/varname/serialize.jl @@ -0,0 +1,82 @@ +module VarNameSerialisationTests + +using AbstractPPL +using InvertedIndices: Not, InvertedIndex +using Test + +@testset "varname/serialize.jl" verbose = true begin + @testset "roundtrip" begin + y = ones(10) + z = ones(5, 2) + vns = [ + @varname(x), + @varname(ä), + @varname(x.a), + @varname(x.a.b), + @varname(var"x.a"), + @varname(x[1]), + @varname(var"x[1]"), + @varname(x[1:10]), + @varname(x[1:3:10]), + @varname(x[1, 2]), + @varname(x[1, 2:5]), + @varname(x[:]), + @varname(x.a[1]), + @varname(x.a[1:10]), + @varname(x[1].a), + @varname(y[:]), + @varname(y[begin:end], true), + @varname(y[end], true), + @varname(y[:], false), + @varname(y[:], true), + @varname(z[:], false), + @varname(z[:], true), + @varname(z[:][:], false), + @varname(z[:][:], true), + @varname(z[:, :], false), + @varname(z[:, :], true), + @varname(z[2:5, :], false), + @varname(z[2:5, :], true), + @varname(x[i=1]), + @varname(x[].a[j=2].b[3, 4, 5, [6]]), + ] + for vn in vns + @test string_to_varname(varname_to_string(vn)) == vn + end + + # For this VarName, the {de,}serialisation works correctly but we must + # test in a different way because equality comparison of structs with + # vector fields (such as Accessors.IndexLens) compares the memory + # addresses rather than the contents (thus vn_vec == vn_vec2 returns + # false). + vn_vec = @varname(x[[1, 2, 5, 6]]) + vn_vec2 = string_to_varname(varname_to_string(vn_vec)) + @test hash(vn_vec) == hash(vn_vec2) + end + + @testset "deserialisation fails for unconcretised dynamic indices" begin + for vn in (@varname(x[1:end]), @varname(x[begin:end]), @varname(x[2:step:end])) + @test_throws ArgumentError varname_to_string(vn) + end + end + + @testset "custom index types" begin + vn = @varname(x[Not(3)]) + + # This won't work as we don't yet know how to handle OffsetArray + @test_throws MethodError varname_to_string(vn) + + # Now define the relevant methods + AbstractPPL.index_to_dict(o::InvertedIndex{I}) where {I} = Dict( + "type" => "InvertedIndices.InvertedIndex", + "skip" => AbstractPPL.index_to_dict(o.skip), + ) + AbstractPPL.dict_to_index(::Val{Symbol("InvertedIndices.InvertedIndex")}, d) = + InvertedIndex(AbstractPPL.dict_to_index(d["skip"])) + + # Serialisation should now work + @test string_to_varname(varname_to_string(vn)) == vn + end +end + +end # module diff --git a/test/varname/subsumes.jl b/test/varname/subsumes.jl new file mode 100644 index 00000000..d4187fea --- /dev/null +++ b/test/varname/subsumes.jl @@ -0,0 +1,78 @@ +module VarNameSubsumesTests + +using AbstractPPL +using Test + +@testset "varname/subsumes.jl" verbose = true begin + @testset "varnames that are equal" begin + @test subsumes(@varname(x), @varname(x)) + @test subsumes(@varname(x[1]), @varname(x[1])) + @test subsumes(@varname(x.a), @varname(x.a)) + @test subsumes(@varname(x[:]), @varname(x[:])) + end + + uncomparable(vn1, vn2) = !subsumes(vn1, vn2) && !subsumes(vn2, vn1) + @testset "uncomparable varnames" begin + @test uncomparable(@varname(x), @varname(y)) + @test uncomparable(@varname(x.a), @varname(y.a)) + @test uncomparable(@varname(a.x), @varname(a.y)) + @test uncomparable(@varname(a.x[1]), @varname(a.x.z)) + @test uncomparable(@varname(x[1]), @varname(y[1])) + @test uncomparable(@varname(x[1]), @varname(x.y)) + end + + strictly_subsumes(vn1, vn2) = subsumes(vn1, vn2) && !subsumes(vn2, vn1) + @testset "strict subsumption - no index comparisons" begin + @test strictly_subsumes(@varname(x), @varname(x.a)) + @test strictly_subsumes(@varname(x), @varname(x[1])) + @test strictly_subsumes(@varname(x), @varname(x[2:2:5])) + @test strictly_subsumes(@varname(x), @varname(x[10, 20])) + @test strictly_subsumes(@varname(x.a), @varname(x.a.b)) + @test strictly_subsumes(@varname(x[1]), @varname(x[1].a)) + @test strictly_subsumes(@varname(x.a), @varname(x.a[1])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:10][2])) + end + + @testset "strict subsumption - index comparisons" begin + @testset "integer vectors" begin + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[1:5])) + @test strictly_subsumes(@varname(x[1:10]), @varname(x[4:6])) + @test strictly_subsumes(@varname(x[1:10, 1:10]), @varname(x[1:5, 1:5])) + @test strictly_subsumes(@varname(x[[5, 4, 3, 2, 1]]), @varname(x[[2, 4]])) + end + + @testset "non-integer indices" begin + @test strictly_subsumes(@varname(x[:a]), @varname(x[:a][1])) + end + + @testset "colon" begin + @test strictly_subsumes(@varname(x[:]), @varname(x[1])) + @test strictly_subsumes(@varname(x[:, 1:10]), @varname(x[1:10, 1])) + end + + @testset "dynamic indices" begin + @test strictly_subsumes(@varname(x), @varname(x[begin])) + @test subsumes(@varname(x[begin]), @varname(x[begin])) + @test strictly_subsumes(@varname(x[:]), @varname(x[begin])) + @test strictly_subsumes(@varname(x), @varname(x[end])) + @test subsumes(@varname(x[end]), @varname(x[end])) + @test strictly_subsumes(@varname(x[:]), @varname(x[end])) + @test strictly_subsumes(@varname(x[:]), @varname(x[1:end])) + @test strictly_subsumes(@varname(x[:]), @varname(x[end - 3])) + @test uncomparable(@varname(x[begin]), @varname(x["a"])) + @test uncomparable(@varname(x[begin]), @varname(x[1:5])) + end + + @testset "keyword indices" begin + @test strictly_subsumes(@varname(x), @varname(x[a=1])) + @test strictly_subsumes(@varname(x[a=1:10, b=1:10]), @varname(x[a=1:10])) + @test strictly_subsumes(@varname(x[a=1:10, b=1:10]), @varname(x[a=1:5, b=1:5])) + @test strictly_subsumes(@varname(x[a=:]), @varname(x[a=1])) + @test uncomparable(@varname(x[a=1:10, b=5]), @varname(x[a=5, b=1:10])) + @test uncomparable(@varname(x[a=1]), @varname(x[b=1])) + end + end +end + +end # module diff --git a/test/varname/varname.jl b/test/varname/varname.jl new file mode 100644 index 00000000..8ddb7802 --- /dev/null +++ b/test/varname/varname.jl @@ -0,0 +1,225 @@ +module VarNameTests + +using AbstractPPL +using Test +using JET: @test_call + +@testset "varname/varname.jl" verbose = true begin + @testset "basic construction (and type stability)" begin + @test @varname(x) == (@inferred VarName{:x}(Iden())) + @test @varname(x[1]) == (@inferred VarName{:x}(Index((1,), (;), Iden()))) + @test @varname(x.a) == (@inferred VarName{:x}(Property{:a}(Iden()))) + @test @varname(x.a[1]) == + (@inferred VarName{:x}(Property{:a}(Index((1,), (;), Iden())))) + end + + @testset "errors on invalid inputs" begin + # Note: have to wrap in eval to avoid throwing an error before the actual test + errmsg = "malformed variable name" + @test_throws errmsg eval(:(@varname(1))) + @test_throws errmsg eval(:(@varname(x + y))) + @test_throws MethodError eval(:(@varname(x[1:Colon()]))) + end + + @testset "equality and hash" begin + function check_doubleeq_and_hash(vn1, vn2, is_equal) + if is_equal + @test vn1 == vn2 + @test hash(vn1) == hash(vn2) + else + @test vn1 != vn2 + @test hash(vn1) != hash(vn2) + end + end + check_doubleeq_and_hash(@varname(x), @varname(x), true) + check_doubleeq_and_hash(@varname(x), @varname(y), false) + check_doubleeq_and_hash(@varname(x[1]), @varname(x[1]), true) + check_doubleeq_and_hash(@varname(x[1]), @varname(x[2]), false) + check_doubleeq_and_hash(@varname(x.a), @varname(x.a), true) + check_doubleeq_and_hash(@varname(x.a), @varname(x.b), false) + check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.a[1]), true) + check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.a[2]), false) + check_doubleeq_and_hash(@varname(x.a[1]), @varname(x.b[1]), false) + check_doubleeq_and_hash(@varname(x[1, i=2]), @varname(x[1, i=2]), true) + check_doubleeq_and_hash(@varname(x[i=2, 4]), @varname(x[4, i=2]), true) + + @testset "dynamic indices" begin + check_doubleeq_and_hash(@varname(x[begin]), @varname(x[begin]), true) + check_doubleeq_and_hash(@varname(x[end]), @varname(x[end]), true) + check_doubleeq_and_hash(@varname(x[begin]), @varname(x[end]), false) + check_doubleeq_and_hash(@varname(x[begin + 1]), @varname(x[begin + 1]), true) + check_doubleeq_and_hash(@varname(x[begin + 1]), @varname(x[(begin + 1)]), true) + check_doubleeq_and_hash(@varname(x[begin + 1]), @varname(x[begin + 2]), false) + check_doubleeq_and_hash(@varname(x[end - 1]), @varname(x[end - 1]), true) + check_doubleeq_and_hash(@varname(x[end - 1]), @varname(x[end - 2]), false) + check_doubleeq_and_hash( + @varname(x[(begin * end - begin):end]), + @varname(x[((begin * end) - begin):end]), + true, + ) + end + end + + @testset "JET on equality + dynamic dispatch" begin + # This test is very specific, so some context is needed: + # + # In DynamicPPL it's quite common to want to search for a VarName in a collection of + # VarNames. Usually the collection will not have a concrete element type (because + # it's a mixture of different optics). Thus, there will be a fair amount of dynamic + # dispatch when performing the comparisons. + # + # In AbstractPPL, there is custom code in `src/varname/optic.jl` to make sure that + # equality comparisons of `Index` lenses are JET-friendly even when this happens + # (i.e., JET.jl doesn't error on `@report_call`). These were needed because base + # Julia's equality methods on tuples error with JET: + # https://github.com/JuliaLang/julia/issues/60470, and using those default methods + # would cause test failures in DynamicPPLJETExt. + # + # This test therefore makes sure we don't cause any regressions. + vns = [@varname(x), @varname(x[1]), @varname(x.a)] + for vn in vns + @test_call any(k -> k == vn, vns) + end + end + + @testset "pretty-printing" begin + @test string(@varname(x)) == "x" + @test string(@varname(x[1])) == "x[1]" + @test string(@varname(x.a)) == "x.a" + @test string(@varname(x.a[1])) == "x.a[1]" + @test string(@varname(x[begin])) == "x[DynamicIndex(begin)]" + @test string(@varname(x[end])) == "x[DynamicIndex(end)]" + @test string(@varname(x[:])) == "x[:]" + @test string(@varname(x[1, i=3])) == "x[1, i=3]" + end + + @testset "dynamic indices and manual concretization" begin + @testset "begin" begin + vn = @varname(x[begin]) + @test vn isa VarName + @test is_dynamic(vn) + @test concretize(vn, [1.0]) == @varname(x[1]) + end + + @testset "end" begin + vn = @varname(x[end]) + @test vn isa VarName + @test is_dynamic(vn) + @test concretize(vn, randn(5)) == @varname(x[5]) + end + + @testset "expressions thereof" begin + vn = @varname(x[(begin + 2):(end - 1)]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(6) + @test concretize(vn, arr) == @varname(x[3:5]) + end + + @testset "different dimensions" begin + vn = @varname(x[end, begin:(end - 1)]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(4, 4) + @test concretize(vn, arr) == @varname(x[4, 1:3]) + end + + @testset "linear indexing for matrices" begin + vn = @varname(x[begin:end]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(4, 4) + @test concretize(vn, arr) == @varname(x[1:16]) + end + + @testset "nested" begin + vn = @varname(x.a[end].b) + @test vn isa VarName + @test is_dynamic(vn) + @test concretize(vn, (; a=[(; b=1)])) == @varname(x.a[1].b) + end + end + + @testset "things that shouldn't be dynamic aren't dynamic" begin + @test !is_dynamic(@varname(x)) + @test !is_dynamic(@varname(x[3])) + @test !is_dynamic(@varname(x[:])) + @test !is_dynamic(@varname(x[1:3])) + @test !is_dynamic(@varname(x[1:3, 3, 2 + 9])) + i = 10 + @test !is_dynamic(@varname(x[1:3, 3, 2 + 9, 1:3:i])) + @test !is_dynamic(@varname(x[k=i])) + end + + @testset "automatic concretization" begin + test_array = randn(5, 5) + @testset "begin" begin + vn = @varname(test_array[begin], true) + @test vn == concretize(@varname(test_array[begin]), test_array) + end + @testset "end" begin + vn = @varname(test_array[end], true) + @test vn == concretize(@varname(test_array[end]), test_array) + end + @testset "expressions thereof" begin + vn = @varname(test_array[(begin + 1):(end - 2)], true) + @test vn == concretize(@varname(test_array[(begin + 1):(end - 2)]), test_array) + end + @testset "different dimensions" begin + vn = @varname(test_array[end, begin:(end - 1)], true) + @test vn == concretize(@varname(test_array[end, begin:(end - 1)]), test_array) + end + @testset "linear indexing for matrices" begin + vn = @varname(test_array[begin:end], true) + @test vn == concretize(@varname(test_array[begin:end]), test_array) + end + end + + @testset "interpolation" begin + @testset "of property names" begin + prop = :myprop + vn = @varname(x.$prop) + @test vn == @varname(x.myprop) + end + + @testset "of indices" begin + idx = 3 + @test @varname(x[idx]) == @varname(x[3]) + @test @varname(x[2 * idx]) == @varname(x[6]) + @test @varname(x[1:idx]) == @varname(x[1:3]) + @test @varname(x[k=idx]) == @varname(x[k=3]) + @test @varname(x[k=2 * idx]) == @varname(x[k=6]) + end + + @testset "with dynamic indices" begin + idx = 3 + vn = @varname(x[end - idx]) + @test vn isa VarName + @test is_dynamic(vn) + arr = randn(6) + @test concretize(vn, arr) == @varname(x[3]) + # Note that `idx` is only resolved at concretization time (because it's stored + # in a function that looks like (val -> lastindex(val) - idx) -- the VALUE of + # `idx` is not interpolated at macro time because we have no way of obtaining + # values inside the macro). So we could change it and re-concretize... + idx = 4 + @test concretize(vn, arr) == @varname(x[2]) + end + + @testset "of top-level name" begin + name = :x + @test @varname($name) == @varname(x) + @test @varname($name[1]) == @varname(x[1]) + @test @varname($name.a) == @varname(x.a) + end + + @testset "mashup of everything" begin + name = :x + index = 2 + prop = :b + @test @varname($name.$prop[3 * index]) == @varname(x.b[6]) + end + end +end + +end # module VarNameTests