Primal: Refactoring

master
Alinson S. Xavier 5 years ago
parent 4f46866921
commit 0f5a6745a4

@ -26,7 +26,13 @@ from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats
from miplearn.types import (
TrainingSample,
VarIndex,
Solution,
LearningSolveStats,
Features,
)
logger = logging.getLogger(__name__)
@ -126,7 +132,7 @@ class PrimalSolutionComponent(Component):
solution[var_name][idx] = None
# Compute y_pred
x = self.x_sample(instance, sample)
x = self.x_sample(instance.features, sample)
y_pred = {}
for category in x.keys():
assert category in self.classifiers, (
@ -213,34 +219,41 @@ class PrimalSolutionComponent(Component):
assert sample["Solution"] is not None
return cast(
Tuple[Dict, Dict],
PrimalSolutionComponent._extract(
instance,
sample,
sample["Solution"],
),
PrimalSolutionComponent._extract(instance.features, sample),
)
@staticmethod
def x_sample(
instance: Any,
features: Features,
sample: TrainingSample,
) -> Dict:
return cast(Dict, PrimalSolutionComponent._extract(instance, sample))
return cast(Dict, PrimalSolutionComponent._extract(features, sample))
@staticmethod
def _extract(
instance: Any,
features: Features,
sample: TrainingSample,
solution: Optional[Dict] = None,
) -> Union[Dict, Tuple[Dict, Dict]]:
x: Dict = {}
y: Dict = {}
opt_value = 0.0
for (var_name, var_dict) in instance.features["Variables"].items():
solution: Optional[Solution] = None
if "Solution" in sample and sample["Solution"] is not None:
solution = sample["Solution"]
for (var_name, var_dict) in features["Variables"].items():
for (idx, var_features) in var_dict.items():
category = var_features["Category"]
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
f = var_features["User features"]
assert f is not None
if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var_name][idx]
if lp_value is not None:
f += [lp_value]
x[category] += [f]
if solution is not None:
opt_value = solution[var_name][idx]
assert opt_value is not None
@ -250,16 +263,6 @@ class PrimalSolutionComponent(Component):
"variables is not currently supported. Please set its "
"category to None."
)
if category not in x.keys():
x[category] = []
y[category] = []
features = var_features["User features"]
if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var_name][idx]
if lp_value is not None:
features += [sample["LP solution"][var_name][idx]]
x[category] += [features]
if solution is not None:
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
if solution is not None:
return x, y

@ -6,7 +6,7 @@ import numbers
import collections
from typing import TYPE_CHECKING, Dict
from miplearn.types import ModelFeatures, ConstraintFeatures
from miplearn.types import Features, ConstraintFeatures
if TYPE_CHECKING:
from miplearn import InternalSolver, Instance
@ -19,7 +19,7 @@ class FeaturesExtractor:
) -> None:
self.solver = internal_solver
def extract(self, instance: "Instance") -> ModelFeatures:
def extract(self, instance: "Instance") -> Features:
return {
"Constraints": self._extract_constraints(instance),
"Variables": self._extract_variables(instance),

@ -9,7 +9,7 @@ from typing import Any, List, Optional, Hashable
import numpy as np
from miplearn.types import TrainingSample, VarIndex, ModelFeatures
from miplearn.types import TrainingSample, VarIndex, Features
# noinspection PyMethodMayBeStatic
@ -27,7 +27,7 @@ class Instance(ABC):
def __init__(self) -> None:
self.training_data: List[TrainingSample] = []
self.features: ModelFeatures = {}
self.features: Features = {}
@abstractmethod
def to_model(self) -> Any:

@ -94,8 +94,8 @@ ConstraintFeatures = TypedDict(
total=False,
)
ModelFeatures = TypedDict(
"ModelFeatures",
Features = TypedDict(
"Features",
{
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
"Constraints": Dict[str, ConstraintFeatures],

Loading…
Cancel
Save