mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Refactor PrimalSolutionComponent
This commit is contained in:
@@ -13,6 +13,7 @@ from typing import (
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@@ -63,29 +64,30 @@ class PrimalSolutionComponent(Component):
|
||||
self._n_one = 0
|
||||
|
||||
def before_solve_mip(self, solver, instance, model):
|
||||
if len(self.thresholds) > 0:
|
||||
logger.info("Predicting primal solution...")
|
||||
solution = self.predict(instance)
|
||||
|
||||
# Collect prediction statistics
|
||||
self._n_free = 0
|
||||
self._n_zero = 0
|
||||
self._n_one = 0
|
||||
for (var, var_dict) in solution.items():
|
||||
for (idx, value) in var_dict.items():
|
||||
if value is None:
|
||||
self._n_free += 1
|
||||
else:
|
||||
if value < 0.5:
|
||||
self._n_zero += 1
|
||||
else:
|
||||
self._n_one += 1
|
||||
|
||||
# Provide solution to the solver
|
||||
if self.mode == "heuristic":
|
||||
solver.internal_solver.fix(solution)
|
||||
else:
|
||||
solver.internal_solver.set_warm_start(solution)
|
||||
pass
|
||||
# if len(self.thresholds) > 0:
|
||||
# logger.info("Predicting primal solution...")
|
||||
# solution = self.predict(instance)
|
||||
#
|
||||
# # Collect prediction statistics
|
||||
# self._n_free = 0
|
||||
# self._n_zero = 0
|
||||
# self._n_one = 0
|
||||
# for (var, var_dict) in solution.items():
|
||||
# for (idx, value) in var_dict.items():
|
||||
# if value is None:
|
||||
# self._n_free += 1
|
||||
# else:
|
||||
# if value < 0.5:
|
||||
# self._n_zero += 1
|
||||
# else:
|
||||
# self._n_one += 1
|
||||
#
|
||||
# # Provide solution to the solver
|
||||
# if self.mode == "heuristic":
|
||||
# solver.internal_solver.fix(solution)
|
||||
# else:
|
||||
# solver.internal_solver.set_warm_start(solution)
|
||||
|
||||
def after_solve_mip(
|
||||
self,
|
||||
@@ -214,14 +216,50 @@ class PrimalSolutionComponent(Component):
|
||||
if "Solution" not in sample:
|
||||
return x, y
|
||||
assert sample["Solution"] is not None
|
||||
for (var, var_dict) in sample["Solution"].items():
|
||||
return cast(
|
||||
Tuple[Dict, Dict],
|
||||
PrimalSolutionComponent._extract(
|
||||
instance,
|
||||
sample,
|
||||
sample["Solution"],
|
||||
extract_y=True,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def x_sample(
|
||||
instance: Any,
|
||||
sample: TrainingSample,
|
||||
) -> Dict:
|
||||
return cast(
|
||||
Dict,
|
||||
PrimalSolutionComponent._extract(
|
||||
instance,
|
||||
sample,
|
||||
instance.model_features["Variables"],
|
||||
extract_y=False,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract(
|
||||
instance: Any,
|
||||
sample: TrainingSample,
|
||||
variables: Dict,
|
||||
extract_y: bool,
|
||||
) -> Union[Dict, Tuple[Dict, Dict]]:
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
for (var, var_dict) in variables.items():
|
||||
for (idx, opt_value) in var_dict.items():
|
||||
assert opt_value is not None
|
||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
||||
f"Variable {var} has non-binary value {opt_value} in the optimal "
|
||||
f"solution. Predicting values of non-binary variables is not "
|
||||
f"currently supported. Please set its category to None."
|
||||
)
|
||||
if extract_y:
|
||||
assert opt_value is not None
|
||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
||||
f"Variable {var} has non-binary value {opt_value} in the "
|
||||
"optimal solution. Predicting values of non-binary "
|
||||
"variables is not currently supported. Please set its "
|
||||
"category to None."
|
||||
)
|
||||
category = instance.get_variable_category(var, idx)
|
||||
if category is None:
|
||||
continue
|
||||
@@ -235,29 +273,9 @@ class PrimalSolutionComponent(Component):
|
||||
if lp_value is not None:
|
||||
features += [sample["LP solution"][var][idx]]
|
||||
x[category] += [features]
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
return x, y
|
||||
|
||||
@staticmethod
|
||||
def x_sample(
|
||||
instance: Any,
|
||||
sample: TrainingSample,
|
||||
) -> Dict:
|
||||
x: Dict = {}
|
||||
for (var, var_dict) in instance.model_features["Variables"].items():
|
||||
for idx in var_dict.keys():
|
||||
category = instance.get_variable_category(var, idx)
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
features: Any = instance.get_variable_features(var, idx)
|
||||
assert isinstance(features, list)
|
||||
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||
lp_value = sample["LP solution"][var][idx]
|
||||
if lp_value is not None:
|
||||
features += [sample["LP solution"][var][idx]]
|
||||
x[category] += [features]
|
||||
for category in x.keys():
|
||||
x[category] = np.array(x[category])
|
||||
return x
|
||||
if extract_y:
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
if extract_y:
|
||||
return x, y
|
||||
else:
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user