From 9ca4cc3c24f405929c85d0e0704d9750ef1dcb89 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Sat, 10 Apr 2021 19:11:38 -0500 Subject: [PATCH] Include additional features in instance.features --- miplearn/features.py | 55 +++++++++++++++--------------------------- tests/test_features.py | 49 ++++++++++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 47 deletions(-) diff --git a/miplearn/features.py b/miplearn/features.py index 342aa0e..7fa4374 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -84,16 +84,14 @@ class FeaturesExtractor: self.solver = internal_solver def extract(self, instance: "Instance") -> None: - instance.features.variables = self._extract_variables(instance) - instance.features.constraints = self._extract_constraints(instance) - instance.features.instance = self._extract_instance(instance, instance.features) - - def _extract_variables( - self, - instance: "Instance", - ) -> Dict[VariableName, Variable]: - result: Dict[VariableName, Variable] = {} - for var_name in self.solver.get_variable_names(): + instance.features.variables = self.solver.get_variables() + instance.features.constraints = self.solver.get_constraints() + self._extract_user_features_vars(instance) + self._extract_user_features_constrs(instance) + self._extract_user_features_instance(instance) + + def _extract_user_features_vars(self, instance: "Instance"): + for (var_name, var) in instance.features.variables.items(): user_features: Optional[List[float]] = None category: Category = instance.get_variable_category(var_name) if category is not None: @@ -115,20 +113,12 @@ class FeaturesExtractor: f"Found {type(v).__name__} instead " f"for var={var_name}." ) - result[var_name] = Variable( - category=category, - user_features=user_features, - ) - return result + var.category = category + var.user_features = user_features - def _extract_constraints( - self, - instance: "Instance", - ) -> Dict[str, Constraint]: + def _extract_user_features_constrs(self, instance: "Instance"): has_static_lazy = instance.has_static_lazy_constraints() - constraints = self.solver.get_constraints() - - for (cid, constr) in constraints.items(): + for (cid, constr) in instance.features.constraints.items(): user_features = None category = instance.get_constraint_category(cid) if category is not None: @@ -147,18 +137,13 @@ class FeaturesExtractor: f"Constraint features must be a list of floats. " f"Found {type(user_features[0]).__name__} instead for cid={cid}." ) - constraints[cid].category = category - constraints[cid].user_features = user_features if has_static_lazy: - constraints[cid].lazy = instance.is_constraint_lazy(cid) - return constraints - - @staticmethod - def _extract_instance( - instance: "Instance", - features: Features, - ) -> InstanceFeatures: - assert features.constraints is not None + constr.lazy = instance.is_constraint_lazy(cid) + constr.category = category + constr.user_features = user_features + + def _extract_user_features_instance(self, instance: "Instance"): + assert instance.features.constraints is not None user_features = instance.get_instance_features() if isinstance(user_features, np.ndarray): user_features = user_features.tolist() @@ -172,10 +157,10 @@ class FeaturesExtractor: f"Found {type(v).__name__} instead." ) lazy_count = 0 - for (cid, cdict) in features.constraints.items(): + for (cid, cdict) in instance.features.constraints.items(): if cdict.lazy: lazy_count += 1 - return InstanceFeatures( + instance.features.instance = InstanceFeatures( user_features=user_features, lazy_constraint_count=lazy_count, ) diff --git a/tests/test_features.py b/tests/test_features.py index e36d423..2205c6b 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -9,34 +9,55 @@ from miplearn.features import ( Constraint, ) from miplearn.solvers.gurobi import GurobiSolver +from miplearn.solvers.tests import assert_equals def test_knapsack() -> None: - for solver_factory in [GurobiSolver]: - solver = solver_factory() - instance = solver.build_test_instance_knapsack() - model = instance.to_model() - solver.set_instance(instance, model) - FeaturesExtractor(solver).extract(instance) - assert instance.features.variables == { + solver = GurobiSolver() + instance = solver.build_test_instance_knapsack() + model = instance.to_model() + solver.set_instance(instance, model) + FeaturesExtractor(solver).extract(instance) + assert_equals( + instance.features.variables, + { "x[0]": Variable( category="default", + lower_bound=0.0, + obj_coeff=505.0, + type="B", + upper_bound=1.0, user_features=[23.0, 505.0], ), "x[1]": Variable( category="default", + lower_bound=0.0, + obj_coeff=352.0, + type="B", + upper_bound=1.0, user_features=[26.0, 352.0], ), "x[2]": Variable( category="default", + lower_bound=0.0, + obj_coeff=458.0, + type="B", + upper_bound=1.0, user_features=[20.0, 458.0], ), "x[3]": Variable( category="default", + lower_bound=0.0, + obj_coeff=220.0, + type="B", + upper_bound=1.0, user_features=[18.0, 220.0], ), - } - assert instance.features.constraints == { + }, + ) + assert_equals( + instance.features.constraints, + { "eq_capacity": Constraint( lhs={ "x[0]": 23.0, @@ -50,8 +71,12 @@ def test_knapsack() -> None: category="eq_capacity", user_features=[0.0], ) - } - assert instance.features.instance == InstanceFeatures( + }, + ) + assert_equals( + instance.features.instance, + InstanceFeatures( user_features=[67.0, 21.75], lazy_constraint_count=0, - ) + ), + )