Skip to content

Commit d6b61a2

Browse files
committed
add breakpoint rank estimation method and more tests
1 parent 80e94c5 commit d6b61a2

File tree

4 files changed

+118
-4
lines changed

4 files changed

+118
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
Manifest.toml
22
docs/build
33
.vscode
4+
Breakpoint_calc.m

src/Core/curvaturetools.jl

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ end
141141
142142
Approximates the signed curvature of a function given evenly spaced samples.
143143
144-
Uses [`d_dx`](@ref) and [`d2_dx2`](@ref) to approximate the first two derivatives.
144+
# Possible `method`s
145+
- `:finite_differences`: Approximates first and second derivative with 3rd order finite differences. See [`d_dx`](@ref) and [`d2_dx2`](@ref).
146+
- `:splines`: Curvature of a third order spline. See [`d_dx_and_d2_dx2_spline`](@ref).
147+
- `:circles`: Inverse radius of a circle through rolling three points. See [`circle_curvature`](@ref).
148+
- `:breakpoints`: WARNING does not compute a value that approximates the curvature of a continuous function. Computes the inverse least-squares error of `f.(eachindex(y); z)` and `y` for all `z in eachindex(y)` where `f(x; z) = a + b(min(x, z) - z) + c(max(x, z) - z)`. Useful if `y` looks like two lines. See [`breakpoint_curvature`](@ref).
145149
"""
146150
function curvature(y::AbstractVector{<:Real}; method=:finite_differences, kwargs...)
147151
if method == :finite_differences
@@ -153,6 +157,8 @@ function curvature(y::AbstractVector{<:Real}; method=:finite_differences, kwargs
153157
return @. dy2_dx2 / (1 + dy_dx^2)^1.5
154158
elseif method == :circles
155159
return circle_curvature(y; h=1)
160+
elseif method == :breakpoints
161+
return breakpoint_curvature(y)
156162
else
157163
throw(ArgumentError("method $method not implemented"))
158164
end
@@ -166,6 +172,13 @@ Approximates the signed curvature of a function, scaled to the unit box ``[0,1]^
166172
Assumes the function is 1 at 0 and (after x dimension is scaled) 0 at 1.
167173
168174
See [`curvature`](@ref).
175+
176+
177+
# Possible `method`s
178+
- `:finite_differences`: Approximates first and second derivative with 3rd order finite differences. See [`d_dx`](@ref) and [`d2_dx2`](@ref).
179+
- `:splines`: Curvature of a third order spline. See [`d_dx_and_d2_dx2_spline`](@ref).
180+
- `:circles`: Inverse radius of a circle through rolling three points. See [`circle_curvature`](@ref).
181+
- `:breakpoints`: WARNING does not compute a value that approximates the curvature of a continuous function. Computes the inverse least-squares error of `f.(eachindex(y); z)` and `y` for all `z in eachindex(y)` where `f(x; z) = a + b(min(x, z) - z) + c(max(x, z) - z)`. Useful if `y` looks like two lines. See [`breakpoint_curvature`](@ref).
169182
"""
170183
function standard_curvature(y::AbstractVector{<:Real}; method=:finite_differences, kwargs...)
171184
Δx = 1/length(y)
@@ -183,6 +196,8 @@ function standard_curvature(y::AbstractVector{<:Real}; method=:finite_difference
183196
return @. dy2_dx2 / (1 + dy_dx^2)^1.5
184197
elseif method == :circles
185198
return circle_curvature(y / max(1,maximum(y)); h=Δx)
199+
elseif method == :breakpoints
200+
return breakpoint_curvature(y) # best breakpoint unaffected by scaling and stretching
186201
else
187202
throw(ArgumentError("method $method not implemented"))
188203
end
@@ -264,3 +279,51 @@ function signed_circle_curvature((a,f),(b,g),(c,h))
264279
sign = g > (f+h)/2 ? -1 : 1
265280
return sign / r
266281
end
282+
283+
"""
284+
breakpoint_model_coefficients(xs, ys, breakpoint)
285+
286+
Least squares fit data ``(x_i, y_i)``
287+
288+
``\\min_{a,b,c} 0.5\\sum_{i} (f(x_i; a,b,c) - y_i)^2``
289+
290+
with the model
291+
292+
``f(x; a,b,c) = a + b(\\min(x, z) - x) + c(\\max(x, z) - x)``
293+
294+
for some fixed ``z``.
295+
"""
296+
function breakpoint_model_coefficients(xs, ys, z)
297+
n = length(xs)
298+
@assert n == length(ys)
299+
M = hcat(ones(n), (min.(xs, z) .- z), (max.(xs, z) .- z))
300+
a, b, c = M \ ys
301+
return a, b, c
302+
end
303+
304+
breakpoint_model(a, b, c, z) = x -> a + b*(min(x, z) - z) + c*(max(x, z) - z)
305+
306+
function breakpoint_error(xs, ys, z)
307+
a, b, c = breakpoint_model_coefficients(xs, ys, z)
308+
f = breakpoint_model(a, b, c, z)
309+
return norm2(@. f(xs) - ys)
310+
# equivalent to sum(((x, y),) -> (f(x) - y)^2, zip(xs, ys))
311+
end
312+
313+
best_breakpoint(xs, ys; breakpoints=xs) = argmin(z -> breakpoint_error(xs, ys, z), breakpoints)
314+
315+
"""
316+
breakpoint_curvature(y)
317+
318+
This is a hacked way to fit the data `y` with a breakpoint model,
319+
which can be called by `k = standard_curvature(...; model=:breakpoints)`
320+
321+
This lets us call `argmax(k)` to get the breakpoint that minimizes the model error.
322+
323+
See [`breakpoint_model_coefficients`](@ref).
324+
"""
325+
function breakpoint_curvature(y)
326+
x = eachindex(y)
327+
errors = [breakpoint_error(x, y, z) for z in x]
328+
return 1 ./ errors
329+
end

src/Core/rankdetection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Selects the rank that maximizes the standard curvature of the Relative Error (as
77
88
# Keywords
99
- `online_rank_estimation`: `false`. Set to `true` to stop testing larger ranks after the first peak in curvature
10-
- `curvature_method`: `:splines`. Can also pick `:finite_differences` (faster but less accurate) or `circles` (fastest and smallest memory but more sensitive to results from `factorize`)
10+
- `curvature_method`: `:splines`. Can also pick `:finite_differences` (faster but less accurate) or `circles` (fastest and smallest memory but more sensitive to results from `factorize`). Set to `:breakpoints` to pick the rank `R` that minimizes least-squares error in the model `f(r) = a + b(min(r, R) - R) + c(max(r, R) - R)` and the errors.
1111
- `model`: `Tucker1`. Only rank detection with `Tucker1` and `CPDecomposition` is currently implemented
1212
- `max_rank`: `max_possible_rank(Y, model)`. Test ranks from `1` up to `max_rank`. Defaults to largest possible rank under the model
1313
- `rank`: `nothing`. If a rank is passed, rank detection is ignored and `factorize(Y; kwargs...)` is called

test/runtests.jl

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,42 @@ end
894894

895895
check_slice = 1:10
896896
@test isapprox(decomposition[check_slice], Y[check_slice]; rtol=0.01) # should be within 1% error
897+
898+
N = 33 # 1 plus a power of 2
899+
R = 5
900+
D = 2
901+
902+
# More constrained problem
903+
matrices = [abs_randn(N, R) for _ in 1:D]
904+
l1scale_cols!.(matrices)
905+
Ydecomp = CPDecomposition(Tuple(matrices))#abs_randn
906+
@assert all(check.(simplex_cols!, factors(Ydecomp)))
907+
Y = array(Ydecomp)
908+
909+
scaleB_rescaleA! = ConstraintUpdate(0, l1scale_1slices! nonnegative!;
910+
whats_rescaled=(x -> eachcol(factor(x, 1)))
911+
)
912+
nonnegativeB! = ConstraintUpdate(0, nonnegative!)
913+
nonnegativeA! = ConstraintUpdate(1, nonnegative!)
914+
#[l1scale_1slices! ∘ nonnegative!, nonnegative!]
915+
916+
options = (
917+
rank=3,
918+
momentum=true,
919+
model=Tucker1,
920+
tolerance=(1e-5),
921+
converged=(GradientNNCone),
922+
do_subblock_updates=false,
923+
constrain_init=true,
924+
constraints=[scaleB_rescaleA!, nonnegativeA!],
925+
stats=[Iteration, ObjectiveValue, GradientNNCone, RelativeError],
926+
maxiter=200
927+
)
928+
929+
decomposition, stats, kwargs = multiscale_factorize(Y; options...)
930+
931+
check_slice = 1:10
932+
@test isapprox(decomposition[check_slice], Y[check_slice]; rtol=0.02) # should be within 2% error
897933
end
898934
end
899935

@@ -927,6 +963,13 @@ end
927963
@test MAPE(k_circles, k_true) < 0.03 # 3%
928964
@test MAPE(k_finite_differences, k_true) < 0.03 # 3%
929965

966+
# Break point method
967+
ys = [13,10,5,4.5,4,3.6,3,2.5]; xs = [0,1,2,3,4,5,6,7]; z=2
968+
a,b,c = BlockTensorFactorization.Core.breakpoint_model_coefficients(xs, ys, z)
969+
970+
@test isapprox([a,b,c], [5.2,-4.1,-0.5], rtol=0.01)
971+
972+
# Rank detect
930973
T = Tucker1((10, 10, 10), 3)
931974
Y = array(T)
932975
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=Tucker1, curvature_method=:splines)
@@ -938,10 +981,17 @@ end
938981
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=Tucker1, curvature_method=:finite_differences)
939982
@test kwargs[:rank] == 3
940983

984+
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=Tucker1, curvature_method=:breakpoints)
985+
@test kwargs[:rank] == 3
986+
941987
T = CPDecomposition((10, 11, 12), 4)
942988
Y = array(T)
943-
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=CPDecomposition, curvature_method=:splines, online_rank_estimation=true)
944-
@test kwargs[:rank] == 4
989+
V = zeros(Int, 5)
990+
for i in 1:5
991+
decomposition, stats, kwargs, final_rel_errors = rank_detect_factorize(Y; model=CPDecomposition, curvature_method=:splines, online_rank_estimation=true, tolerance=0.01)
992+
V[i] = kwargs[:rank]
993+
end
994+
@test count(x -> x == 4, V) 3 # should predict 4 most of the time
945995
end
946996

947997
end

0 commit comments

Comments
 (0)