Commit 7b8a57a4 authored by Konstantinos Chatzilygeroudis's avatar Konstantinos Chatzilygeroudis
Browse files

Acqui Opt with Grads

parent 6f8a475e
...@@ -105,8 +105,9 @@ public: ...@@ -105,8 +105,9 @@ public:
size_t dim_out() const { return _model.dim_out(); } size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction> template <typename AggregatorFunction>
limbo::opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const limbo::opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{ {
assert(!gradient);
// double mu, sigma; // double mu, sigma;
// std::tie(mu, sigma) = _model.query(v); // std::tie(mu, sigma) = _model.query(v);
// return (mu + Params::ucb::alpha() * sqrt(sigma)); // return (mu + Params::ucb::alpha() * sqrt(sigma));
......
...@@ -81,8 +81,9 @@ namespace limbo { ...@@ -81,8 +81,9 @@ namespace limbo {
size_t dim_out() const { return _model.dim_out(); } size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction> template <typename AggregatorFunction>
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{ {
assert(!gradient);
Eigen::VectorXd mu; Eigen::VectorXd mu;
double sigma_sq; double sigma_sq;
std::tie(mu, sigma_sq) = _model.query(v); std::tie(mu, sigma_sq) = _model.query(v);
......
...@@ -92,8 +92,9 @@ namespace limbo { ...@@ -92,8 +92,9 @@ namespace limbo {
size_t dim_out() const { return _model.dim_out(); } size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction> template <typename AggregatorFunction>
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{ {
assert(!gradient);
Eigen::VectorXd mu; Eigen::VectorXd mu;
double sigma; double sigma;
std::tie(mu, sigma) = _model.query(v); std::tie(mu, sigma) = _model.query(v);
......
...@@ -79,8 +79,9 @@ namespace limbo { ...@@ -79,8 +79,9 @@ namespace limbo {
size_t dim_out() const { return _model.dim_out(); } size_t dim_out() const { return _model.dim_out(); }
template <typename AggregatorFunction> template <typename AggregatorFunction>
opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun) const opt::eval_t operator()(const Eigen::VectorXd& v, const AggregatorFunction& afun, bool gradient) const
{ {
assert(!gradient);
Eigen::VectorXd mu; Eigen::VectorXd mu;
double sigma; double sigma;
std::tie(mu, sigma) = _model.query(v); std::tie(mu, sigma) = _model.query(v);
......
...@@ -152,7 +152,7 @@ namespace limbo { ...@@ -152,7 +152,7 @@ namespace limbo {
acquisition_function_t acqui(_model, this->_current_iteration); acquisition_function_t acqui(_model, this->_current_iteration);
auto acqui_optimization = auto acqui_optimization =
[&](const Eigen::VectorXd& x, bool g) { return acqui(x,afun); }; [&](const Eigen::VectorXd& x, bool g) { return acqui(x,afun,g); };
Eigen::VectorXd starting_point = tools::random_vector(StateFunction::dim_in); Eigen::VectorXd starting_point = tools::random_vector(StateFunction::dim_in);
Eigen::VectorXd new_sample = acqui_optimizer(acqui_optimization, starting_point, true); Eigen::VectorXd new_sample = acqui_optimizer(acqui_optimization, starting_point, true);
bool blacklisted = !this->eval_and_add(sfun, new_sample); bool blacklisted = !this->eval_and_add(sfun, new_sample);
......
...@@ -80,7 +80,7 @@ namespace limbo { ...@@ -80,7 +80,7 @@ namespace limbo {
point[dim_in] = x; point[dim_in] = x;
if (dim_in == current.size() - 1) { if (dim_in == current.size() - 1) {
auto q = bo.model().query(point); auto q = bo.model().query(point);
double acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(point, afun)); double acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(point, afun, false));
ofs << point.transpose() << " " ofs << point.transpose() << " "
<< std::get<0>(q).transpose() << " " << std::get<0>(q).transpose() << " "
<< std::get<1>(q) << " " << std::get<1>(q) << " "
......
...@@ -71,11 +71,11 @@ namespace limbo { ...@@ -71,11 +71,11 @@ namespace limbo {
if (!blacklisted && !bo.samples().empty()) { if (!blacklisted && !bo.samples().empty()) {
std::tie(mu, sigma) = bo.model().query(bo.samples().back()); std::tie(mu, sigma) = bo.model().query(bo.samples().back());
acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.samples().back(), afun)); acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.samples().back(), afun, false));
} }
else if (!bo.bl_samples().empty()) { else if (!bo.bl_samples().empty()) {
std::tie(mu, sigma) = bo.model().query(bo.bl_samples().back()); std::tie(mu, sigma) = bo.model().query(bo.bl_samples().back());
acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.bl_samples().back(), afun)); acqui = std::get<0>(typename BO::acquisition_function_t(bo.model(), bo.current_iteration())(bo.bl_samples().back(), afun, false));
} }
else else
return; return;
......
...@@ -252,7 +252,7 @@ BOOST_AUTO_TEST_CASE(test_gp_no_samples_acqui_opt) ...@@ -252,7 +252,7 @@ BOOST_AUTO_TEST_CASE(test_gp_no_samples_acqui_opt)
// we do not have gradient in our current acquisition function // we do not have gradient in our current acquisition function
auto acqui_optimization = auto acqui_optimization =
[&](const Eigen::VectorXd& x, bool g) { return acqui(x, FirstElem()); }; [&](const Eigen::VectorXd& x, bool g) { return acqui(x, FirstElem(), g); };
Eigen::VectorXd starting_point = tools::random_vector(2); Eigen::VectorXd starting_point = tools::random_vector(2);
Eigen::VectorXd test = acqui_optimizer(acqui_optimization, starting_point, true); Eigen::VectorXd test = acqui_optimizer(acqui_optimization, starting_point, true);
BOOST_CHECK(test(0) < 1e-5); BOOST_CHECK(test(0) < 1e-5);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment