mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Make all before/solve callbacks receive same parameters
This commit is contained in:
@@ -62,9 +62,16 @@ class PrimalSolutionComponent(Component):
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
self.stats: Dict[str, float] = {}
|
||||
|
||||
def before_solve_mip(self, solver, instance, model):
|
||||
def before_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
features: Features,
|
||||
training_data: TrainingSample,
|
||||
) -> None:
|
||||
if len(self.thresholds) > 0:
|
||||
logger.info("Predicting MIP solution...")
|
||||
solution = self.predict(
|
||||
@@ -72,41 +79,32 @@ class PrimalSolutionComponent(Component):
|
||||
instance.training_data[-1],
|
||||
)
|
||||
|
||||
# Collect prediction statistics
|
||||
self.stats["Primal: Free"] = 0
|
||||
self.stats["Primal: Zero"] = 0
|
||||
self.stats["Primal: One"] = 0
|
||||
# Update statistics
|
||||
stats["Primal: Free"] = 0
|
||||
stats["Primal: Zero"] = 0
|
||||
stats["Primal: One"] = 0
|
||||
for (var, var_dict) in solution.items():
|
||||
for (idx, value) in var_dict.items():
|
||||
if value is None:
|
||||
self.stats["Primal: Free"] += 1
|
||||
stats["Primal: Free"] += 1
|
||||
else:
|
||||
if value < 0.5:
|
||||
self.stats["Primal: Zero"] += 1
|
||||
stats["Primal: Zero"] += 1
|
||||
else:
|
||||
self.stats["Primal: One"] += 1
|
||||
stats["Primal: One"] += 1
|
||||
logger.info(
|
||||
f"Predicted: free: {self.stats['Primal: Free']}, "
|
||||
f"zero: {self.stats['Primal: zero']}, "
|
||||
f"one: {self.stats['Primal: One']}"
|
||||
f"Predicted: free: {stats['Primal: Free']}, "
|
||||
f"zero: {stats['Primal: Zero']}, "
|
||||
f"one: {stats['Primal: One']}"
|
||||
)
|
||||
|
||||
# Provide solution to the solver
|
||||
assert solver.internal_solver is not None
|
||||
if self.mode == "heuristic":
|
||||
solver.internal_solver.fix(solution)
|
||||
else:
|
||||
solver.internal_solver.set_warm_start(solution)
|
||||
|
||||
def after_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
training_data: TrainingSample,
|
||||
) -> None:
|
||||
stats.update(self.stats)
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
|
||||
Reference in New Issue
Block a user