Commit 7b8a57a4 authored by Konstantinos Chatzilygeroudis's avatar Konstantinos Chatzilygeroudis
Browse files

Acqui Opt with Grads

parent 6f8a475e
......@@ -105,8 +105,9 @@ public:
size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction>
limbo::opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const
limbo::opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{
assert(!gradient);
// double mu, sigma;
// std::tie(mu, sigma) = _model.query(v);
// return (mu + Params::ucb::alpha() * sqrt(sigma));
......
......@@ -81,8 +81,9 @@ namespace limbo {
size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction>
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{
assert(!gradient);
Eigen::VectorXd mu;
double sigma_sq;
std::tie(mu, sigma_sq) = _model.query(v);
......
......@@ -92,8 +92,9 @@ namespace limbo {
size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction>
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{
assert(!gradient);
Eigen::VectorXd mu;
double sigma;
std::tie(mu, sigma) = _model.query(v);
......
......@@ -79,8 +79,9 @@ namespace limbo {
size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction>
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{
assert(!gradient);
Eigen::VectorXd mu;
double sigma;
std::tie(mu, sigma) = _model.query(v);
......
......@@ -152,7 +152,7 @@ namespace limbo {
acquisition_function_t acqui(_model, this->_current_iteration);
auto acqui_optimization =
[&](const Eigen::VectorXd& x, bool g) { return acqui(x,afun); };
[&](const Eigen::VectorXd& x, bool g) { return acqui(x,afun,g); };
Eigen::VectorXd starting_point = tools::random_vector(StateFunction::dim_in);
Eigen::VectorXd new_sample = acqui_optimizer(acqui_optimization, starting_point, true);
bool blacklisted = !this->eval_and_add(sfun, new_sample);
......
......@@ -80,7 +80,7 @@ namespace limbo {
point[dim_in] = x;
if (dim_in == current.size() - 1) {
auto q = bo.model().query(point);
double acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(point, afun));
double acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(point, afun, false));
ofs << point.transpose() << " "
<< std::get<0>(q).transpose() << " "
<< std::get<1>(q) << " "
......
......@@ -71,11 +71,11 @@ namespace limbo {
if (!blacklisted && !bo.samples().empty()) {
std::tie(mu, sigma) = bo.model().query(bo.samples().back());
acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.samples().back(), afun));
acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.samples().back(), afun, false));
}
else if (!bo.bl_samples().empty()) {
std::tie(mu, sigma) = bo.model().query(bo.bl_samples().back());
acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.bl_samples().back(), afun));
acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.bl_samples().back(), afun, false));
}
else
return;
......
......@@ -252,7 +252,7 @@ BOOST_AUTO_TEST_CASE(test_gp_no_samples_acqui_opt)
// we do not have gradient in our current acquisition function
auto acqui_optimization =
[&](const Eigen::VectorXd& x, bool g) { return acqui(x, FirstElem()); };
[&](const Eigen::VectorXd& x, bool g) { return acqui(x, FirstElem(), g); };
Eigen::VectorXd starting_point = tools::random_vector(2);
Eigen::VectorXd test = acqui_optimizer(acqui_optimization, starting_point, true);
BOOST_CHECK(test(0) < 1e-5);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment