Skip to content

Commit ff22b79

Browse files
committed
rank estimation clean up, add circles method
1 parent c08ac90 commit ff22b79

File tree

6 files changed

+111
-31
lines changed

6 files changed

+111
-31
lines changed

src/BlockTensorFactorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ export getconstraint, smart_insert!, smart_interlace!, group_by_factor
8383

8484
# High level / user-interface
8585
#include("./factorize.jl")
86-
export factorize
86+
export factorize, default_kwargs
8787

8888
#include("./multiscale.jl")
8989
export multiscale_factorize
9090
export coarsen, interpolate, scale_constraint
9191

9292
#include("./rankdetection.jl")
9393
export rank_detect_factorize
94-
export possible_ranks
94+
export max_possible_rank
9595

9696
end # module BlockTensorFactorization

src/Core/Core.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ export getconstraint, smart_insert!, smart_interlace!, group_by_factor
9393

9494
# High level / user-interface
9595
include("./factorize.jl")
96-
export factorize
96+
export factorize, default_kwargs
9797

9898
include("./multiscale.jl")
9999
export multiscale_factorize
100100
export coarsen, interpolate, scale_constraint
101101

102102
include("./rankdetection.jl")
103103
export rank_detect_factorize
104-
export possible_ranks
104+
export max_possible_rank
105105

106106
end # module Core

src/Core/curvaturetools.jl

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,35 +296,37 @@ end
296296

297297
"""Extracts the first and second derivatives of the splines at the knots"""
298298
function d_dx_and_d2_dx2_spline(y::AbstractVector{<:Real}; h=1)
299-
_, b, c, _ = cubic_spline_coefficients(y::AbstractVector{<:Real}; h=1)
299+
_, b, c, _ = cubic_spline_coefficients(y::AbstractVector{<:Real}; h)
300300
dy_dx = c
301301
dy2_dx2 = 2b
302302
return dy_dx, dy2_dx2
303303
end
304304

305305

306306
"""
307-
curvature(y::AbstractVector{<:Real})
307+
curvature(y::AbstractVector{<:Real}; method=:finite_differences)
308308
309309
Approximates the signed curvature of a function given evenly spaced samples.
310310
311311
Uses [`d_dx`](@ref) and [`d2_dx2`](@ref) to approximate the first two derivatives.
312312
"""
313313
function curvature(y::AbstractVector{<:Real}; method=:finite_differences, kwargs...)
314-
if method == finite_differences
314+
if method == :finite_differences
315315
dy_dx = d_dx(y; kwargs...)
316316
dy2_dx2 = d2_dx2(y; kwargs...)
317317
return @. dy2_dx2 / (1 + dy_dx^2)^1.5
318318
elseif method == :splines
319319
dy_dx, dy2_dx2 = d_dx_and_d2_dx2_spline(y; h=1)
320320
return @. dy2_dx2 / (1 + dy_dx^2)^1.5
321+
elseif method == :circles
322+
return circumscribed_standard_curvature(y)
321323
else
322324
throw(ArgumentError("method $method not implemented"))
323325
end
324326
end
325327

326328
"""
327-
standard_curvature(y::AbstractVector{<:Real})
329+
standard_curvature(y::AbstractVector{<:Real}; method=:finite_differences)
328330
329331
Approximates the signed curvature of a function, scaled to the unit box ``[0,1]^2``.
330332
@@ -341,32 +343,62 @@ function standard_curvature(y::AbstractVector{<:Real}; method=:finite_difference
341343
# y_max = 1
342344
dy_dx, dy2_dx2 = d_dx_and_d2_dx2_spline(y; h=Δx)
343345
return @. dy2_dx2 / (1 + dy_dx^2)^1.5
346+
elseif method == :circles
347+
return circumscribed_standard_curvature(y)
344348
else
345349
throw(ArgumentError("method $method not implemented"))
346350
end
347351
end
348352

349-
"""
350-
Finds the radius of the circumscribed circle between points (a,f), (b,g), (c,h)
351-
"""
352-
function circumscribed_radius((a,f),(b,g),(c,h))
353-
d = 2*(a*(g-h)+b*(h-f)+c*(f-g))
354-
p = ((a^2+f^2)*(g-h)+(b^2+g^2)*(h-f)+(c^2+h^2)*(f-g)) / d
355-
q = ((a^2+f^2)*(b-c)+(b^2+g^2)*(c-a)+(c^2+h^2)*(a-b)) / d
356-
r = sqrt((a-p)^2+(f-q)^2)
357-
return r
358-
end
353+
# """
354+
# Finds the radius of the circumscribed circle between points (a,f), (b,g), (c,h)
355+
# """
356+
# function circumscribed_radius((a,f),(b,g),(c,h))
357+
# d = 2*(a*(g-h)+b*(h-f)+c*(f-g))
358+
# p = ((a^2+f^2)*(g-h)+(b^2+g^2)*(h-f)+(c^2+h^2)*(f-g)) / d
359+
# q = ((a^2+f^2)*(b-c)+(b^2+g^2)*(c-a)+(c^2+h^2)*(a-b)) / d
360+
# r = sqrt((a-p)^2+(f-q)^2)
361+
# return r
362+
# end
359363

360364
function circumscribed_standard_curvature(y)
361-
n = length(v)
365+
n = length(y)
362366
ymax = maximum(y)
363367
y = y / ymax
364-
k = zero(ymax)
368+
k = zero(y)
365369
a, b, c = 0, 1/n, 2/n
366370
for i in eachindex(k)[2:end-1]
367-
k[i] = 1 / circumscribed_radius((a,y[i-1]),(b,y[i]),(c,y[i+1]))
371+
k[i] = signed_circle_curvature((a,y[i-1]),(b,y[i]),(c,y[i+1]))
372+
#k[i] = 1 / circumscribed_radius((a,y[i-1]),(b,y[i]),(c,y[i+1]))
368373
end
369374
k[1] = k[2]
370375
k[end] = k[end-1]
371376
return k
372377
end
378+
379+
"""radius r and center point (p,q) of the circle passing through the three points"""
380+
function three_point_circle((a,f),(b,g),(c,h))
381+
fg = f-g
382+
gh = g-h
383+
hf = h-f
384+
ab = a-b
385+
bc = b-c
386+
ca = c-a
387+
a2 = a^2
388+
b2 = b^2
389+
c2 = c^2
390+
f2 = f^2
391+
g2 = g^2
392+
h2 = h^2
393+
p = (a2*gh + b2*hf + c2*fg - gh*hf*fg) / (a*gh + b*hf + c*fg) / 2
394+
q = (f2*bc + g2*ca + h2*ab - bc*ca*ab) / (f*bc + g*ca + h*ab) / 2
395+
r = sqrt((a-p)^2 + (f-q)^2)
396+
return r, (p, q)
397+
end
398+
399+
function signed_circle_curvature((a,f),(b,g),(c,h))
400+
@assert a < b < c
401+
r, _ = three_point_circle((a,f),(b,g),(c,h))
402+
sign = g > (f+h)/2 ? -1 : 1
403+
return sign / r
404+
end

src/Core/rankdetection.jl

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
"""
2-
rank_detect_factorize(Y; online_rank_estimation=false, rank=nothing, model=Tucker1, kwargs...)
2+
rank_detect_factorize(Y; kwargs...)
33
44
Wraps `factorize()` with rank detection.
55
66
Selects the rank that maximizes the standard curvature of the Relative Error (as a function of rank).
7+
8+
# Keywords
9+
- `online_rank_estimation`: `false`. Set to `true` to stop testing larger ranks after the first peak in curvature
10+
- `curvature_method`: `:splines`. Can also pick `:finite_differences` (faster but less accurate) or `circles` (fastest and smallest memory but more sensitive to results from `factorize`)
11+
- `model`: `Tucker1`. Only rank detection with `Tucker1` and `CPDecomposition` is currently implemented
12+
- `max_rank`: `max_possible_rank(Y, model)`. Test ranks from `1` up to `max_rank`. Defaults to largest possible rank under the model
13+
- `rank`: `nothing`. If a rank is passed, rank detection is ignored and `factorize(Y; kwargs...)` is called
14+
15+
Any other keywords from [`factorize`](@ref), full list given by [`default_kwargs`](@ref).
716
"""
8-
function rank_detect_factorize(Y; online_rank_estimation=false, rank=nothing, model=Tucker1, kwargs...)
17+
function rank_detect_factorize(Y;
18+
online_rank_estimation=false,
19+
rank=nothing,
20+
model=Tucker1,
21+
max_rank=max_possible_rank(Y, model),
22+
curvature_method=:splines,
23+
kwargs...)
924
if isnothing(rank)
1025
# Initialize output and final error lists
1126
all_outputs = []
@@ -20,8 +35,10 @@ function rank_detect_factorize(Y; online_rank_estimation=false, rank=nothing, mo
2035
kwargs[:stats] = [RelativeError, kwargs[:stats]...] # not using pushfirst! since kwargs[:stats] could be a Tuple
2136
end
2237
kwargs[:model] = model # add the model back into kwargs
38+
kwargs[:curvature_method] = curvature_method# :splines
39+
kwargs[:max_rank] = max_rank
2340

24-
for rank in possible_ranks(Y, model)
41+
for rank in 1:max_rank
2542
@info "Trying rank=$rank..."
2643

2744
kwargs[:rank] = rank # add the rank into kwargs
@@ -35,7 +52,7 @@ function rank_detect_factorize(Y; online_rank_estimation=false, rank=nothing, mo
3552
@info "Final relative error = $final_rel_error"
3653

3754
if (online_rank_estimation == true) && length(final_rel_errors) >= 3 # Need at least 3 points to evaluate curvature
38-
curvatures = standard_curvature(final_rel_errors; method=:finite_differences) # method=:splines
55+
curvatures = standard_curvature(final_rel_errors; method=curvature_method) # method=:splines
3956
if curvatures[end] maximum(curvatures) # want the last curvature to be significantly smaller than the max
4057
continue
4158
else
@@ -57,9 +74,9 @@ function rank_detect_factorize(Y; online_rank_estimation=false, rank=nothing, mo
5774
end
5875

5976
"""
60-
possible_ranks(Y, model)
77+
max_possible_rank(Y, model)
6178
62-
Returns the rank of possible ranks `Y` could have under the `model`.
79+
Returns the maximum rank possible `Y` could have under the `model`.
6380
6481
For matrices `I × J` this is `1:min(I, J)`. This is can be extended to tensors for different type
6582
of decompositions.
@@ -70,16 +87,16 @@ The CP-rank is `≤ minimum_{n} (prod(I1,...,IN) / In)` for tensors `I1 × …
7087
some shapes have have tighter upper bounds. For example, `2 × I × I` tensors over ℝ have a maximum
7188
rank of `floor(3I/2)`.
7289
"""
73-
function possible_ranks(Y, model)
90+
function max_possible_rank(Y, model) # TODO store this info with the Corresponding AbstractDecomposition
7491
if model <: Tucker1
7592
I, Js... = size(Y)
7693
max_rank = min(I, prod(Js))
77-
return 1:max_rank
94+
return max_rank
7895
elseif model <: CPDecomposition
7996
Is = size(Y)
8097
# There exist tighter upper bounds for particular shapes like I×I×K, but this a simple upper bound that works for all shapes
8198
max_rank = minimum(prod(Is) Is) # ÷ is Integer division
82-
return 1:max_rank
99+
return max_rank
83100
else
84101
error("Possible ranks for models of type $model are not implemented")
85102
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
BlockTensorFactorization = "07b766a1-0096-4e53-b687-1cd07becaccb"
33
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
44
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
56
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Test
77
using Random
88
Random.seed!(3141592653589) # Reproducibility of random initializations
99
using LinearAlgebra
10+
using Statistics
1011

1112
using BlockTensorFactorization
1213

@@ -803,10 +804,39 @@ end
803804
end
804805

805806
@testset "RankDetection" begin
807+
@test BlockTensorFactorization.Core.three_point_circle((1,2), (2,1), (5,2)) == (5, (3, 3))
808+
809+
# any smooth function s.t.
810+
# f(0) = 1, f(1) = 0, f''(1) = 0
811+
f(x) = 3x^3 - 5x^2 + x + 1
812+
# k(x) = f''(x) / (1 + (f'(x))^2)^(3/2)
813+
k(x) = (18x-10)/(1 + (1 - 10x + 9x^2)^2)^(1.5) # closed form true curvature
814+
x = range(0, 1, length=20)[2:end] # exclude x=0 point (necessary for :splines)
815+
y = f.(x)
816+
k_true = k.(x)
817+
methods = (:finite_differences, :splines, :circles)
818+
k_finite_differences, k_splines, k_circles = (standard_curvature(y; method=method) for method in methods)
819+
# Mean Absolute Percentage Error
820+
MAPE(test_vals, true_vals) = mean(@. abs((test_vals - true_vals)/true_vals))
821+
@test MAPE(k_splines, k_true) < 0.07 # 7%
822+
@test MAPE(k_circles, k_true) < 0.08 # 8%
823+
@test MAPE(k_finite_differences, k_true) < 0.09 # 9%
824+
806825
T = Tucker1((10, 10, 10), 3)
807826
Y = array(T)
808-
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=Tucker1)
827+
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=Tucker1, curvature_method=:splines)
828+
@test kwargs[:rank] == 3
829+
830+
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=Tucker1, curvature_method=:circles)
809831
@test kwargs[:rank] == 3
832+
833+
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=Tucker1, curvature_method=:finite_differences)
834+
@test kwargs[:rank] == 3
835+
836+
T = CPDecomposition((10, 11, 12), 4)
837+
Y = array(T)
838+
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=CPDecomposition, curvature_method=:splines, online_rank_estimation=true)
839+
@test kwargs[:rank] == 4
810840
end
811841

812842
end

0 commit comments

Comments
 (0)