mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Primal: Refactoring
This commit is contained in:
@@ -26,7 +26,13 @@ from miplearn.components import classifier_evaluation_dict
|
|||||||
from miplearn.components.component import Component
|
from miplearn.components.component import Component
|
||||||
from miplearn.extractors import InstanceIterator
|
from miplearn.extractors import InstanceIterator
|
||||||
from miplearn.instance import Instance
|
from miplearn.instance import Instance
|
||||||
from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats
|
from miplearn.types import (
|
||||||
|
TrainingSample,
|
||||||
|
VarIndex,
|
||||||
|
Solution,
|
||||||
|
LearningSolveStats,
|
||||||
|
Features,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -126,7 +132,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
solution[var_name][idx] = None
|
solution[var_name][idx] = None
|
||||||
|
|
||||||
# Compute y_pred
|
# Compute y_pred
|
||||||
x = self.x_sample(instance, sample)
|
x = self.x_sample(instance.features, sample)
|
||||||
y_pred = {}
|
y_pred = {}
|
||||||
for category in x.keys():
|
for category in x.keys():
|
||||||
assert category in self.classifiers, (
|
assert category in self.classifiers, (
|
||||||
@@ -213,34 +219,41 @@ class PrimalSolutionComponent(Component):
|
|||||||
assert sample["Solution"] is not None
|
assert sample["Solution"] is not None
|
||||||
return cast(
|
return cast(
|
||||||
Tuple[Dict, Dict],
|
Tuple[Dict, Dict],
|
||||||
PrimalSolutionComponent._extract(
|
PrimalSolutionComponent._extract(instance.features, sample),
|
||||||
instance,
|
|
||||||
sample,
|
|
||||||
sample["Solution"],
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def x_sample(
|
def x_sample(
|
||||||
instance: Any,
|
features: Features,
|
||||||
sample: TrainingSample,
|
sample: TrainingSample,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
return cast(Dict, PrimalSolutionComponent._extract(instance, sample))
|
return cast(Dict, PrimalSolutionComponent._extract(features, sample))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract(
|
def _extract(
|
||||||
instance: Any,
|
features: Features,
|
||||||
sample: TrainingSample,
|
sample: TrainingSample,
|
||||||
solution: Optional[Dict] = None,
|
|
||||||
) -> Union[Dict, Tuple[Dict, Dict]]:
|
) -> Union[Dict, Tuple[Dict, Dict]]:
|
||||||
x: Dict = {}
|
x: Dict = {}
|
||||||
y: Dict = {}
|
y: Dict = {}
|
||||||
opt_value = 0.0
|
solution: Optional[Solution] = None
|
||||||
for (var_name, var_dict) in instance.features["Variables"].items():
|
if "Solution" in sample and sample["Solution"] is not None:
|
||||||
|
solution = sample["Solution"]
|
||||||
|
for (var_name, var_dict) in features["Variables"].items():
|
||||||
for (idx, var_features) in var_dict.items():
|
for (idx, var_features) in var_dict.items():
|
||||||
category = var_features["Category"]
|
category = var_features["Category"]
|
||||||
if category is None:
|
if category is None:
|
||||||
continue
|
continue
|
||||||
|
if category not in x.keys():
|
||||||
|
x[category] = []
|
||||||
|
y[category] = []
|
||||||
|
f = var_features["User features"]
|
||||||
|
assert f is not None
|
||||||
|
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||||
|
lp_value = sample["LP solution"][var_name][idx]
|
||||||
|
if lp_value is not None:
|
||||||
|
f += [lp_value]
|
||||||
|
x[category] += [f]
|
||||||
if solution is not None:
|
if solution is not None:
|
||||||
opt_value = solution[var_name][idx]
|
opt_value = solution[var_name][idx]
|
||||||
assert opt_value is not None
|
assert opt_value is not None
|
||||||
@@ -250,16 +263,6 @@ class PrimalSolutionComponent(Component):
|
|||||||
"variables is not currently supported. Please set its "
|
"variables is not currently supported. Please set its "
|
||||||
"category to None."
|
"category to None."
|
||||||
)
|
)
|
||||||
if category not in x.keys():
|
|
||||||
x[category] = []
|
|
||||||
y[category] = []
|
|
||||||
features = var_features["User features"]
|
|
||||||
if "LP solution" in sample and sample["LP solution"] is not None:
|
|
||||||
lp_value = sample["LP solution"][var_name][idx]
|
|
||||||
if lp_value is not None:
|
|
||||||
features += [sample["LP solution"][var_name][idx]]
|
|
||||||
x[category] += [features]
|
|
||||||
if solution is not None:
|
|
||||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||||
if solution is not None:
|
if solution is not None:
|
||||||
return x, y
|
return x, y
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import numbers
|
|||||||
import collections
|
import collections
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from miplearn.types import ModelFeatures, ConstraintFeatures
|
from miplearn.types import Features, ConstraintFeatures
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from miplearn import InternalSolver, Instance
|
from miplearn import InternalSolver, Instance
|
||||||
@@ -19,7 +19,7 @@ class FeaturesExtractor:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.solver = internal_solver
|
self.solver = internal_solver
|
||||||
|
|
||||||
def extract(self, instance: "Instance") -> ModelFeatures:
|
def extract(self, instance: "Instance") -> Features:
|
||||||
return {
|
return {
|
||||||
"Constraints": self._extract_constraints(instance),
|
"Constraints": self._extract_constraints(instance),
|
||||||
"Variables": self._extract_variables(instance),
|
"Variables": self._extract_variables(instance),
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Any, List, Optional, Hashable
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from miplearn.types import TrainingSample, VarIndex, ModelFeatures
|
from miplearn.types import TrainingSample, VarIndex, Features
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyMethodMayBeStatic
|
# noinspection PyMethodMayBeStatic
|
||||||
@@ -27,7 +27,7 @@ class Instance(ABC):
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.training_data: List[TrainingSample] = []
|
self.training_data: List[TrainingSample] = []
|
||||||
self.features: ModelFeatures = {}
|
self.features: Features = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_model(self) -> Any:
|
def to_model(self) -> Any:
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ ConstraintFeatures = TypedDict(
|
|||||||
total=False,
|
total=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelFeatures = TypedDict(
|
Features = TypedDict(
|
||||||
"ModelFeatures",
|
"Features",
|
||||||
{
|
{
|
||||||
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
|
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
|
||||||
"Constraints": Dict[str, ConstraintFeatures],
|
"Constraints": Dict[str, ConstraintFeatures],
|
||||||
|
|||||||
Reference in New Issue
Block a user