mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Primal: Refactoring
This commit is contained in:
@@ -26,7 +26,13 @@ from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.extractors import InstanceIterator
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats
|
||||
from miplearn.types import (
|
||||
TrainingSample,
|
||||
VarIndex,
|
||||
Solution,
|
||||
LearningSolveStats,
|
||||
Features,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,7 +132,7 @@ class PrimalSolutionComponent(Component):
|
||||
solution[var_name][idx] = None
|
||||
|
||||
# Compute y_pred
|
||||
x = self.x_sample(instance, sample)
|
||||
x = self.x_sample(instance.features, sample)
|
||||
y_pred = {}
|
||||
for category in x.keys():
|
||||
assert category in self.classifiers, (
|
||||
@@ -213,34 +219,41 @@ class PrimalSolutionComponent(Component):
|
||||
assert sample["Solution"] is not None
|
||||
return cast(
|
||||
Tuple[Dict, Dict],
|
||||
PrimalSolutionComponent._extract(
|
||||
instance,
|
||||
sample,
|
||||
sample["Solution"],
|
||||
),
|
||||
PrimalSolutionComponent._extract(instance.features, sample),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def x_sample(
|
||||
instance: Any,
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Dict:
|
||||
return cast(Dict, PrimalSolutionComponent._extract(instance, sample))
|
||||
return cast(Dict, PrimalSolutionComponent._extract(features, sample))
|
||||
|
||||
@staticmethod
|
||||
def _extract(
|
||||
instance: Any,
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
solution: Optional[Dict] = None,
|
||||
) -> Union[Dict, Tuple[Dict, Dict]]:
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
opt_value = 0.0
|
||||
for (var_name, var_dict) in instance.features["Variables"].items():
|
||||
solution: Optional[Solution] = None
|
||||
if "Solution" in sample and sample["Solution"] is not None:
|
||||
solution = sample["Solution"]
|
||||
for (var_name, var_dict) in features["Variables"].items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features["Category"]
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
f = var_features["User features"]
|
||||
assert f is not None
|
||||
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||
lp_value = sample["LP solution"][var_name][idx]
|
||||
if lp_value is not None:
|
||||
f += [lp_value]
|
||||
x[category] += [f]
|
||||
if solution is not None:
|
||||
opt_value = solution[var_name][idx]
|
||||
assert opt_value is not None
|
||||
@@ -250,16 +263,6 @@ class PrimalSolutionComponent(Component):
|
||||
"variables is not currently supported. Please set its "
|
||||
"category to None."
|
||||
)
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
features = var_features["User features"]
|
||||
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||
lp_value = sample["LP solution"][var_name][idx]
|
||||
if lp_value is not None:
|
||||
features += [sample["LP solution"][var_name][idx]]
|
||||
x[category] += [features]
|
||||
if solution is not None:
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
if solution is not None:
|
||||
return x, y
|
||||
|
||||
@@ -6,7 +6,7 @@ import numbers
|
||||
import collections
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from miplearn.types import ModelFeatures, ConstraintFeatures
|
||||
from miplearn.types import Features, ConstraintFeatures
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn import InternalSolver, Instance
|
||||
@@ -19,7 +19,7 @@ class FeaturesExtractor:
|
||||
) -> None:
|
||||
self.solver = internal_solver
|
||||
|
||||
def extract(self, instance: "Instance") -> ModelFeatures:
|
||||
def extract(self, instance: "Instance") -> Features:
|
||||
return {
|
||||
"Constraints": self._extract_constraints(instance),
|
||||
"Variables": self._extract_variables(instance),
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, List, Optional, Hashable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from miplearn.types import TrainingSample, VarIndex, ModelFeatures
|
||||
from miplearn.types import TrainingSample, VarIndex, Features
|
||||
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
@@ -27,7 +27,7 @@ class Instance(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.training_data: List[TrainingSample] = []
|
||||
self.features: ModelFeatures = {}
|
||||
self.features: Features = {}
|
||||
|
||||
@abstractmethod
|
||||
def to_model(self) -> Any:
|
||||
|
||||
@@ -94,8 +94,8 @@ ConstraintFeatures = TypedDict(
|
||||
total=False,
|
||||
)
|
||||
|
||||
ModelFeatures = TypedDict(
|
||||
"ModelFeatures",
|
||||
Features = TypedDict(
|
||||
"Features",
|
||||
{
|
||||
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
|
||||
"Constraints": Dict[str, ConstraintFeatures],
|
||||
|
||||
Reference in New Issue
Block a user