Commit 2903436c authored by Jean-Baptiste Mouret's avatar Jean-Baptiste Mouret
Browse files

corrected a bug in ehvi

parent ca2ab23c
......@@ -24,6 +24,10 @@ struct Params {
struct cmaes : public defaults::cmaes {};
struct gp_auto : public defaults::gp_auto {};
struct meanconstant : public defaults::meanconstant {};
struct ehvi {
BO_PARAM(double, x_ref, -11);
BO_PARAM(double, y_ref, -11);
};
};
......@@ -139,22 +143,22 @@ namespace limbo {
obs << opt.observations()[i].transpose() << " "
<< opt.samples()[i].transpose()
<< std::endl;
/*
std::string m1 = "model_" + it + ".dat";
std::ofstream m1f(m1.c_str());
for (float x = 0; x < 1; x += 0.01)
for (float y = 0; y < 1; y += 0.01) {
Eigen::VectorXd v(2);
v << x, y;
m1f << x << " " << y << " "
<< opt.models()[0].mu(v) << " "
<< opt.models()[0].sigma(v) << " "
<< opt.models()[1].mu(v) << " "
<< opt.models()[1].sigma(v) << std::endl;
}
*/
std::cout<<"stats done"<<std::endl;
/*
std::string m1 = "model_" + it + ".dat";
std::ofstream m1f(m1.c_str());
for (float x = 0; x < 1; x += 0.01)
for (float y = 0; y < 1; y += 0.01) {
Eigen::VectorXd v(2);
v << x, y;
m1f << x << " " << y << " "
<< opt.models()[0].mu(v) << " "
<< opt.models()[0].sigma(v) << " "
<< opt.models()[1].mu(v) << " "
<< opt.models()[1].sigma(v) << std::endl;
}
*/
std::cout << "stats done" << std::endl;
}
};
}
......
......@@ -26,16 +26,15 @@ namespace limbo {
double r[3] = { _ref_point(0), _ref_point(1), _ref_point(2) };
double mu[3] = { _models[0].mu(v), _models[1].mu(v), 0 };
double s[3] = { _models[0].sigma(v), _models[1].sigma(v), 0 };
for (size_t i = 0; i < _models.size(); ++i)
mu[i] = std::min(_models[i].mu(v), _models[i].max_observation());
double ehvi = ehvi2d(_pop, r, mu, s);
//for (size_t i = 0; i < _models.size(); ++i)
// mu[i] = std::min(_models[i].mu(v), _models[i].max_observation());
double ehvi = ehvi2d(_pop, r, mu, s);
return ehvi;
}
protected:
const std::vector<Model>& _models;
const std::deque<individual*>& _pop;
const Eigen::VectorXd& _ref_point;
Eigen::VectorXd _ref_point;
};
}
......@@ -68,7 +67,6 @@ namespace limbo {
this->update_pareto_data();
std::cout << "ok" << std::endl;
std::cout<<"copying pop...("<<this->pareto_data().size()<<")"<<std::endl;;
// copy in the ehvi structure to compute expected improvement
std::deque<individual*> pop;
for (auto x : this->pareto_data()) {
......@@ -83,22 +81,24 @@ namespace limbo {
std::cout << "optimizing ehvi" << std::endl;
auto acqui =
acquisition_functions::Ehvi<Params, model_t>(this->_models, pop,
Eigen::Vector3d(-11, -11, 0));
acquisition_functions::Ehvi<Params, model_t>
(this->_models, pop,
Eigen::Vector3d(Params::ehvi::x_ref(), Params::ehvi::y_ref(), 0));
double best_hv = -1;
Eigen::VectorXd best_s;
for (auto x : this->pareto_data()) {
Eigen::VectorXd s = inner_opt(acqui, acqui.dim(), std::get<0>(x));
double hv = acqui(s);
if (hv > best_hv && this->_models[0].mu(s) <= 1)
if (hv > best_hv)
{
best_s = s;
best_hv = hv;
}
}
std::cout<<"sample selected" << std::endl;
Eigen::VectorXd new_sample = best_s;
std::cout<<"new sample:"<<new_sample.transpose()<<std::endl;
std::cout<<"expected improvement: "<<acqui(new_sample)<<std::endl;
std::cout<<"expected value: "<<this->_models[0].mu(new_sample)
......
......@@ -30,7 +30,7 @@ namespace limbo {
Cmaes() {}
template <typename AcquisitionFunction>
Eigen::VectorXd operator()(const AcquisitionFunction& acqui, int dim) const {
return this->operator()(acqui, dim, Eigen::VectorXd::Constant(dim, 0.5));
return this->operator()(acqui, dim, Eigen::VectorXd::Constant(dim, 0.5));
}
template <typename AcquisitionFunction>
Eigen::VectorXd operator()(const AcquisitionFunction& acqui, int dim, const Eigen::VectorXd& init) const {
......@@ -53,7 +53,7 @@ namespace limbo {
double init_point[dim];
for (int i = 0; i < dim; ++i)
init_point[i] = init(i);
for (irun = 0; irun < nrestarts + 1; ++irun) {
for (irun = 0; irun < nrestarts + 1; ++irun) {
fitvals = cmaes_init(&evo, acqui.dim(), init_point, NULL, 0, lambda, NULL);
evo.countevals = countevals;
evo.sp.stopMaxFunEvals =
......
......@@ -45,7 +45,7 @@ def configure(conf):
if conf.is_defined('USE_SFERES'):
common_flags += " -DUSE_SFERES -DSFERES_FAST_DOMSORT"
opt_flags = common_flags + ' -O3 -msse2 -ggdb3 -g'
opt_flags = common_flags + ' -O3 -msse2 -g'
conf.env['CXXFLAGS'] = cxxflags + opt_flags.split(' ')
print conf.env['CXXFLAGS']
......@@ -63,7 +63,7 @@ def build(bld):
bld.recurse('exp/' + i)
from waflib.Tools import waf_unit_test
bld.add_post_fun(waf_unit_test.summary)
def shutdown (ctx):
if ctx.options.qsub:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment