From 0f5a6745a4e87344be9e256724d242fe0a480542 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Wed, 31 Mar 2021 09:08:01 -0500 Subject: [PATCH] Primal: Refactoring --- miplearn/components/primal.py | 49 +++++++++++++++++++---------------- miplearn/features.py | 4 +-- miplearn/instance.py | 4 +-- miplearn/types.py | 4 +-- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index dea597d..8c7736b 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -26,7 +26,13 @@ from miplearn.components import classifier_evaluation_dict from miplearn.components.component import Component from miplearn.extractors import InstanceIterator from miplearn.instance import Instance -from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats +from miplearn.types import ( + TrainingSample, + VarIndex, + Solution, + LearningSolveStats, + Features, +) logger = logging.getLogger(__name__) @@ -126,7 +132,7 @@ class PrimalSolutionComponent(Component): solution[var_name][idx] = None # Compute y_pred - x = self.x_sample(instance, sample) + x = self.x_sample(instance.features, sample) y_pred = {} for category in x.keys(): assert category in self.classifiers, ( @@ -213,34 +219,41 @@ class PrimalSolutionComponent(Component): assert sample["Solution"] is not None return cast( Tuple[Dict, Dict], - PrimalSolutionComponent._extract( - instance, - sample, - sample["Solution"], - ), + PrimalSolutionComponent._extract(instance.features, sample), ) @staticmethod def x_sample( - instance: Any, + features: Features, sample: TrainingSample, ) -> Dict: - return cast(Dict, PrimalSolutionComponent._extract(instance, sample)) + return cast(Dict, PrimalSolutionComponent._extract(features, sample)) @staticmethod def _extract( - instance: Any, + features: Features, sample: TrainingSample, - solution: Optional[Dict] = None, ) -> Union[Dict, Tuple[Dict, Dict]]: x: Dict = {} y: Dict = {} - opt_value = 0.0 - for (var_name, var_dict) in instance.features["Variables"].items(): + solution: Optional[Solution] = None + if "Solution" in sample and sample["Solution"] is not None: + solution = sample["Solution"] + for (var_name, var_dict) in features["Variables"].items(): for (idx, var_features) in var_dict.items(): category = var_features["Category"] if category is None: continue + if category not in x.keys(): + x[category] = [] + y[category] = [] + f = var_features["User features"] + assert f is not None + if "LP solution" in sample and sample["LP solution"] is not None: + lp_value = sample["LP solution"][var_name][idx] + if lp_value is not None: + f += [lp_value] + x[category] += [f] if solution is not None: opt_value = solution[var_name][idx] assert opt_value is not None @@ -250,16 +263,6 @@ class PrimalSolutionComponent(Component): "variables is not currently supported. Please set its " "category to None." ) - if category not in x.keys(): - x[category] = [] - y[category] = [] - features = var_features["User features"] - if "LP solution" in sample and sample["LP solution"] is not None: - lp_value = sample["LP solution"][var_name][idx] - if lp_value is not None: - features += [sample["LP solution"][var_name][idx]] - x[category] += [features] - if solution is not None: y[category] += [[opt_value < 0.5, opt_value >= 0.5]] if solution is not None: return x, y diff --git a/miplearn/features.py b/miplearn/features.py index 6c693dd..d4ac84f 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -6,7 +6,7 @@ import numbers import collections from typing import TYPE_CHECKING, Dict -from miplearn.types import ModelFeatures, ConstraintFeatures +from miplearn.types import Features, ConstraintFeatures if TYPE_CHECKING: from miplearn import InternalSolver, Instance @@ -19,7 +19,7 @@ class FeaturesExtractor: ) -> None: self.solver = internal_solver - def extract(self, instance: "Instance") -> ModelFeatures: + def extract(self, instance: "Instance") -> Features: return { "Constraints": self._extract_constraints(instance), "Variables": self._extract_variables(instance), diff --git a/miplearn/instance.py b/miplearn/instance.py index 3121dd8..111c7af 100644 --- a/miplearn/instance.py +++ b/miplearn/instance.py @@ -9,7 +9,7 @@ from typing import Any, List, Optional, Hashable import numpy as np -from miplearn.types import TrainingSample, VarIndex, ModelFeatures +from miplearn.types import TrainingSample, VarIndex, Features # noinspection PyMethodMayBeStatic @@ -27,7 +27,7 @@ class Instance(ABC): def __init__(self) -> None: self.training_data: List[TrainingSample] = [] - self.features: ModelFeatures = {} + self.features: Features = {} @abstractmethod def to_model(self) -> Any: diff --git a/miplearn/types.py b/miplearn/types.py index 1fc0a23..de32cb8 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -94,8 +94,8 @@ ConstraintFeatures = TypedDict( total=False, ) -ModelFeatures = TypedDict( - "ModelFeatures", +Features = TypedDict( + "Features", { "Variables": Dict[str, Dict[VarIndex, VariableFeatures]], "Constraints": Dict[str, ConstraintFeatures],