diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 78abb0b..5aa85db 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -236,7 +236,7 @@ class PrimalSolutionComponent(Component): PrimalSolutionComponent._extract( instance, sample, - instance.model_features["Variables"], + instance.features["Variables"], extract_y=False, ), ) diff --git a/miplearn/features.py b/miplearn/features.py index f9e4cb2..6c693dd 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -2,30 +2,84 @@ # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. +import numbers +import collections from typing import TYPE_CHECKING, Dict from miplearn.types import ModelFeatures, ConstraintFeatures if TYPE_CHECKING: - from miplearn import InternalSolver + from miplearn import InternalSolver, Instance -class ModelFeaturesExtractor: +class FeaturesExtractor: def __init__( self, internal_solver: "InternalSolver", ) -> None: self.solver = internal_solver - def extract(self) -> ModelFeatures: + def extract(self, instance: "Instance") -> ModelFeatures: + return { + "Constraints": self._extract_constraints(instance), + "Variables": self._extract_variables(instance), + } + + def _extract_variables(self, instance: "Instance") -> Dict: + variables = self.solver.get_empty_solution() + for (var_name, var_dict) in variables.items(): + for idx in var_dict.keys(): + user_features = None + category = instance.get_variable_category(var_name, idx) + if category is not None: + assert isinstance(category, collections.Hashable), ( + f"Variable category must be be hashable. " + f"Found {type(category).__name__} instead for var={var_name}." + ) + user_features = instance.get_variable_features(var_name, idx) + assert isinstance(user_features, list), ( + f"Variable features must be a list. " + f"Found {type(user_features).__name__} instead for " + f"var={var_name}." + ) + assert isinstance(user_features[0], numbers.Real), ( + f"Variable features must be a list of numbers." + f"Found {type(user_features[0]).__name__} instead " + f"for var={var_name}." + ) + var_dict[idx] = { + "Category": category, + "User features": user_features, + } + return variables + + def _extract_constraints( + self, + instance: "Instance", + ) -> Dict[str, ConstraintFeatures]: constraints: Dict[str, ConstraintFeatures] = {} for cid in self.solver.get_constraint_ids(): + user_features = None + category = instance.get_constraint_category(cid) + if category is not None: + assert isinstance(category, collections.Hashable), ( + f"Constraint category must be hashable. " + f"Found {type(category).__name__} instead for cid={cid}.", + ) + user_features = instance.get_constraint_features(cid) + assert isinstance(user_features, list), ( + f"Constraint features must be a list. " + f"Found {type(user_features).__name__} instead for cid={cid}." + ) + assert isinstance(user_features[0], float), ( + f"Constraint features must be a list of floats. " + f"Found {type(user_features[0]).__name__} instead for cid={cid}." + ) constraints[cid] = { "RHS": self.solver.get_constraint_rhs(cid), "LHS": self.solver.get_constraint_lhs(cid), "Sense": self.solver.get_constraint_sense(cid), + "Category": category, + "User features": user_features, } - return { - "Constraints": constraints, - "Variables": self.solver.get_empty_solution(), - } + return constraints diff --git a/miplearn/instance.py b/miplearn/instance.py index c30c2f1..3121dd8 100644 --- a/miplearn/instance.py +++ b/miplearn/instance.py @@ -12,6 +12,7 @@ import numpy as np from miplearn.types import TrainingSample, VarIndex, ModelFeatures +# noinspection PyMethodMayBeStatic class Instance(ABC): """ Abstract class holding all the data necessary to generate a concrete model of the @@ -26,7 +27,7 @@ class Instance(ABC): def __init__(self) -> None: self.training_data: List[TrainingSample] = [] - self.model_features: ModelFeatures = {} + self.features: ModelFeatures = {} @abstractmethod def to_model(self) -> Any: @@ -94,10 +95,10 @@ class Instance(ABC): """ return "default" - def get_constraint_features(self, cid): - return np.zeros(1) + def get_constraint_features(self, cid: str) -> Optional[List[float]]: + return [0.0] - def get_constraint_category(self, cid): + def get_constraint_category(self, cid: str) -> Optional[str]: return cid def has_static_lazy_constraints(self): diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 882b03d..ab9757e 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -17,7 +17,7 @@ from miplearn.components.cuts import UserCutsComponent from miplearn.components.lazy_dynamic import DynamicLazyConstraintsComponent from miplearn.components.objective import ObjectiveValueComponent from miplearn.components.primal import PrimalSolutionComponent -from miplearn.features import ModelFeaturesExtractor +from miplearn.features import FeaturesExtractor from miplearn.instance import Instance from miplearn.solvers import _RedirectOutput from miplearn.solvers.internal import InternalSolver @@ -174,9 +174,9 @@ class LearningSolver: assert isinstance(self.internal_solver, InternalSolver) self.internal_solver.set_instance(instance, model) - # Extract model features - extractor = ModelFeaturesExtractor(self.internal_solver) - instance.model_features = extractor.extract() + # Extract features + extractor = FeaturesExtractor(self.internal_solver) + instance.features = extractor.extract(instance) # Solve root LP relaxation if self.solve_lp: diff --git a/miplearn/types.py b/miplearn/types.py index 88d6283..acb8c71 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -2,7 +2,7 @@ # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. -from typing import Optional, Dict, Callable, Any, Union, Tuple, List, Set +from typing import Optional, Dict, Callable, Any, Union, Tuple, List, Set, Hashable from mypy_extensions import TypedDict @@ -79,6 +79,8 @@ ConstraintFeatures = TypedDict( "RHS": float, "LHS": Dict[str, float], "Sense": str, + "Category": Optional[Hashable], + "User features": Optional[List[float]], }, total=False, ) diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index 6065b0d..8914da3 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -153,7 +153,7 @@ def test_predict() -> None: 2: [2.0, 0.0], }[index] ) - instance.model_features = { + instance.features = { "Variables": { "x": { 0: None, diff --git a/tests/solvers/test_learning_solver.py b/tests/solvers/test_learning_solver.py index 168f356..1758067 100644 --- a/tests/solvers/test_learning_solver.py +++ b/tests/solvers/test_learning_solver.py @@ -27,7 +27,7 @@ def test_learning_solver(): solver.solve(instance) - assert hasattr(instance, "model_features") + assert hasattr(instance, "features") data = instance.training_data[0] assert data["Solution"]["x"][0] == 1.0 diff --git a/tests/test_features.py b/tests/test_features.py index 883d92e..e5be4c2 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -3,37 +3,47 @@ # Released under the modified BSD license. See COPYING.md for more details. from miplearn import GurobiSolver -from miplearn.features import ModelFeaturesExtractor +from miplearn.features import FeaturesExtractor from tests.fixtures.knapsack import get_knapsack_instance def test_knapsack() -> None: for solver_factory in [GurobiSolver]: - # Initialize model, instance and internal solver solver = solver_factory() instance = get_knapsack_instance(solver) model = instance.to_model() solver.set_instance(instance, model) - - # Extract all model features - extractor = ModelFeaturesExtractor(solver) - features = extractor.extract() - - # Test constraint features - print(solver, features) + extractor = FeaturesExtractor(solver) + features = extractor.extract(instance) assert features["Variables"] == { "x": { - 0: None, - 1: None, - 2: None, - 3: None, + 0: { + "Category": "default", + "User features": [23.0, 505.0], + }, + 1: { + "Category": "default", + "User features": [26.0, 352.0], + }, + 2: { + "Category": "default", + "User features": [20.0, 458.0], + }, + 3: { + "Category": "default", + "User features": [18.0, 220.0], + }, } } - assert features["Constraints"]["eq_capacity"]["LHS"] == { - "x[0]": 23.0, - "x[1]": 26.0, - "x[2]": 20.0, - "x[3]": 18.0, + assert features["Constraints"]["eq_capacity"] == { + "LHS": { + "x[0]": 23.0, + "x[1]": 26.0, + "x[2]": 20.0, + "x[3]": 18.0, + }, + "Sense": "<", + "RHS": 67.0, + "Category": "eq_capacity", + "User features": [0.0], } - assert features["Constraints"]["eq_capacity"]["Sense"] == "<" - assert features["Constraints"]["eq_capacity"]["RHS"] == 67.0