mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 18:38:51 -06:00
Add user cut callbacks; begin rewrite of UserCutsComponent
This commit is contained in:
@@ -125,18 +125,22 @@ class LearningSolver:
|
||||
) -> LearningSolveStats:
|
||||
|
||||
# Generate model
|
||||
# -------------------------------------------------------
|
||||
if model is None:
|
||||
with _RedirectOutput([]):
|
||||
model = instance.to_model()
|
||||
|
||||
# Initialize training sample
|
||||
# -------------------------------------------------------
|
||||
training_sample = TrainingSample()
|
||||
instance.training_data += [training_sample]
|
||||
|
||||
# Initialize stats
|
||||
# -------------------------------------------------------
|
||||
stats: LearningSolveStats = {}
|
||||
|
||||
# Initialize internal solver
|
||||
# -------------------------------------------------------
|
||||
self.tee = tee
|
||||
self.internal_solver = self.solver_factory()
|
||||
assert self.internal_solver is not None
|
||||
@@ -144,6 +148,7 @@ class LearningSolver:
|
||||
self.internal_solver.set_instance(instance, model)
|
||||
|
||||
# Extract features
|
||||
# -------------------------------------------------------
|
||||
FeaturesExtractor(self.internal_solver).extract(instance)
|
||||
|
||||
callback_args = (
|
||||
@@ -156,6 +161,7 @@ class LearningSolver:
|
||||
)
|
||||
|
||||
# Solve root LP relaxation
|
||||
# -------------------------------------------------------
|
||||
if self.solve_lp:
|
||||
logger.debug("Running before_solve_lp callbacks...")
|
||||
for component in self.components.values():
|
||||
@@ -172,37 +178,50 @@ class LearningSolver:
|
||||
for component in self.components.values():
|
||||
component.after_solve_lp(*callback_args)
|
||||
|
||||
# Define wrappers
|
||||
# Callback wrappers
|
||||
# -------------------------------------------------------
|
||||
def iteration_cb_wrapper() -> bool:
|
||||
should_repeat = False
|
||||
assert isinstance(instance, Instance)
|
||||
for comp in self.components.values():
|
||||
if comp.iteration_cb(self, instance, model):
|
||||
should_repeat = True
|
||||
return should_repeat
|
||||
|
||||
def lazy_cb_wrapper(
|
||||
cb_solver: LearningSolver,
|
||||
cb_solver: InternalSolver,
|
||||
cb_model: Any,
|
||||
) -> None:
|
||||
assert isinstance(instance, Instance)
|
||||
for comp in self.components.values():
|
||||
comp.lazy_cb(self, instance, model)
|
||||
|
||||
def user_cut_cb_wrapper(
|
||||
cb_solver: InternalSolver,
|
||||
cb_model: Any,
|
||||
) -> None:
|
||||
for comp in self.components.values():
|
||||
comp.user_cut_cb(self, instance, model)
|
||||
|
||||
lazy_cb = None
|
||||
if self.use_lazy_cb:
|
||||
lazy_cb = lazy_cb_wrapper
|
||||
|
||||
user_cut_cb = None
|
||||
if instance.has_user_cuts():
|
||||
user_cut_cb = user_cut_cb_wrapper
|
||||
|
||||
# Before-solve callbacks
|
||||
# -------------------------------------------------------
|
||||
logger.debug("Running before_solve_mip callbacks...")
|
||||
for component in self.components.values():
|
||||
component.before_solve_mip(*callback_args)
|
||||
|
||||
# Solve MIP
|
||||
# -------------------------------------------------------
|
||||
logger.info("Solving MIP...")
|
||||
mip_stats = self.internal_solver.solve(
|
||||
tee=tee,
|
||||
iteration_cb=iteration_cb_wrapper,
|
||||
user_cut_cb=user_cut_cb,
|
||||
lazy_cb=lazy_cb,
|
||||
)
|
||||
stats.update(cast(LearningSolveStats, mip_stats))
|
||||
@@ -216,17 +235,20 @@ class LearningSolver:
|
||||
stats["Mode"] = self.mode
|
||||
|
||||
# Add some information to training_sample
|
||||
# -------------------------------------------------------
|
||||
training_sample.lower_bound = stats["Lower bound"]
|
||||
training_sample.upper_bound = stats["Upper bound"]
|
||||
training_sample.mip_log = stats["MIP log"]
|
||||
training_sample.solution = self.internal_solver.get_solution()
|
||||
|
||||
# After-solve callbacks
|
||||
# -------------------------------------------------------
|
||||
logger.debug("Calling after_solve_mip callbacks...")
|
||||
for component in self.components.values():
|
||||
component.after_solve_mip(*callback_args)
|
||||
|
||||
# Write to file, if necessary
|
||||
# Flush
|
||||
# -------------------------------------------------------
|
||||
if not discard_output:
|
||||
instance.flush()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user