Commit f14be3a0 authored by Konstantinos Chatzilygeroudis's avatar Konstantinos Chatzilygeroudis
Browse files

Minor fixes for save/loading and mean_observation in GP/MultiGP

parent 1fb3a54b
...@@ -225,7 +225,7 @@ namespace limbo { ...@@ -225,7 +225,7 @@ namespace limbo {
/// return the mean observation (only call this if the output of the GP is of dimension 1) /// return the mean observation (only call this if the output of the GP is of dimension 1)
Eigen::VectorXd mean_observation() const Eigen::VectorXd mean_observation() const
{ {
// TODO: Check if _dim_out is correct?! assert(_dim_out > 0);
return _samples.size() > 0 ? _mean_observation return _samples.size() > 0 ? _mean_observation
: Eigen::VectorXd::Zero(_dim_out); : Eigen::VectorXd::Zero(_dim_out);
} }
......
...@@ -106,6 +106,12 @@ namespace limbo { ...@@ -106,6 +106,12 @@ namespace limbo {
// compute the new observations for the GPs // compute the new observations for the GPs
std::vector<std::vector<Eigen::VectorXd>> obs(_dim_out); std::vector<std::vector<Eigen::VectorXd>> obs(_dim_out);
// compute mean observation
_mean_observation = Eigen::VectorXd::Zero(_dim_out);
for (size_t j = 0; j < _observations.size(); j++)
_mean_observation.array() += _observations[j].array();
_mean_observation.array() /= static_cast<double>(_observations.size());
for (size_t j = 0; j < observations.size(); j++) { for (size_t j = 0; j < observations.size(); j++) {
Eigen::VectorXd mean_vector = _mean_function(samples[j], *this); Eigen::VectorXd mean_vector = _mean_function(samples[j], *this);
assert(mean_vector.size() == _dim_out); assert(mean_vector.size() == _dim_out);
...@@ -154,6 +160,12 @@ namespace limbo { ...@@ -154,6 +160,12 @@ namespace limbo {
_observations.push_back(observation); _observations.push_back(observation);
// recompute mean observation
_mean_observation = Eigen::VectorXd::Zero(_dim_out);
for (size_t j = 0; j < _observations.size(); j++)
_mean_observation.array() += _observations[j].array();
_mean_observation.array() /= static_cast<double>(_observations.size());
Eigen::VectorXd mean_vector = _mean_function(sample, *this); Eigen::VectorXd mean_vector = _mean_function(sample, *this);
assert(mean_vector.size() == _dim_out); assert(mean_vector.size() == _dim_out);
...@@ -256,8 +268,15 @@ namespace limbo { ...@@ -256,8 +268,15 @@ namespace limbo {
/// return the list of samples that have been tested so far /// return the list of samples that have been tested so far
const std::vector<Eigen::VectorXd>& samples() const const std::vector<Eigen::VectorXd>& samples() const
{ {
assert(_gp_models.size()); return _observations.size();
return _gp_models[0].samples(); }
/// return the mean observation
Eigen::VectorXd mean_observation() const
{
assert(_dim_out > 0);
return _observations.size() > 0 ? _mean_observation
: Eigen::VectorXd::Zero(_dim_out);
} }
/// return the list of GPs /// return the list of GPs
...@@ -324,6 +343,12 @@ namespace limbo { ...@@ -324,6 +343,12 @@ namespace limbo {
_dim_in = static_cast<int>(dims(0)); _dim_in = static_cast<int>(dims(0));
_dim_out = static_cast<int>(dims(1)); _dim_out = static_cast<int>(dims(1));
// recompute mean observation
_mean_observation = Eigen::VectorXd::Zero(_dim_out);
for (size_t j = 0; j < _observations.size(); j++)
_mean_observation.array() += _observations[j].array();
_mean_observation.array() /= static_cast<double>(_observations.size());
_mean_function = MeanFunction(_dim_out); _mean_function = MeanFunction(_dim_out);
if (_mean_function.h_params_size() > 0) { if (_mean_function.h_params_size() > 0) {
...@@ -348,6 +373,7 @@ namespace limbo { ...@@ -348,6 +373,7 @@ namespace limbo {
HyperParamsOptimizer _hp_optimize; HyperParamsOptimizer _hp_optimize;
MeanFunction _mean_function; MeanFunction _mean_function;
std::vector<Eigen::VectorXd> _observations; std::vector<Eigen::VectorXd> _observations;
Eigen::VectorXd _mean_observation;
}; };
} // namespace model } // namespace model
} // namespace limbo } // namespace limbo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment