Skip to content

Commit 146f576

Browse files
committed
Optional inplace, perf gain, no accuracy loss 3/3
1 parent ec0a49b commit 146f576

File tree

3 files changed

+28
-32
lines changed

3 files changed

+28
-32
lines changed

src/genetic.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,6 @@ function generate_population(icn, pop_size)
88
return population
99
end
1010

11-
"""
12-
loss(X, X_sols, icn, weigths, metric)
13-
Compute the loss of `icn`.
14-
"""
15-
function loss(solutions, non_sltns, icn, weigths, metric, dom_size, param; samples=nothing)
16-
compo = compose(icn, weigths)
17-
f = composition(compo)
18-
X = if isnothing(samples)
19-
Iterators.flatten((solutions, non_sltns))
20-
else
21-
Iterators.flatten((solutions, rand(non_sltns, samples)))
22-
end
23-
σ = sum(x -> abs(f(x; param, dom_size) - metric(x, solutions)), X) + regularization(icn)
24-
return σ
25-
end
26-
2711
"""
2812
_optimize!(icn, X, X_sols; metric = hamming, pop_size = 200)
2913
Optimize and set the weigths of an ICN with a given set of configuration `X` and solutions `X_sols`.
@@ -40,11 +24,18 @@ function _optimize!(
4024
samples=nothing,
4125
memoize=false,
4226
)
43-
_metric = memoize ? (@memoize Dict memoize_metric(x, X) = metric(x, X)) : metric
44-
_bias = memoize ? (@memoize Dict memoize_bias(x) = weigths_bias(x)) : weigths_bias
45-
fitness =
46-
w ->
47-
loss(solutions, non_sltns, icn, w, _metric, dom_size, param; samples) + _bias(w)
27+
inplace = zeros(dom_size, max_icn_length())
28+
_non_sltns = isnothing(samples) ? non_sltns : rand(non_sltns, samples)
29+
30+
function fitness(w)
31+
compo = compose(icn, w)
32+
f = composition(compo)
33+
S = Iterators.flatten((solutions, _non_sltns))
34+
return sum(x -> abs(f(x; X=inplace, param, dom_size) - metric(x, solutions)), S) +
35+
regularization(icn) +
36+
weigths_bias(w)
37+
end
38+
_fitness = memoize ? (@memoize Dict memoize_fitness(w) = fitness(w)) : fitness
4839

4940
_icn_ga = GA(;
5041
populationSize=pop_size,
@@ -57,7 +48,7 @@ function _optimize!(
5748
)
5849

5950
pop = generate_population(icn, pop_size)
60-
r = Evolutionary.optimize(fitness, pop, _icn_ga, Evolutionary.Options(; iterations))
51+
r = Evolutionary.optimize(_fitness, pop, _icn_ga, Evolutionary.Options(; iterations))
6152
return weights!(icn, Evolutionary.minimizer(r))
6253
end
6354

src/icn.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ mutable struct ICN
2424
co_layer=comparison_layer(param),
2525
)
2626
w = generate_weights([tr_layer, ar_layer, ag_layer, co_layer])
27-
new(tr_layer, ar_layer, ag_layer, co_layer, w)
27+
return new(tr_layer, ar_layer, ag_layer, co_layer, w)
2828
end
2929
end
3030

@@ -75,7 +75,7 @@ Set the weights of an ICN with a `BitVector`.
7575
function weights!(icn, weigths)
7676
length(weigths) == nbits(icn) || @warn icn weigths
7777
@assert length(weigths) == nbits(icn)
78-
icn.weigths = weigths
78+
return icn.weigths = weigths
7979
end
8080

8181
"""
@@ -107,14 +107,20 @@ function regularization(icn)
107107
return Σop / (Σmax + 1)
108108
end
109109

110-
max_icn_length(icn = ICN(param = true)) = length(icn.transformation)
110+
max_icn_length(icn=ICN(; param=true)) = length(icn.transformation)
111111

112112
"""
113113
_compose(icn)
114114
Internal function called by `compose` and `show_composition`.
115115
"""
116116
function _compose(icn::ICN)
117-
!is_viable(icn) && (return ((x; param=nothing, dom_size=0) -> typemax(Float64)), [])
117+
!is_viable(icn) && (
118+
return (
119+
(x; X=zeros(length(x), max_icn_length()), param=nothing, dom_size=0) ->
120+
typemax(Float64)
121+
),
122+
[]
123+
)
118124

119125
funcs = Vector{Vector{Function}}()
120126
symbols = Vector{Vector{Symbol}}()
@@ -148,11 +154,10 @@ function _compose(icn::ICN)
148154

149155
function composition(x; X=zeros(length(x), length(funcs[1])), param=nothing, dom_size)
150156
tr_in(Tuple(funcs[1]), X, x, param)
151-
for i in 1:length(x)
152-
X[i,1] = funcs[2][1](@view X[i,:])
153-
end
154-
funcs[3][1](@view X[:, 1]) |>
155-
(y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
157+
X[:, 1] .= 1:length(x) .|> (i -> funcs[2][1](@view X[i, 1:length(funcs[1])]))
158+
return (y -> funcs[4][1](y; param, dom_size, nvars=length(x)))(
159+
funcs[3][1](@view X[:, 1])
160+
)
156161
end
157162

158163
return composition, symbols

src/learn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function explore_learn_compose(
7272
),
7373
memoize=false,
7474
)
75-
dom_size = maximum(domain_size, domains)
75+
dom_size = maximum(length, domains)
7676
solutions, non_sltns = configurations
7777
return learn_compose(
7878
solutions,

0 commit comments

Comments
 (0)