Skip to content

Commit 618b4c9

Browse files
committed
structure for int-handling
1 parent 77d5925 commit 618b4c9

File tree

10 files changed

+118
-8
lines changed

10 files changed

+118
-8
lines changed

include/bounds.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ namespace parameters
1414

1515
namespace bounds
1616
{
17-
using Mask = Eigen::Array<bool, Eigen::Dynamic, 1>;
18-
1917
Mask is_out_of_bounds(const Vector &xi, const Vector &lb, const Vector &ub);
2018
bool any_out_of_bounds(const Vector &xi, const Vector &lb, const Vector &ub);
2119

include/common.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ using Matrix = Eigen::Matrix<Float, -1, -1>;
2828
using Vector = Eigen::Matrix<Float, -1, 1>;
2929
using Array = Eigen::Array<Float, -1, 1>;
3030
using size_to = std::optional<size_t>;
31+
using Mask = Eigen::Array<bool, Eigen::Dynamic, 1>;
32+
using Indices = Eigen::ArrayXi;
3133

3234
template <typename T>
3335
std::ostream &operator<<(std::ostream &os, const std::vector<T> &x);

include/integer.hpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#pragma once
2+
3+
#include "common.hpp"
4+
5+
namespace parameters
6+
{
7+
struct Parameters;
8+
}
9+
10+
11+
namespace integer
12+
{
13+
14+
struct IntegerHandling
15+
{
16+
Float lb_sigma;
17+
Vector ones;
18+
Vector effective_y;
19+
20+
IntegerHandling(const size_t d, const Float mueff)
21+
:
22+
lb_sigma(std::min(0.2, mueff / static_cast<Float>(d))),
23+
ones(Vector::Ones(d)),
24+
effective_y(Vector::Ones(d))
25+
{
26+
}
27+
28+
virtual void update_diagonal(const parameters::Parameters& p);
29+
30+
virtual Array get_effective_sigma(const parameters::Parameters& p, const size_t idx);
31+
32+
virtual void round_to_integer(Eigen::Ref<Vector> x, const Indices iidx)
33+
{
34+
for (const auto& idx: iidx)
35+
x[idx] = std::round(x[idx]);
36+
}
37+
};
38+
39+
struct NoIntegerHandling : IntegerHandling
40+
{
41+
using IntegerHandling::IntegerHandling;
42+
43+
// virtual void get_effective_sigma(const parameters::Parameters& p) override {}
44+
// void round_to_integer(Eigen::Ref<Vector> x, const Indices iidx) override {}
45+
};
46+
47+
inline std::shared_ptr<IntegerHandling> get(const Indices &idx, const size_t d, const Float mueff)
48+
{
49+
if (idx.size() == 0)
50+
return std::make_shared<NoIntegerHandling>(d, mueff);
51+
52+
return std::make_shared<IntegerHandling>(d, mueff);
53+
}
54+
}

include/parameters.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "weights.hpp"
1313
#include "repelling.hpp"
1414
#include "center_placement.hpp"
15+
#include "integer.hpp"
1516

1617
namespace parameters
1718
{
@@ -39,7 +40,8 @@ namespace parameters
3940
std::shared_ptr<bounds::BoundCorrection> bounds;
4041
std::shared_ptr<repelling::Repelling> repelling;
4142
std::shared_ptr<center::Placement> center_placement;
42-
43+
std::shared_ptr<integer::IntegerHandling> integer_handling;
44+
4345
Parameters(const size_t dim);
4446

4547
Parameters(const Settings &settings);

include/settings.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace parameters
2222

2323
Vector lb;
2424
Vector ub;
25+
Indices integer_variables;
2526
Vector db;
2627
Vector center;
2728
Float diameter;
@@ -49,6 +50,7 @@ namespace parameters
4950
std::optional<Vector> x0 = std::nullopt,
5051
std::optional<Vector> lb = std::nullopt,
5152
std::optional<Vector> ub = std::nullopt,
53+
std::optional<Indices> integer_variables = std::nullopt,
5254
std::optional<Float> cs = std::nullopt,
5355
std::optional<Float> cc = std::nullopt,
5456
std::optional<Float> cmu = std::nullopt,
@@ -68,6 +70,7 @@ namespace parameters
6870
x0(x0),
6971
lb(lb.value_or(Vector::Ones(dim) * -5)),
7072
ub(ub.value_or(Vector::Ones(dim) * 5)),
73+
integer_variables(integer_variables.value_or(Indices{})),
7174
db(this->ub - this->lb),
7275
center(this->lb + (db * 0.5)),
7376
diameter(db.norm()),

src/integer.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "integer.hpp"
2+
#include "parameters.hpp"
3+
4+
namespace integer
5+
{
6+
void IntegerHandling::update_diagonal(const parameters::Parameters& p)
7+
{
8+
if (p.settings.integer_variables.size() == 0)
9+
return;
10+
11+
effective_y = p.adaptation->compute_y(ones);
12+
}
13+
14+
Array IntegerHandling::get_effective_sigma(const parameters::Parameters& p, const size_t idx)
15+
{
16+
Array effective_sigma = Array::Constant(p.settings.dim, p.pop.s(idx));
17+
for (const auto& iidx: p.settings.integer_variables)
18+
{
19+
20+
}
21+
22+
return effective_sigma;
23+
}
24+
}

src/interface.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "c_maes.hpp"
1212
#include "to_string.hpp"
1313
#include "es.hpp"
14-
1514
namespace py = pybind11;
1615

1716
PYBIND11_MAKE_OPAQUE(restart::vCriteria);
@@ -604,7 +603,7 @@ void define_parameters(py::module &main)
604603
py::class_<Settings, std::shared_ptr<Settings>>(m, "Settings")
605604
.def(py::init<size_t, std::optional<Modules>, std::optional<Float>, size_to, size_to, std::optional<Float>,
606605
std::optional<size_t>, std::optional<size_t>, std::optional<Vector>,
607-
std::optional<Vector>, std::optional<Vector>,
606+
std::optional<Vector>, std::optional<Vector>, std::optional<Indices>,
608607
std::optional<Float>, std::optional<Float>, std::optional<Float>,
609608
std::optional<Float>, std::optional<Float>, std::optional<Float>,
610609
bool, bool>(),
@@ -619,6 +618,7 @@ void define_parameters(py::module &main)
619618
py::arg("x0") = std::nullopt,
620619
py::arg("lb") = std::nullopt,
621620
py::arg("ub") = std::nullopt,
621+
py::arg("integer_variables") = std::nullopt,
622622
py::arg("cs") = std::nullopt,
623623
py::arg("cc") = std::nullopt,
624624
py::arg("cmu") = std::nullopt,
@@ -668,6 +668,7 @@ void define_parameters(py::module &main)
668668
ss << " x0: " << to_string(settings.x0);
669669
ss << " lb: " << settings.lb.transpose();
670670
ss << " ub: " << settings.ub.transpose();
671+
ss << " integer_variables: " << settings.integer_variables.transpose();
671672
ss << " cs: " << to_string(settings.cs);
672673
ss << " cc: " << to_string(settings.cc);
673674
ss << " cmu: " << to_string(settings.cmu);

src/mutation.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ namespace mutation
3737
zi, p.settings.diameter, p.settings.budget, p.stats.evaluations);
3838
p.pop.Z.col(i).noalias() = zi_scaled;
3939
p.pop.Y.col(i).noalias() = p.adaptation->compute_y(p.pop.Z.col(i));
40-
p.pop.X.col(i).noalias() = p.pop.Y.col(i) * p.pop.s(i) + p.adaptation->m;
40+
41+
const auto effective_sigma = p.integer_handling->get_effective_sigma(p, i);
42+
p.pop.X.col(i).array() = p.pop.Y.col(i).array() * effective_sigma + p.adaptation->m.array();
4143
p.bounds->correct(i, p);
44+
45+
p.integer_handling->round_to_integer(p.pop.X.col(i), p.settings.integer_variables);
4246
} while (
4347
(p.settings.modules.bound_correction == parameters::CorrectionMethod::RESAMPLE &&
4448
n_rej++ < 5 * p.settings.dim && p.bounds->is_out_of_bounds(p.pop.X.col(i), p.settings).any()) ||

src/parameters.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ namespace parameters
3030
settings.budget)),
3131
bounds(bounds::get(settings.modules.bound_correction, settings.dim)),
3232
repelling(repelling::get(settings.modules)),
33-
center_placement(center::get(settings.modules.center_placement))
33+
center_placement(center::get(settings.modules.center_placement)),
34+
integer_handling(integer::get(settings.integer_variables, settings.dim, weights.mueff))
3435
{
3536
criteria.reset(*this);
3637
}
@@ -80,8 +81,8 @@ namespace parameters
8081
mutation->sigma = std::min(std::max(mutation->sigma, restart::MinSigma::tolerance), restart::MaxSigma::tolerance);
8182

8283
successfull_adaptation = adaptation->adapt_matrix(weights, settings.modules, pop, mu, settings, stats);
83-
8484
criteria.update(*this);
85+
integer_handling->update_diagonal(*this);
8586
stats.t++;
8687
}
8788

tests/test_c_integer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
from modcma import c_maes
3+
4+
5+
def sphere(x):
6+
return sum(xi**2 for xi in x)
7+
8+
9+
class TestInteger(unittest.TestCase):
10+
def test_int(self):
11+
settings = c_maes.settings_from_dict(2, integer_variables=[0])
12+
cma = c_maes.ModularCMAES(settings)
13+
14+
for _ in range(3):
15+
cma.mutate(sphere)
16+
cma.select()
17+
cma.recombine()
18+
cma.adapt()
19+
20+
# breakpoint()
21+

0 commit comments

Comments
 (0)