Make all before/solve callbacks receive same parameters

This commit is contained in:
2021-04-02 07:05:16 -05:00
parent 8eb2b63a85
commit 0c687692f7
17 changed files with 201 additions and 189 deletions

View File

@@ -62,9 +62,16 @@ class PrimalSolutionComponent(Component):
self.thresholds: Dict[Hashable, Threshold] = {}
self.threshold_prototype = threshold
self.classifier_prototype = classifier
self.stats: Dict[str, float] = {}
def before_solve_mip(self, solver, instance, model):
def before_solve_mip(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
) -> None:
if len(self.thresholds) > 0:
logger.info("Predicting MIP solution...")
solution = self.predict(
@@ -72,41 +79,32 @@ class PrimalSolutionComponent(Component):
instance.training_data[-1],
)
# Collect prediction statistics
self.stats["Primal: Free"] = 0
self.stats["Primal: Zero"] = 0
self.stats["Primal: One"] = 0
# Update statistics
stats["Primal: Free"] = 0
stats["Primal: Zero"] = 0
stats["Primal: One"] = 0
for (var, var_dict) in solution.items():
for (idx, value) in var_dict.items():
if value is None:
self.stats["Primal: Free"] += 1
stats["Primal: Free"] += 1
else:
if value < 0.5:
self.stats["Primal: Zero"] += 1
stats["Primal: Zero"] += 1
else:
self.stats["Primal: One"] += 1
stats["Primal: One"] += 1
logger.info(
f"Predicted: free: {self.stats['Primal: Free']}, "
f"zero: {self.stats['Primal: zero']}, "
f"one: {self.stats['Primal: One']}"
f"Predicted: free: {stats['Primal: Free']}, "
f"zero: {stats['Primal: Zero']}, "
f"one: {stats['Primal: One']}"
)
# Provide solution to the solver
assert solver.internal_solver is not None
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
solver.internal_solver.set_warm_start(solution)
def after_solve_mip(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
stats: LearningSolveStats,
training_data: TrainingSample,
) -> None:
stats.update(self.stats)
def fit_xy(
self,
x: Dict[str, np.ndarray],