|
|
|
@ -26,7 +26,13 @@ from miplearn.components import classifier_evaluation_dict
|
|
|
|
|
from miplearn.components.component import Component
|
|
|
|
|
from miplearn.extractors import InstanceIterator
|
|
|
|
|
from miplearn.instance import Instance
|
|
|
|
|
from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats
|
|
|
|
|
from miplearn.types import (
|
|
|
|
|
TrainingSample,
|
|
|
|
|
VarIndex,
|
|
|
|
|
Solution,
|
|
|
|
|
LearningSolveStats,
|
|
|
|
|
Features,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@ -126,7 +132,7 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
solution[var_name][idx] = None
|
|
|
|
|
|
|
|
|
|
# Compute y_pred
|
|
|
|
|
x = self.x_sample(instance, sample)
|
|
|
|
|
x = self.x_sample(instance.features, sample)
|
|
|
|
|
y_pred = {}
|
|
|
|
|
for category in x.keys():
|
|
|
|
|
assert category in self.classifiers, (
|
|
|
|
@ -213,34 +219,41 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
assert sample["Solution"] is not None
|
|
|
|
|
return cast(
|
|
|
|
|
Tuple[Dict, Dict],
|
|
|
|
|
PrimalSolutionComponent._extract(
|
|
|
|
|
instance,
|
|
|
|
|
sample,
|
|
|
|
|
sample["Solution"],
|
|
|
|
|
),
|
|
|
|
|
PrimalSolutionComponent._extract(instance.features, sample),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def x_sample(
|
|
|
|
|
instance: Any,
|
|
|
|
|
features: Features,
|
|
|
|
|
sample: TrainingSample,
|
|
|
|
|
) -> Dict:
|
|
|
|
|
return cast(Dict, PrimalSolutionComponent._extract(instance, sample))
|
|
|
|
|
return cast(Dict, PrimalSolutionComponent._extract(features, sample))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _extract(
|
|
|
|
|
instance: Any,
|
|
|
|
|
features: Features,
|
|
|
|
|
sample: TrainingSample,
|
|
|
|
|
solution: Optional[Dict] = None,
|
|
|
|
|
) -> Union[Dict, Tuple[Dict, Dict]]:
|
|
|
|
|
x: Dict = {}
|
|
|
|
|
y: Dict = {}
|
|
|
|
|
opt_value = 0.0
|
|
|
|
|
for (var_name, var_dict) in instance.features["Variables"].items():
|
|
|
|
|
solution: Optional[Solution] = None
|
|
|
|
|
if "Solution" in sample and sample["Solution"] is not None:
|
|
|
|
|
solution = sample["Solution"]
|
|
|
|
|
for (var_name, var_dict) in features["Variables"].items():
|
|
|
|
|
for (idx, var_features) in var_dict.items():
|
|
|
|
|
category = var_features["Category"]
|
|
|
|
|
if category is None:
|
|
|
|
|
continue
|
|
|
|
|
if category not in x.keys():
|
|
|
|
|
x[category] = []
|
|
|
|
|
y[category] = []
|
|
|
|
|
f = var_features["User features"]
|
|
|
|
|
assert f is not None
|
|
|
|
|
if "LP solution" in sample and sample["LP solution"] is not None:
|
|
|
|
|
lp_value = sample["LP solution"][var_name][idx]
|
|
|
|
|
if lp_value is not None:
|
|
|
|
|
f += [lp_value]
|
|
|
|
|
x[category] += [f]
|
|
|
|
|
if solution is not None:
|
|
|
|
|
opt_value = solution[var_name][idx]
|
|
|
|
|
assert opt_value is not None
|
|
|
|
@ -250,16 +263,6 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
"variables is not currently supported. Please set its "
|
|
|
|
|
"category to None."
|
|
|
|
|
)
|
|
|
|
|
if category not in x.keys():
|
|
|
|
|
x[category] = []
|
|
|
|
|
y[category] = []
|
|
|
|
|
features = var_features["User features"]
|
|
|
|
|
if "LP solution" in sample and sample["LP solution"] is not None:
|
|
|
|
|
lp_value = sample["LP solution"][var_name][idx]
|
|
|
|
|
if lp_value is not None:
|
|
|
|
|
features += [sample["LP solution"][var_name][idx]]
|
|
|
|
|
x[category] += [features]
|
|
|
|
|
if solution is not None:
|
|
|
|
|
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
|
|
|
|
if solution is not None:
|
|
|
|
|
return x, y
|
|
|
|
|