Commit 9f5cef2c authored by Antoine Cully's avatar Antoine Cully
Browse files

Merge branch 'serialize_model' of https://github.com/resibots/limbo into serialize_model

parents c3d1de0a fc16102a
......@@ -73,7 +73,7 @@ std::tuple<double, Eigen::MatrixXd, Eigen::MatrixXd> check_grad(const Mean& mean
analytic_result = me.grad(v, v);
finite_diff_result = Eigen::MatrixXd::Zero(v.size(), x.size());
finite_diff_result = Eigen::MatrixXd::Zero(me(v, v).size(), x.size());
for (int j = 0; j < x.size(); j++) {
Eigen::VectorXd test1 = x, test2 = x;
test1[j] -= e;
......@@ -91,9 +91,9 @@ std::tuple<double, Eigen::MatrixXd, Eigen::MatrixXd> check_grad(const Mean& mean
}
template <typename Mean>
void check_mean(size_t N, size_t K)
void check_mean(size_t dim_in, size_t dim_out, size_t K)
{
Mean mean(N);
Mean mean(dim_out);
for (size_t i = 0; i < K; i++) {
Eigen::VectorXd hp = tools::random_vector(mean.h_params_size()).array() * 10. - 5.;
......@@ -101,7 +101,7 @@ void check_mean(size_t N, size_t K)
double error;
Eigen::MatrixXd analytic, finite_diff;
Eigen::VectorXd v = tools::random_vector(N).array() * 10. - 5.;
Eigen::VectorXd v = tools::random_vector(dim_in).array() * 10. - 5.;
std::tie(error, analytic, finite_diff) = check_grad(mean, hp, v);
// std::cout << error << ": " << analytic << " vs " << finite_diff << std::endl;
......@@ -111,8 +111,10 @@ void check_mean(size_t N, size_t K)
BOOST_AUTO_TEST_CASE(test_mean_constant)
{
for (int i = 1; i <= 10; i++) {
check_mean<mean::Constant<Params>>(i, 100);
for (int k = 1; k <= 10; k++) {
for (int i = 1; i <= 10; i++) {
check_mean<mean::Constant<Params>>(k, i, 100);
}
}
}
......@@ -120,8 +122,10 @@ BOOST_AUTO_TEST_CASE(test_mean_function_ard)
{
// This test checks the gradients computation of FunctionARD when the base mean function
// also has tunable parameters
for (int i = 1; i <= 10; i++) {
check_mean<mean::FunctionARD<Params, mean::Constant<Params>>>(i, 100);
for (int k = 1; k <= 10; k++) {
for (int i = 1; i <= 10; i++) {
check_mean<mean::FunctionARD<Params, mean::Constant<Params>>>(k, i, 100);
}
}
}
......@@ -129,7 +133,9 @@ BOOST_AUTO_TEST_CASE(test_mean_function_ard_dummy)
{
// This test checks the gradients computation of FunctionARD when the base mean function
// has no tunable parameters
for (int i = 1; i <= 10; i++) {
check_mean<mean::FunctionARD<Params, mean::NullFunction<Params>>>(i, 100);
for (int k = 1; k <= 10; k++) {
for (int i = 1; i <= 10; i++) {
check_mean<mean::FunctionARD<Params, mean::NullFunction<Params>>>(k, i, 100);
}
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment