Refactor PrimalSolutionComponent

master
Alinson S. Xavier 5 years ago
parent ec69464794
commit b3c24814b0

@ -13,6 +13,7 @@ from typing import (
Any, Any,
TYPE_CHECKING, TYPE_CHECKING,
Tuple, Tuple,
cast,
) )
import numpy as np import numpy as np
@ -63,29 +64,30 @@ class PrimalSolutionComponent(Component):
self._n_one = 0 self._n_one = 0
def before_solve_mip(self, solver, instance, model): def before_solve_mip(self, solver, instance, model):
if len(self.thresholds) > 0: pass
logger.info("Predicting primal solution...") # if len(self.thresholds) > 0:
solution = self.predict(instance) # logger.info("Predicting primal solution...")
# solution = self.predict(instance)
# Collect prediction statistics #
self._n_free = 0 # # Collect prediction statistics
self._n_zero = 0 # self._n_free = 0
self._n_one = 0 # self._n_zero = 0
for (var, var_dict) in solution.items(): # self._n_one = 0
for (idx, value) in var_dict.items(): # for (var, var_dict) in solution.items():
if value is None: # for (idx, value) in var_dict.items():
self._n_free += 1 # if value is None:
else: # self._n_free += 1
if value < 0.5: # else:
self._n_zero += 1 # if value < 0.5:
else: # self._n_zero += 1
self._n_one += 1 # else:
# self._n_one += 1
# Provide solution to the solver #
if self.mode == "heuristic": # # Provide solution to the solver
solver.internal_solver.fix(solution) # if self.mode == "heuristic":
else: # solver.internal_solver.fix(solution)
solver.internal_solver.set_warm_start(solution) # else:
# solver.internal_solver.set_warm_start(solution)
def after_solve_mip( def after_solve_mip(
self, self,
@ -214,43 +216,56 @@ class PrimalSolutionComponent(Component):
if "Solution" not in sample: if "Solution" not in sample:
return x, y return x, y
assert sample["Solution"] is not None assert sample["Solution"] is not None
for (var, var_dict) in sample["Solution"].items(): return cast(
for (idx, opt_value) in var_dict.items(): Tuple[Dict, Dict],
assert opt_value is not None PrimalSolutionComponent._extract(
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( instance,
f"Variable {var} has non-binary value {opt_value} in the optimal " sample,
f"solution. Predicting values of non-binary variables is not " sample["Solution"],
f"currently supported. Please set its category to None." extract_y=True,
),
) )
category = instance.get_variable_category(var, idx)
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
features: Any = instance.get_variable_features(var, idx)
assert isinstance(features, list)
if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var][idx]
if lp_value is not None:
features += [sample["LP solution"][var][idx]]
x[category] += [features]
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
return x, y
@staticmethod @staticmethod
def x_sample( def x_sample(
instance: Any, instance: Any,
sample: TrainingSample, sample: TrainingSample,
) -> Dict: ) -> Dict:
return cast(
Dict,
PrimalSolutionComponent._extract(
instance,
sample,
instance.model_features["Variables"],
extract_y=False,
),
)
@staticmethod
def _extract(
instance: Any,
sample: TrainingSample,
variables: Dict,
extract_y: bool,
) -> Union[Dict, Tuple[Dict, Dict]]:
x: Dict = {} x: Dict = {}
for (var, var_dict) in instance.model_features["Variables"].items(): y: Dict = {}
for idx in var_dict.keys(): for (var, var_dict) in variables.items():
for (idx, opt_value) in var_dict.items():
if extract_y:
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to None."
)
category = instance.get_variable_category(var, idx) category = instance.get_variable_category(var, idx)
if category is None: if category is None:
continue continue
if category not in x.keys(): if category not in x.keys():
x[category] = [] x[category] = []
y[category] = []
features: Any = instance.get_variable_features(var, idx) features: Any = instance.get_variable_features(var, idx)
assert isinstance(features, list) assert isinstance(features, list)
if "LP solution" in sample and sample["LP solution"] is not None: if "LP solution" in sample and sample["LP solution"] is not None:
@ -258,6 +273,9 @@ class PrimalSolutionComponent(Component):
if lp_value is not None: if lp_value is not None:
features += [sample["LP solution"][var][idx]] features += [sample["LP solution"][var][idx]]
x[category] += [features] x[category] += [features]
for category in x.keys(): if extract_y:
x[category] = np.array(x[category]) y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
if extract_y:
return x, y
else:
return x return x

Loading…
Cancel
Save