Skip to content

Commit 0eb7992

Browse files
committed
add IdentityUpdate for NoConstraint
1 parent 8fcc5a0 commit 0eb7992

File tree

4 files changed

+27
-7
lines changed

4 files changed

+27
-7
lines changed

src/BlockTensorFactorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ export LipschitzStep, ConstantStep, SPGStep
7272
export AbstractUpdate
7373
export GradientDescent, MomentumUpdate
7474

75-
export ConstraintUpdate, GenericConstraintUpdate
75+
export ConstraintUpdate, GenericConstraintUpdate, IdentityUpdate
7676
export Projection, NNProjection, SafeNNProjection, Rescale
7777

7878
export BlockedUpdate

src/Core/Core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ export LipschitzStep, ConstantStep, SPGStep
8282
export AbstractUpdate, AbstractGradientDescent
8383
export GradientDescent, BlockGradientDescent, MomentumUpdate
8484

85-
export ConstraintUpdate, GenericConstraintUpdate
85+
export ConstraintUpdate, GenericConstraintUpdate, IdentityUpdate
8686
export Projection, NNProjection, SafeNNProjection, Rescale
8787

8888
export BlockedUpdate

src/Core/blockupdates.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ end
464464
Converts an AbstractConstraint to a ConstraintUpdate on the factor n
465465
"""
466466
ConstraintUpdate(n, constraint::AbstractConstraint; kwargs...) = error("converting $(typeof(constraint)) to a ConstraintUpdate is not yet supported")
467+
ConstraintUpdate(n, constraint::NoConstraint; kwargs...) = IdentityUpdate(n)
467468
ConstraintUpdate(n, constraint::GenericConstraint; kwargs...) = GenericConstraintUpdate(n, constraint)
468469
ConstraintUpdate(n, constraint::ProjectedNormalization; kwargs...) = Projection(n, constraint)
469470
ConstraintUpdate(n, constraint::Entrywise; kwargs...) = Projection(n, constraint)
@@ -491,6 +492,15 @@ end
491492

492493
check(_::ConstraintUpdate, _::AbstractDecomposition) = error("checking $(typeof(constraint)) is not yet supported")
493494

495+
struct IdentityUpdate <: ConstraintUpdate
496+
n::Integer
497+
end
498+
499+
check(U::IdentityUpdate, D::AbstractDecomposition) = true
500+
getconstraint(_::IdentityUpdate) = noconstraint
501+
502+
(_::IdentityUpdate)(x::T; kwargs...) where T = x
503+
494504
struct GenericConstraintUpdate <: ConstraintUpdate
495505
n::Integer
496506
constraint::GenericConstraint

test/runtests.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ const VERBOSE = true
109109
@test cpproduct(A, B, C) Y_array
110110
@test Y Y_array
111111
@test Y[1, 2, 3] sum(A[1, r] * B[2, r] * C[3, r] for r in 1:R)
112+
@test Y_array sum(outer_product((A[:, r], B[:, r], C[:, r])) for r in 1:R)
112113
end
113114

114115
@testset "Constraints" begin
@@ -415,15 +416,26 @@ end
415416
end
416417

417418
@testset "BlockUpdates" begin
419+
# Update from a NoConstraint
420+
G1 = Tucker1((3, 3, 3), 2)
421+
G2 = deepcopy(G1)
422+
U = ConstraintUpdate(1, noconstraint)
423+
424+
U(G1)
425+
@test G1 == G2 # should actually do nothing so using == rather than ≈
426+
427+
# Update from ProjectedNormalization
418428
G1 = CPDecomposition((3,3,3), 2)
419429
G2 = deepcopy(G1)
430+
420431
U = ConstraintUpdate(2, l2normalize_cols!)
421432

422433
U(G1)
423434
l2normalize_cols!(factor(G2, 2))
424435

425436
@test G1 G2
426437

438+
# Update from a ComposedConstraint
427439
U = ConstraintUpdate(1, l2scale! nonnegative!)
428440
@test U.n == 1
429441

@@ -656,13 +668,11 @@ end
656668
converged=RelativeError,
657669
#converged=(GradientNNCone, RelativeError),
658670
constrain_init=false,
659-
maxiter=1,
660-
#constraints=nonnegative!,
661-
stats=[Iteration, ObjectiveValue, GradientNNCone, RelativeError, PrintStats]
671+
maxiter=2,
672+
constraints=[l2normalize_cols!, l2normalize_cols!, noconstraint],#nonnegative!,
673+
stats=[Iteration, ObjectiveValue, GradientNorm, RelativeError, PrintStats,DisplayDecomposition]
662674
);
663675

664-
665-
666676
# Semi-interesting run of CPDecomposition
667677
N = 100
668678
R = 5

0 commit comments

Comments
 (0)