mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 17:38:51 -06:00
Plot predicted objective value
This commit is contained in:
@@ -118,52 +118,72 @@ def charts():
|
|||||||
sense = results.loc[0, "Sense"]
|
sense = results.loc[0, "Sense"]
|
||||||
if sense == "min":
|
if sense == "min":
|
||||||
primal_column = "Relative Upper Bound"
|
primal_column = "Relative Upper Bound"
|
||||||
|
obj_column = "Upper Bound"
|
||||||
|
predicted_obj_column = "Predicted UB"
|
||||||
else:
|
else:
|
||||||
primal_column = "Relative Lower Bound"
|
primal_column = "Relative Lower Bound"
|
||||||
|
obj_column = "Lower Bound"
|
||||||
|
predicted_obj_column = "Predicted LB"
|
||||||
|
|
||||||
palette={
|
palette={
|
||||||
"baseline": "#9b59b6",
|
"baseline": "#9b59b6",
|
||||||
"ml-exact": "#3498db",
|
"ml-exact": "#3498db",
|
||||||
"ml-heuristic": "#95a5a6"
|
"ml-heuristic": "#95a5a6"
|
||||||
}
|
}
|
||||||
fig, axes = plt.subplots(nrows=1,
|
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1,
|
||||||
ncols=3,
|
ncols=4,
|
||||||
figsize=(10,4),
|
figsize=(12,4),
|
||||||
gridspec_kw={'width_ratios': [3, 3, 2]},
|
gridspec_kw={'width_ratios': [2, 1, 1, 2]},
|
||||||
)
|
)
|
||||||
sns.stripplot(x="Solver",
|
sns.stripplot(x="Solver",
|
||||||
y="Wallclock Time",
|
y="Wallclock Time",
|
||||||
data=results,
|
data=results,
|
||||||
ax=axes[0],
|
ax=ax1,
|
||||||
jitter=0.25,
|
jitter=0.25,
|
||||||
palette=palette,
|
palette=palette,
|
||||||
|
size=4.0,
|
||||||
);
|
);
|
||||||
sns.barplot(x="Solver",
|
sns.barplot(x="Solver",
|
||||||
y="Wallclock Time",
|
y="Wallclock Time",
|
||||||
data=results,
|
data=results,
|
||||||
ax=axes[0],
|
ax=ax1,
|
||||||
errwidth=0.,
|
errwidth=0.,
|
||||||
alpha=0.3,
|
alpha=0.4,
|
||||||
palette=palette,
|
palette=palette,
|
||||||
estimator=median,
|
estimator=median,
|
||||||
);
|
);
|
||||||
axes[0].set(ylabel='Wallclock Time (s)')
|
ax1.set(ylabel='Wallclock Time (s)')
|
||||||
axes[1].set_ylim(-0.5, 5.5)
|
ax2.set_ylim(-0.5, 5.5)
|
||||||
sns.stripplot(x="Solver",
|
sns.stripplot(x="Solver",
|
||||||
y="Gap (%)",
|
y="Gap (%)",
|
||||||
jitter=0.25,
|
jitter=0.25,
|
||||||
data=results[results["Solver"] != "ml-heuristic"],
|
data=results[results["Solver"] != "ml-heuristic"],
|
||||||
ax=axes[1],
|
ax=ax2,
|
||||||
palette=palette,
|
palette=palette,
|
||||||
|
size=4.0,
|
||||||
);
|
);
|
||||||
axes[2].set_ylim(0.95,1.01)
|
ax3.set_ylim(0.95,1.05)
|
||||||
sns.stripplot(x="Solver",
|
sns.stripplot(x="Solver",
|
||||||
y=primal_column,
|
y=primal_column,
|
||||||
jitter=0.25,
|
jitter=0.25,
|
||||||
data=results[results["Solver"] == "ml-heuristic"],
|
data=results[results["Solver"] == "ml-heuristic"],
|
||||||
ax=axes[2],
|
ax=ax3,
|
||||||
palette=palette,
|
palette=palette,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
sns.scatterplot(x=obj_column,
|
||||||
|
y=predicted_obj_column,
|
||||||
|
hue="Solver",
|
||||||
|
data=results[results["Solver"] == "ml-exact"],
|
||||||
|
ax=ax4,
|
||||||
|
palette=palette,
|
||||||
|
);
|
||||||
|
xlim, ylim = ax4.get_xlim(), ax4.get_ylim()
|
||||||
|
ax4.plot([-1e10, 1e10], [-1e10, 1e10], ls='-', color="#cccccc");
|
||||||
|
ax4.set_xlim(xlim)
|
||||||
|
ax4.set_ylim(ylim)
|
||||||
|
ax4.get_legend().remove()
|
||||||
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
plt.savefig("%s/performance.png" % basepath,
|
plt.savefig("%s/performance.png" % basepath,
|
||||||
bbox_inches='tight',
|
bbox_inches='tight',
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class BenchmarkRunner:
|
|||||||
"Nodes",
|
"Nodes",
|
||||||
"Mode",
|
"Mode",
|
||||||
"Sense",
|
"Sense",
|
||||||
|
"Predicted LB",
|
||||||
|
"Predicted UB",
|
||||||
])
|
])
|
||||||
lb = result["Lower bound"]
|
lb = result["Lower bound"]
|
||||||
ub = result["Upper bound"]
|
ub = result["Upper bound"]
|
||||||
@@ -78,6 +80,8 @@ class BenchmarkRunner:
|
|||||||
"Nodes": result["Nodes"],
|
"Nodes": result["Nodes"],
|
||||||
"Mode": solver.mode,
|
"Mode": solver.mode,
|
||||||
"Sense": result["Sense"],
|
"Sense": result["Sense"],
|
||||||
|
"Predicted LB": result["Predicted LB"],
|
||||||
|
"Predicted UB": result["Predicted UB"],
|
||||||
}, ignore_index=True)
|
}, ignore_index=True)
|
||||||
groups = self.results.groupby("Instance")
|
groups = self.results.groupby("Instance")
|
||||||
best_lower_bound = groups["Lower Bound"].transform("max")
|
best_lower_bound = groups["Lower Bound"].transform("max")
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class Component(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def after_solve(self, solver, instance, model):
|
def after_solve(self, solver, instance, model, results):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class LazyConstraintsComponent(Component):
|
|||||||
cut = instance.build_lazy_constraint(model, v)
|
cut = instance.build_lazy_constraint(model, v)
|
||||||
solver.internal_solver.add_constraint(cut)
|
solver.internal_solver.add_constraint(cut)
|
||||||
|
|
||||||
def after_solve(self, solver, instance, model):
|
def after_solve(self, solver, instance, model, results):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def fit(self, training_instances):
|
def fit(self, training_instances):
|
||||||
|
|||||||
@@ -27,8 +27,13 @@ class ObjectiveValueComponent(Component):
|
|||||||
instance.predicted_lb = lb
|
instance.predicted_lb = lb
|
||||||
logger.info("Predicted objective: [%.2f, %.2f]" % (lb, ub))
|
logger.info("Predicted objective: [%.2f, %.2f]" % (lb, ub))
|
||||||
|
|
||||||
def after_solve(self, solver, instance, model):
|
def after_solve(self, solver, instance, model, results):
|
||||||
pass
|
if self.ub_regressor is not None:
|
||||||
|
results["Predicted UB"] = instance.predicted_ub
|
||||||
|
results["Predicted LB"] = instance.predicted_lb
|
||||||
|
else:
|
||||||
|
results["Predicted UB"] = None
|
||||||
|
results["Predicted LB"] = None
|
||||||
|
|
||||||
def fit(self, training_instances):
|
def fit(self, training_instances):
|
||||||
logger.debug("Extracting features...")
|
logger.debug("Extracting features...")
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
else:
|
else:
|
||||||
solver.internal_solver.set_warm_start(solution)
|
solver.internal_solver.set_warm_start(solution)
|
||||||
|
|
||||||
def after_solve(self, solver, instance, model):
|
def after_solve(self, solver, instance, model, results):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def fit(self, training_instances):
|
def fit(self, training_instances):
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ class LearningSolver:
|
|||||||
|
|
||||||
logger.debug("Calling after_solve callbacks...")
|
logger.debug("Calling after_solve callbacks...")
|
||||||
for component in self.components.values():
|
for component in self.components.values():
|
||||||
component.after_solve(self, instance, model)
|
component.after_solve(self, instance, model, results)
|
||||||
|
|
||||||
# Store instance for future training
|
# Store instance for future training
|
||||||
self.training_instances += [instance]
|
self.training_instances += [instance]
|
||||||
|
|||||||
@@ -27,11 +27,11 @@ def test_benchmark():
|
|||||||
benchmark = BenchmarkRunner(test_solvers)
|
benchmark = BenchmarkRunner(test_solvers)
|
||||||
benchmark.fit(train_instances)
|
benchmark.fit(train_instances)
|
||||||
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
|
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
|
||||||
assert benchmark.raw_results().values.shape == (12,13)
|
assert benchmark.raw_results().values.shape == (12,16)
|
||||||
|
|
||||||
benchmark.save_results("/tmp/benchmark.csv")
|
benchmark.save_results("/tmp/benchmark.csv")
|
||||||
assert os.path.isfile("/tmp/benchmark.csv")
|
assert os.path.isfile("/tmp/benchmark.csv")
|
||||||
|
|
||||||
benchmark = BenchmarkRunner(test_solvers)
|
benchmark = BenchmarkRunner(test_solvers)
|
||||||
benchmark.load_results("/tmp/benchmark.csv")
|
benchmark.load_results("/tmp/benchmark.csv")
|
||||||
assert benchmark.raw_results().values.shape == (12,13)
|
assert benchmark.raw_results().values.shape == (12,16)
|
||||||
|
|||||||
Reference in New Issue
Block a user