Commit b2a401f1 authored by Jean-Baptiste Mouret's avatar Jean-Baptiste Mouret
Browse files

first prototype of the serialization

parent 14f51918
......@@ -125,6 +125,10 @@ namespace limbo {
// Get signal noise
double noise() const { return _noise; }
template<typename A>
void save(A& a) {
a.save(h_params());
}
protected:
double _noise;
double _noise_p;
......
......@@ -241,11 +241,10 @@ namespace limbo {
if (update_obs_mean)
this->_compute_obs_mean();
if (update_full_kernel)
this->_compute_full_kernel();
else
this->_compute_alpha();
this->_compute_alpha();
}
void compute_inv_kernel()
......@@ -415,6 +414,34 @@ namespace limbo {
bool inv_kernel_computed() { return _inv_kernel_updated; }
/// save the parameters and the data for the GP to the archive (text or binary)
template<typename A>
void save(A& archive)
{
archive.save(_kernel_function.h_params(), "kernel_params");
// archive.save(_mean_function.h_params(), "mean_params");
archive.save(_samples, "samples");
archive.save(_observations, "observations");
}
/// load the parameters and the data for the GP from the archive (text or binary)
template <typename A>
void load(A& archive)
{
Eigen::VectorXd h_params;
archive.load(h_params, "kernel_params");
assert(h_params.size() == _kernel_function.h_params().size());
_kernel_function.set_h_params(h_params);
// should we save parameters of the mean function as well?
std::vector<Eigen::VectorXd> samples;
archive.load(samples, "samples");
std::vector<Eigen::VectorXd> observations;
archive.load(observations, "observations");
compute(samples, observations);
}
protected:
int _dim_in;
int _dim_out;
......@@ -441,9 +468,14 @@ namespace limbo {
void _compute_obs_mean()
{
assert(!_samples.empty());
_mean_vector.resize(_samples.size(), _dim_out);
for (int i = 0; i < _mean_vector.rows(); i++)
for (int i = 0; i < _mean_vector.rows(); i++){
assert(_samples[i].cols() == 1);
assert(_samples[i].rows() != 0);
assert(_samples[i].rows() == _dim_in);
_mean_vector.row(i) = _mean_function(_samples[i], *this);
}
_obs_mean = _observations - _mean_vector;
}
......
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
#include <boost/filesystem.hpp>
#include <Eigen/Core>
namespace limbo {
namespace serialize {
class TextArchive {
public:
TextArchive(const std::string& dir_name) : _dir_name(dir_name),
_fmt(Eigen::FullPrecision, Eigen::DontAlignCols, " ", "\n", "", "") {}
/// write an Eigen::Matrix*
void save(const Eigen::MatrixXd& v, const std::string& object_name)
{
_create_directory();
std::ofstream ofs(fname(object_name).c_str());
ofs << v.format(_fmt) << std::endl;
}
/// write a vector of Eigen::Vector*
template <typename T>
void save(const std::vector<T>& v, const std::string& object_name)
{
_create_directory();
std::ofstream ofs(fname(object_name).c_str());
for (auto& x : v) {
ofs << x.transpose().format(_fmt) << std::endl;
}
}
/// load an Eigen matrix (or vector)
template <typename M>
void load(M& m, const std::string& object_name)
{
auto values = _load(object_name);
m.resize(values.size(), values[0].size());
for (size_t i = 0; i < values.size(); ++i)
for (size_t j = 0; j < values[i].size(); ++j)
m(i, j) = values[i][j];
}
/// load a vector of Eigen::Vector*
template <typename V>
void load(std::vector<V>& m_list, const std::string& object_name)
{
m_list.clear();
auto values = _load(object_name);
assert(!values.empty());
for (size_t i = 0; i < values.size(); ++i) {
V v(values[i].size());
for (size_t j = 0; j < values[i].size(); ++j)
v(j) = values[i][j];
m_list.push_back(v);
}
assert(!m_list.empty());
}
std::string fname(const std::string& object_name) const
{
return _dir_name + "/" + object_name + ".dat";
}
protected:
std::string _dir_name;
Eigen::IOFormat _fmt;
//
void _create_directory()
{
boost::filesystem::path my_path(_dir_name);
boost::filesystem::create_directory(my_path);
}
std::vector<std::vector<double>> _load(const std::string& object_name)
{
std::ifstream ifs(fname(object_name).c_str());
assert(ifs.good() && "file not found");
std::string line;
std::vector<std::vector<double>> v;
while (std::getline(ifs, line)) {
std::stringstream line_stream(line);
std::string cell;
std::vector<double> line;
while (std::getline(line_stream, cell, ' '))
line.push_back(std::stod(cell));
v.push_back(line);
}
assert(!v.empty() && "empty file");
return v;
}
};
} // namespace serialize
} // namespace limbo
\ No newline at end of file
//| Copyright Inria May 2015
//| This project has received funding from the European Research Council (ERC) under
//| the European Union's Horizon 2020 research and innovation programme (grant
//| agreement No 637972) - see http://www.resibots.eu
//|
//| Contributor(s):
//| - Jean-Baptiste Mouret (jean-baptiste.mouret@inria.fr)
//| - Antoine Cully (antoinecully@gmail.com)
//| - Kontantinos Chatzilygeroudis (konstantinos.chatzilygeroudis@inria.fr)
//| - Federico Allocati (fede.allocati@gmail.com)
//| - Vaios Papaspyros (b.papaspyros@gmail.com)
//| - Roberto Rama (bertoski@gmail.com)
//|
//| This software is a computer library whose purpose is to optimize continuous,
//| black-box functions. It mainly implements Gaussian processes and Bayesian
//| optimization.
//| Main repository: http://github.com/resibots/limbo
//| Documentation: http://www.resibots.eu/limbo
//|
//| This software is governed by the CeCILL-C license under French law and
//| abiding by the rules of distribution of free software. You can use,
//| modify and/ or redistribute the software under the terms of the CeCILL-C
//| license as circulated by CEA, CNRS and INRIA at the following URL
//| "http://www.cecill.info".
//|
//| As a counterpart to the access to the source code and rights to copy,
//| modify and redistribute granted by the license, users are provided only
//| with a limited warranty and the software's author, the holder of the
//| economic rights, and the successive licensors have only limited
//| liability.
//|
//| In this respect, the user's attention is drawn to the risks associated
//| with loading, using, modifying and/or developing or reproducing the
//| software by the user in light of its specific status of free software,
//| that may mean that it is complicated to manipulate, and that also
//| therefore means that it is reserved for developers and experienced
//| professionals having in-depth computer knowledge. Users are therefore
//| encouraged to load and test the software's suitability as regards their
//| requirements in conditions enabling the security of their systems and/or
//| data to be ensured and, more generally, to use and operate it in the
//| same conditions as regards security.
//|
//| The fact that you are presently reading this means that you have had
//| knowledge of the CeCILL-C license and that you accept its terms.
//|
#define BOOST_TEST_DYN_LINK
#define BOOST_TEST_MODULE test_serialize
#include <cstring>
#include <fstream>
#include <boost/test/unit_test.hpp>
#include <limbo/model/gp.hpp>
#include <limbo/serialize/text_archive.hpp>
struct Params {
struct kernel_exp {
BO_PARAM(double, sigma_sq, 1.0);
BO_PARAM(double, l, 0.2);
};
struct kernel : public limbo::defaults::kernel {
};
struct kernel_squared_exp_ard : public limbo::defaults::kernel_squared_exp_ard {
};
struct opt_rprop : public limbo::defaults::opt_rprop {
};
struct opt_parallelrepeater : public limbo::defaults::opt_parallelrepeater {
};
};
BOOST_AUTO_TEST_CASE(test_text_archive)
{
using namespace limbo;
// our data (3-D inputs, 1-D outputs)
std::vector<Eigen::VectorXd> samples;
std::vector<Eigen::VectorXd> observations;
size_t n = 8;
for (size_t i = 0; i < n; i++) {
Eigen::VectorXd s = tools::random_vector(3).array() * 4.0 - 2.0;
samples.push_back(s);
observations.push_back(tools::make_vector(std::cos(s(0)*s(1)*s(2))));
}
// 3-D inputs, 1-D outputs
model::GPOpt<Params> gp(3, 1);
gp.compute(samples, observations);
gp.optimize_hyperparams();
// attempt to save
serialize::TextArchive a1("/tmp/test_model.dat");
gp.save(a1);
// attempt to read
model::GPOpt<Params> gp2(3, 1);
serialize::TextArchive a2("/tmp/test_model.dat");
gp2.load(a2);
}
BOOST_AUTO_TEST_CASE(test_bin_archive)
{
}
......@@ -86,6 +86,12 @@ def build(bld):
target='test_macros',
uselib='BOOST EIGEN TBB',
use='limbo')
bld.program(features='cxx test',
source='test_serialize.cpp',
includes='. .. ../../',
target='test_serialize',
uselib='BOOST EIGEN TBB',
use='limbo')
if bld.env.DEFINES_NLOPT or bld.env.DEFINES_LIBCMAES:
bld.program(features='cxx test',
source='test_boptimizer.cpp',
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment