mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
LearningSolver: only compute all_vars once
This commit is contained in:
@@ -106,8 +106,10 @@ class InternalSolver:
|
|||||||
else:
|
else:
|
||||||
self.sense = "max"
|
self.sense = "max"
|
||||||
self.var_name_to_var = {}
|
self.var_name_to_var = {}
|
||||||
|
self.all_vars = []
|
||||||
for var in model.component_objects(Var):
|
for var in model.component_objects(Var):
|
||||||
self.var_name_to_var[var.name] = var
|
self.var_name_to_var[var.name] = var
|
||||||
|
self.all_vars += [var[idx] for idx in var]
|
||||||
|
|
||||||
def set_instance(self, instance):
|
def set_instance(self, instance):
|
||||||
self.instance = instance
|
self.instance = instance
|
||||||
@@ -173,8 +175,7 @@ class GurobiSolver(InternalSolver):
|
|||||||
from gurobipy import GRB
|
from gurobipy import GRB
|
||||||
def cb(cb_model, cb_opt, cb_where):
|
def cb(cb_model, cb_opt, cb_where):
|
||||||
if cb_where == GRB.Callback.MIPSOL:
|
if cb_where == GRB.Callback.MIPSOL:
|
||||||
all_vars = [v[idx] for v in self.model.component_objects(Var) for idx in v]
|
cb_opt.cbGetSolution(self.all_vars)
|
||||||
cb_opt.cbGetSolution(all_vars)
|
|
||||||
logger.debug("Finding violated constraints...")
|
logger.debug("Finding violated constraints...")
|
||||||
violations = self.instance.find_violations(cb_model)
|
violations = self.instance.find_violations(cb_model)
|
||||||
self.instance.found_violations += violations
|
self.instance.found_violations += violations
|
||||||
|
|||||||
Reference in New Issue
Block a user