Commit 77e80557 authored by Konstantinos Chatzilygeroudis's avatar Konstantinos Chatzilygeroudis
Browse files

Fixing plotting of regression benchmarks and more reasonable configuration [ci skip]

parent ced98724
......@@ -131,24 +131,21 @@ def custom_ax(ax):
ax.set_axisbelow(True)
ax.grid(axis='x', color="0.9", linestyle='-')
def plot_data(name, data, points, labely):
fig = figure()
ax = gca()
def plot_ax(ax, data, points, labely, disp_legend=True, disp_xaxis=False):
labels = []
kk = 0
# for each variant
for var in points.keys():
labels.append(var)
var_p = points[var]
var_mses = data[var]
var_data = data[var]
pp = {}
for i in range(len(var_p)):
if var_p[i] not in pp:
pp[var_p[i]] = []
pp[var_p[i]].append(var_mses[i])
pp[var_p[i]].append(var_data[i])
pp = OrderedDict(sorted(pp.items()))
......@@ -164,16 +161,27 @@ def plot_data(name, data, points, labely):
y_axis_25.append(np.percentile(dd[i], 25))
c_kk = colors[kk%len(colors)]
ax.plot(x_axis, y_axis, '-', color=c_kk, linewidth=3)
ax.plot(x_axis, y_axis, '-o', color=c_kk, linewidth=3, markersize=5)
ax.fill_between(x_axis, y_axis_75, y_axis_25, color=c_kk, alpha=0.15, linewidth=2)
kk = kk + 1
ax.legend(labels)
ax.set_xlabel('Number of points')
if disp_legend:
ax.legend(labels)
if disp_xaxis:
ax.set_xlabel('Number of points')
ax.set_ylabel(labely)
custom_ax(ax)
def plot_data(bench, func, dim, mses, query_times, learning_times, points):
name = func+'_'+str(dim)
fig, ax = plt.subplots(3, sharex=True)
plot_ax(ax[0], mses, points, 'Mean Squared Error')
plot_ax(ax[1], query_times, points, 'Querying time in ms', False)
plot_ax(ax[2], learning_times, points, 'Learning time in seconds', False, True)
fig.tight_layout()
fig.savefig(name+'.png')
fig.savefig('regression_benchmark_results/'+bench+'/'+name+'.png')
close()
def plot(points,times_learn,times_query,mses):
......@@ -186,14 +194,7 @@ def plot(points,times_learn,times_query,mses):
print('plotting for benchmark: ' + bench + ', the function: ' + func + ' for dimension: ' + str(dim))
name = bench+'_'+func+'_'+str(dim)
# plotting MSE
plot_data(name+'_mse', mses[bench][func][dim], points[bench][func][dim], 'Mean Squared Error')
# plotting learning times
plot_data(name+'_learn_time', times_learn[bench][func][dim], points[bench][func][dim], 'Learning time in seconds')
# plotting querying times
plot_data(name+'_query_time', times_query[bench][func][dim], points[bench][func][dim], 'Querying time in ms')
plot_data(bench, func, dim, mses[bench][func][dim], times_query[bench][func][dim], times_learn[bench][func][dim], points[bench][func][dim])
def plot_all():
if not plot_ok:
......
[
{ "name" : "regression_benchmark",
"functions" : ["Rastrigin", "Ackley", "GramacyLee", "Step", "SixHumpCamel", "RobotArm", "OTLCircuit", "PistonSimulation", "PlanarInverseDynamicsI", "PlanarInverseDynamicsII"],
"dimensions" : [[1,2,4,8], [1,2,4,8], [1], [1], [2], [8], [6], [7], [6], [6]],
"functions" : ["Rastrigin", "GramacyLee", "Step", "RobotArm", "OTLCircuit", "PistonSimulation", "PlanarInverseDynamicsI", "PlanarInverseDynamicsII"],
"dimensions" : [[1,2,4,8], [1], [1], [8], [6], [7], [6], [6]],
"points" : [50, 100, 200, 400, 600],
"randomness": ["uniform"],
"noise" : "true",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment