diff --git a/miplearn/components/component.py b/miplearn/components/component.py index d532c35..7b143c3 100644 --- a/miplearn/components/component.py +++ b/miplearn/components/component.py @@ -62,6 +62,38 @@ class Component: """ return + def before_solve_lp( + self, + solver: "LearningSolver", + instance: Instance, + model: Any, + stats: LearningSolveStats, + sample: Sample, + ) -> None: + """ + Method called by LearningSolver before the root LP relaxation is solved. + + Parameters + ---------- + solver: LearningSolver + The solver calling this method. + instance: Instance + The instance being solved. + model + The concrete optimization model being solved. + stats: LearningSolveStats + A dictionary containing statistics about the solution process, such as + number of nodes explored and running time. Components are free to add + their own statistics here. For example, PrimalSolutionComponent adds + statistics regarding the number of predicted variables. All statistics in + this dictionary are exported to the benchmark CSV file. + sample: miplearn.features.Sample + An object containing data that may be useful for training machine + learning models and accelerating the solution process. Components are + free to add their own training data here. + """ + return + def after_solve_lp_old( self, solver: "LearningSolver", @@ -77,6 +109,20 @@ class Component: """ return + def after_solve_lp( + self, + solver: "LearningSolver", + instance: Instance, + model: Any, + stats: LearningSolveStats, + sample: Sample, + ) -> None: + """ + Method called by LearningSolver after the root LP relaxation is solved. + See before_solve_lp for a description of the parameters. + """ + return + def before_solve_mip_old( self, solver: "LearningSolver", @@ -92,6 +138,20 @@ class Component: """ return + def before_solve_mip( + self, + solver: "LearningSolver", + instance: Instance, + model: Any, + stats: LearningSolveStats, + sample: Sample, + ) -> None: + """ + Method called by LearningSolver before the MIP is solved. + See before_solve_lp for a description of the parameters. + """ + return + def after_solve_mip_old( self, solver: "LearningSolver", @@ -107,6 +167,20 @@ class Component: """ return + def after_solve_mip( + self, + solver: "LearningSolver", + instance: Instance, + model: Any, + stats: LearningSolveStats, + sample: Sample, + ) -> None: + """ + Method called by LearningSolver after the MIP is solved. + See before_solve_lp for a description of the parameters. + """ + return + def sample_xy_old( self, instance: Instance, diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index fe58582..d8b3dfe 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -162,6 +162,14 @@ class LearningSolver: instance.features.__dict__ = features.__dict__ sample.after_load = features + callback_args = ( + self, + instance, + model, + stats, + sample, + ) + callback_args_old = ( self, instance, @@ -177,6 +185,7 @@ class LearningSolver: if self.solve_lp: logger.debug("Running before_solve_lp callbacks...") for component in self.components.values(): + component.before_solve_lp(*callback_args) component.before_solve_lp_old(*callback_args_old) logger.info("Solving root LP relaxation...") @@ -188,6 +197,7 @@ class LearningSolver: logger.debug("Running after_solve_lp callbacks...") for component in self.components.values(): + component.after_solve_lp(*callback_args) component.after_solve_lp_old(*callback_args_old) # Extract features (after-lp) @@ -232,6 +242,7 @@ class LearningSolver: # ------------------------------------------------------- logger.debug("Running before_solve_mip callbacks...") for component in self.components.values(): + component.before_solve_mip(*callback_args) component.before_solve_mip_old(*callback_args_old) # Solve MIP @@ -269,6 +280,7 @@ class LearningSolver: # ------------------------------------------------------- logger.debug("Calling after_solve_mip callbacks...") for component in self.components.values(): + component.after_solve_mip(*callback_args) component.after_solve_mip_old(*callback_args_old) # Flush