Refactor PrimalSolutionComponent

master
Alinson S. Xavier 5 years ago
parent ec69464794
commit b3c24814b0

@ -13,6 +13,7 @@ from typing import (
Any,
TYPE_CHECKING,
Tuple,
cast,
)
import numpy as np
@ -63,29 +64,30 @@ class PrimalSolutionComponent(Component):
self._n_one = 0
def before_solve_mip(self, solver, instance, model):
if len(self.thresholds) > 0:
logger.info("Predicting primal solution...")
solution = self.predict(instance)
# Collect prediction statistics
self._n_free = 0
self._n_zero = 0
self._n_one = 0
for (var, var_dict) in solution.items():
for (idx, value) in var_dict.items():
if value is None:
self._n_free += 1
else:
if value < 0.5:
self._n_zero += 1
else:
self._n_one += 1
# Provide solution to the solver
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
solver.internal_solver.set_warm_start(solution)
pass
# if len(self.thresholds) > 0:
# logger.info("Predicting primal solution...")
# solution = self.predict(instance)
#
# # Collect prediction statistics
# self._n_free = 0
# self._n_zero = 0
# self._n_one = 0
# for (var, var_dict) in solution.items():
# for (idx, value) in var_dict.items():
# if value is None:
# self._n_free += 1
# else:
# if value < 0.5:
# self._n_zero += 1
# else:
# self._n_one += 1
#
# # Provide solution to the solver
# if self.mode == "heuristic":
# solver.internal_solver.fix(solution)
# else:
# solver.internal_solver.set_warm_start(solution)
def after_solve_mip(
self,
@ -214,43 +216,56 @@ class PrimalSolutionComponent(Component):
if "Solution" not in sample:
return x, y
assert sample["Solution"] is not None
for (var, var_dict) in sample["Solution"].items():
for (idx, opt_value) in var_dict.items():
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var} has non-binary value {opt_value} in the optimal "
f"solution. Predicting values of non-binary variables is not "
f"currently supported. Please set its category to None."
return cast(
Tuple[Dict, Dict],
PrimalSolutionComponent._extract(
instance,
sample,
sample["Solution"],
extract_y=True,
),
)
category = instance.get_variable_category(var, idx)
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
features: Any = instance.get_variable_features(var, idx)
assert isinstance(features, list)
if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var][idx]
if lp_value is not None:
features += [sample["LP solution"][var][idx]]
x[category] += [features]
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
return x, y
@staticmethod
def x_sample(
instance: Any,
sample: TrainingSample,
) -> Dict:
return cast(
Dict,
PrimalSolutionComponent._extract(
instance,
sample,
instance.model_features["Variables"],
extract_y=False,
),
)
@staticmethod
def _extract(
instance: Any,
sample: TrainingSample,
variables: Dict,
extract_y: bool,
) -> Union[Dict, Tuple[Dict, Dict]]:
x: Dict = {}
for (var, var_dict) in instance.model_features["Variables"].items():
for idx in var_dict.keys():
y: Dict = {}
for (var, var_dict) in variables.items():
for (idx, opt_value) in var_dict.items():
if extract_y:
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to None."
)
category = instance.get_variable_category(var, idx)
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
features: Any = instance.get_variable_features(var, idx)
assert isinstance(features, list)
if "LP solution" in sample and sample["LP solution"] is not None:
@ -258,6 +273,9 @@ class PrimalSolutionComponent(Component):
if lp_value is not None:
features += [sample["LP solution"][var][idx]]
x[category] += [features]
for category in x.keys():
x[category] = np.array(x[category])
if extract_y:
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
if extract_y:
return x, y
else:
return x

Loading…
Cancel
Save