Add save/load capabilities to MultiGP

parent 8d30278f
......@@ -272,6 +272,76 @@ namespace limbo {
return _gp_models;
}
/// save the parameters and the data for the GP to the archive (text or binary)
template <typename A>
void save(const std::string& directory)
{
A archive(directory);
save(archive);
}
/// save the parameters and the data for the GP to the archive (text or binary)
template <typename A>
void save(const A& archive)
{
Eigen::VectorXd dims(2);
dims << _dim_in, _dim_out;
archive.save(dims, "dims");
archive.save(_observations, "observations");
if (_mean_function.h_params_size() > 0) {
archive.save(_mean_function.h_params(), "mean_params");
}
for (int i = 0; i < _dim_out; i++) {
_gp_models[i].save<A>(archive.directory() + "/gp_" + std::to_string(i));
}
}
/// load the parameters and the data for the GP from the archive (text or binary)
/// if recompute is true, we do not read the kernel matrix
/// but we recompute it given the data and the hyperparameters
template <typename A>
void load(const std::string& directory, bool recompute = true)
{
A archive(directory);
load(archive, recompute);
}
/// load the parameters and the data for the GP from the archive (text or binary)
/// if recompute is true, we do not read the kernel matrix
/// but we recompute it given the data and the hyperparameters
template <typename A>
void load(const A& archive, bool recompute = true)
{
_observations.clear();
archive.load(_observations, "observations");
Eigen::VectorXd dims;
archive.load(dims, "dims");
_dim_in = static_cast<int>(dims(0));
_dim_out = static_cast<int>(dims(1));
_mean_function = MeanFunction(_dim_out);
if (_mean_function.h_params_size() > 0) {
Eigen::VectorXd h_params;
archive.load(h_params, "mean_params");
assert(h_params.size() == (int)_mean_function.h_params_size());
_mean_function.set_h_params(h_params);
}
for (int i = 0; i < _dim_out; i++) {
// do not recompute the individual GPs on their own
_gp_models[i].load<A>(archive.directory() + "/gp_" + std::to_string(i), false);
}
if (recompute)
this->recompute(true, true);
}
protected:
std::vector<GP_t> _gp_models;
int _dim_in, _dim_out;
......
......@@ -124,6 +124,11 @@ namespace limbo {
return _dir_name + "/" + object_name + ".bin";
}
std::string directory() const
{
return _dir_name;
}
protected:
std::string _dir_name;
......
......@@ -113,6 +113,11 @@ namespace limbo {
return _dir_name + "/" + object_name + ".dat";
}
std::string directory() const
{
return _dir_name;
}
protected:
std::string _dir_name;
Eigen::IOFormat _fmt;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment