Skip to content

Commit 8fcc5a0

Browse files
committed
clean outer_product and cpproduct
1 parent ff1cf50 commit 8fcc5a0

File tree

6 files changed

+58
-5
lines changed

6 files changed

+58
-5
lines changed

src/BlockTensorFactorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ export SuperDiagonal, abs_randn, all_recursive, eachfibre, getnotindex, geomean,
1212
export d_dx, d2_dx2, curvature, standard_curvature
1313

1414
#include("./tensorproducts.jl")
15-
export nmp, nmode_product, mtt, slicewise_dot, tuckerproduct, cpproduct
15+
export nmp, nmode_product, mtt, slicewise_dot, tuckerproduct, cpproduct, outer_product
1616
export ×₁, ×₂, ×₃, ×₄, ×₅, ×₆, ×₇, ×₈, ×
1717
export ₁, ₂, ₃, ₄, ₅, ₆, ₇, ₈,
1818

src/Core/Core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ include("./curvaturetools.jl")
2222
export d_dx, d2_dx2, curvature, standard_curvature
2323

2424
include("./tensorproducts.jl")
25-
export nmp, nmode_product, mtt, slicewise_dot, tuckerproduct, cpproduct
25+
export nmp, nmode_product, mtt, slicewise_dot, tuckerproduct, cpproduct, outer_product
2626
export ×₁, ×₂, ×₃, ×₄, ×₅, ×₆, ×₇, ×₈, ×
2727
export ₁, ₂, ₃, ₄, ₅, ₆, ₇, ₈,
2828

src/Core/decomposition.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,6 @@ end
441441
factors(CPD::CPDecomposition) = CPD.factors
442442
array(CPD::CPDecomposition) = cpproduct(factors(CPD))
443443
frozen(CPD::CPDecomposition) = CPD.frozen
444-
vector_outer(v) = reshape(kron(reverse(v)...),length.(v))
445444
eachfactorindex(CPD::CPDecomposition) = 1:nfactors(CPD) # unlike other AbstractTucker's, back to 1 based since there's only matrix factors
446445
isfrozen(CPD::CPDecomposition, n::Integer) = n == 0 ? true : frozen(CPD)[n] # similar to eachfactorindex
447446

src/Core/factorize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function default_kwargs(Y; kwargs...)
153153

154154
get!(kwargs, :rank) do # Can also be a tuple. For example, Tucker rank could be (1, 2, 3) for an order 3 array Y
155155
isnothing(kwargs[:decomposition]) ? error("`rank_detect_factorize` should be called if no initial decomposition nor the rank is provided.") : rankof(kwargs[:decomposition])
156-
end
156+
end # TODO add a check that the keyword rank is compatible with the model and decomposition
157157
get!(kwargs, :init) do
158158
isnonnegative(Y) ? abs_randn : randn
159159
end

src/Core/tensorproducts.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,25 @@ function _gettuckerindex(core, matrices, I)
195195
return Y
196196
end
197197

198+
"""
199+
outer_product(vectors)
200+
outer_product(vectors...)
201+
202+
Outer product of a collection of vectors.
203+
204+
For example,
205+
206+
`outer_product(u, v) == u * v'`
207+
208+
and
209+
210+
`outer_product(u, v, w)[i, j, k] == u[i] * v[j] * w[k]`.
211+
212+
Returned array will have same number dimensions as the length of the collection.
213+
"""
214+
outer_product(vectors) = reshape(kron(reverse(vectors)...),length.(vectors))
215+
outer_product(vectors...) = outer_product(vectors)
216+
198217
"""
199218
cpproduct((A, B, C, ...))
200219
cpproduct(A, B, C, ...)
@@ -205,7 +224,7 @@ Example
205224
-------
206225
cpproduct(A, B, C) == @einsum T[i, j, k] := A[i, r] * B[j, r] * C[k, r]
207226
"""
208-
cpproduct(matrices) = mapreduce(vector_outer, +, zip((eachcol.(matrices))...))
227+
cpproduct(matrices) = mapreduce(outer_product, +, zip((eachcol.(matrices))...))
209228
cpproduct(matrices...) = cpproduct(matrices)
210229

211230
"""

test/runtests.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ const VERBOSE = true
9797
B = randn(10, 20)
9898
@test slicewise_dot(A, A) A*A' # test this separately since it uses a different routine when the argument is the same
9999
@test slicewise_dot(A, B) A*B'
100+
101+
# CP product
102+
R = 2
103+
A = randn(4, R)
104+
B = randn(5, R)
105+
C = randn(6, R)
106+
Y = CPDecomposition((A, B, C))
107+
Y_array = array(Y)
108+
109+
@test cpproduct(A, B, C) Y_array
110+
@test Y Y_array
111+
@test Y[1, 2, 3] sum(A[1, r] * B[2, r] * C[3, r] for r in 1:R)
100112
end
101113

102114
@testset "Constraints" begin
@@ -628,6 +640,29 @@ end
628640
@testset "CPFactorization" begin
629641
fact = BlockTensorFactorization.factorize
630642

643+
# Regular run of CPDecomposition
644+
A = randn(5, 2)
645+
B = randn(5, 2)
646+
C = randn(5, 2)
647+
Y = CPDecomposition((A, B, C))
648+
Y = array(Y)
649+
650+
decomposition, stats, kwargs = fact(Y;
651+
rank=2,
652+
model=CPDecomposition,
653+
momentum=false,
654+
#tolerance=(1, 0.045),
655+
tolerance=0.045,
656+
converged=RelativeError,
657+
#converged=(GradientNNCone, RelativeError),
658+
constrain_init=false,
659+
maxiter=1,
660+
#constraints=nonnegative!,
661+
stats=[Iteration, ObjectiveValue, GradientNNCone, RelativeError, PrintStats]
662+
);
663+
664+
665+
631666
# Semi-interesting run of CPDecomposition
632667
N = 100
633668
R = 5

0 commit comments

Comments
 (0)