Skip to content

Commit ca35171

Browse files
committed
big fix Tucker family gradients
1 parent c513f94 commit ca35171

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

src/Core/blockupdates.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ function make_gradient(T::Tucker, n::Integer, Y::AbstractArray; objective::L2, k
342342
matrices = matrix_factors(X)
343343
gram_matrices = map(A -> A'A, matrices) # gram matrices AA = A'A,
344344
# BB = B'B, ...
345-
grad = tuckerproduct(B, gram_matrices)
346-
- tuckerproduct(Y, adjoint.(matrices))
345+
grad = tuckerproduct(B, gram_matrices) - tuckerproduct(Y, adjoint.(matrices))
347346
return grad
348347
end
349348
return gradient_core
@@ -354,8 +353,7 @@ function make_gradient(T::Tucker, n::Integer, Y::AbstractArray; objective::L2, k
354353
matrices = matrix_factors(X)
355354
Aₙ = factor(X, n)
356355
X̃ₙ = tuckerproduct(B, matrices; exclude=n)
357-
grad = Aₙ * slicewise_dot(X̃ₙ, X̃ₙ; dims=n)
358-
- slicewise_dot(Y, X̃ₙ; dims=n)
356+
grad = Aₙ * slicewise_dot(X̃ₙ, X̃ₙ; dims=n) - slicewise_dot(Y, X̃ₙ; dims=n)
359357
return grad
360358
end
361359
return gradient_matrix
@@ -373,8 +371,7 @@ function make_gradient(T::CPDecomposition, n::Integer, Y::AbstractArray; objecti
373371
matrices = matrix_factors(X)
374372
Aₙ = factor(X, n)
375373
X̃ₙ = tuckerproduct(B, matrices; exclude=n)
376-
grad = Aₙ * slicewise_dot(X̃ₙ, X̃ₙ; dims=n)
377-
- slicewise_dot(Y, X̃ₙ; dims=n)
374+
grad = Aₙ * slicewise_dot(X̃ₙ, X̃ₙ; dims=n) - slicewise_dot(Y, X̃ₙ; dims=n)
378375
return grad
379376
end
380377
return gradient_matrix

0 commit comments

Comments
 (0)