diff --git a/miplearn/components/component.py b/miplearn/components/component.py index 355d587..c2eea0d 100644 --- a/miplearn/components/component.py +++ b/miplearn/components/component.py @@ -132,17 +132,17 @@ class Component: return @staticmethod - def xy_sample( + def xy( features: Features, sample: TrainingSample, - ) -> Optional[Tuple[Dict, Dict]]: + ) -> Tuple[Dict, Dict]: """ Given a set of features and a training sample, returns a pair of x and y dictionaries containing, respectively, the matrices of ML features and the labels for the sample. If the training sample does not include label - information, returns None. + information, returns (x, {}). """ - return None + pass def xy_instances( self, @@ -153,7 +153,7 @@ class Component: for instance in InstanceIterator(instances): assert isinstance(instance, Instance) for sample in instance.training_data: - xy = self.xy_sample(instance.features, sample) + xy = self.xy(instance.features, sample) if xy is None: continue x_sample, y_sample = xy diff --git a/miplearn/components/lazy_static.py b/miplearn/components/lazy_static.py index 29432dc..c2006a0 100644 --- a/miplearn/components/lazy_static.py +++ b/miplearn/components/lazy_static.py @@ -207,12 +207,10 @@ class StaticLazyConstraintsComponent(Component): return result @staticmethod - def xy_sample( + def xy( features: Features, sample: TrainingSample, - ) -> Optional[Tuple[Dict, Dict]]: - if "LazyStatic: Enforced" not in sample: - return None + ) -> Tuple[Dict, Dict]: x: Dict = {} y: Dict = {} for (cid, cfeatures) in features["Constraints"].items(): @@ -225,8 +223,9 @@ class StaticLazyConstraintsComponent(Component): x[category] = [] y[category] = [] x[category] += [cfeatures["User features"]] - if cid in sample["LazyStatic: Enforced"]: - y[category] += [[False, True]] - else: - y[category] += [[True, False]] + if "LazyStatic: Enforced" in sample: + if cid in sample["LazyStatic: Enforced"]: + y[category] += [[False, True]] + else: + y[category] += [[True, False]] return x, y diff --git a/miplearn/components/objective.py b/miplearn/components/objective.py index 270a8f3..d928c5a 100644 --- a/miplearn/components/objective.py +++ b/miplearn/components/objective.py @@ -166,12 +166,10 @@ class ObjectiveValueComponent(Component): return ev @staticmethod - def xy_sample( + def xy( features: Features, sample: TrainingSample, - ) -> Optional[Tuple[Dict, Dict]]: - if "Lower bound" not in sample: - return None + ) -> Tuple[Dict, Dict]: f = features["Instance"]["User features"] if "LP value" in sample and sample["LP value"] is not None: f += [sample["LP value"]] @@ -179,8 +177,11 @@ class ObjectiveValueComponent(Component): "Lower bound": [f], "Upper bound": [f], } - y = { - "Lower bound": [[sample["Lower bound"]]], - "Upper bound": [[sample["Upper bound"]]], - } - return x, y + if "Lower bound" in sample: + y = { + "Lower bound": [[sample["Lower bound"]]], + "Upper bound": [[sample["Upper bound"]]], + } + return x, y + else: + return x, {} diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 797a082..70b93a2 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -134,7 +134,7 @@ class PrimalSolutionComponent(Component): solution[var_name][idx] = None # Compute y_pred - x = self.x_sample(features, sample) + x, _ = self.xy(features, sample) y_pred = {} for category in x.keys(): assert category in self.classifiers, ( @@ -173,7 +173,7 @@ class PrimalSolutionComponent(Component): ): instance = instances[instance_idx] solution_actual = instance.training_data[0]["Solution"] - solution_pred = self.predict(instance) + solution_pred = self.predict(instance, instance.training_data[0]) vars_all, vars_one, vars_zero = set(), set(), set() pred_one_positive, pred_zero_positive = set(), set() @@ -213,33 +213,10 @@ class PrimalSolutionComponent(Component): return ev @staticmethod - def xy_sample( + def xy( features: Features, sample: TrainingSample, - ) -> Optional[Tuple[Dict, Dict]]: - if "Solution" not in sample: - return None - assert sample["Solution"] is not None - return cast( - Tuple[Dict, Dict], - PrimalSolutionComponent._extract(features, sample), - ) - - @staticmethod - def x_sample( - features: Features, - sample: TrainingSample, - ) -> Dict: - return cast( - Dict, - PrimalSolutionComponent._extract(features, sample), - ) - - @staticmethod - def _extract( - features: Features, - sample: TrainingSample, - ) -> Union[Dict, Tuple[Dict, Dict]]: + ) -> Tuple[Dict, Dict]: x: Dict = {} y: Dict = {} solution: Optional[Solution] = None @@ -271,7 +248,4 @@ class PrimalSolutionComponent(Component): "category to None." ) y[category] += [[opt_value < 0.5, opt_value >= 0.5]] - if solution is not None: - return x, y - else: - return x + return x, y diff --git a/tests/components/test_component.py b/tests/components/test_component.py index e03257a..3688aa9 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -57,7 +57,7 @@ def test_xy_instance(): instance_2 = Mock(spec=Instance) instance_2.training_data = ["s3"] instance_2.features = {} - comp.xy_sample = _xy_sample + comp.xy = _xy_sample x_expected = { "category_a": [ [1, 2, 3], diff --git a/tests/components/test_lazy_static.py b/tests/components/test_lazy_static.py index 8cd7991..b1001d1 100644 --- a/tests/components/test_lazy_static.py +++ b/tests/components/test_lazy_static.py @@ -286,7 +286,7 @@ def test_xy_sample() -> None: [False, True], ], } - xy = StaticLazyConstraintsComponent.xy_sample(features, sample) + xy = StaticLazyConstraintsComponent.xy(features, sample) assert xy is not None x_actual, y_actual = xy assert x_actual == x_expected diff --git a/tests/components/test_objective.py b/tests/components/test_objective.py index f9d7dcc..bbd93cd 100644 --- a/tests/components/test_objective.py +++ b/tests/components/test_objective.py @@ -125,7 +125,7 @@ def test_xy_sample_with_lp() -> None: "Lower bound": [[1.0]], "Upper bound": [[2.0]], } - xy = ObjectiveValueComponent.xy_sample(features, sample) + xy = ObjectiveValueComponent.xy(features, sample) assert xy is not None x_actual, y_actual = xy assert x_actual == x_expected @@ -150,7 +150,7 @@ def test_xy_sample_without_lp() -> None: "Lower bound": [[1.0]], "Upper bound": [[2.0]], } - xy = ObjectiveValueComponent.xy_sample(features, sample) + xy = ObjectiveValueComponent.xy(features, sample) assert xy is not None x_actual, y_actual = xy assert x_actual == x_expected diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index 6785843..4f7b620 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -8,15 +8,14 @@ import numpy as np from numpy.testing import assert_array_equal from scipy.stats import randint -from miplearn import Classifier, LearningSolver, GurobiSolver, GurobiPyomoSolver +from miplearn import Classifier, LearningSolver from miplearn.classifiers.threshold import Threshold from miplearn.components.primal import PrimalSolutionComponent from miplearn.problems.tsp import TravelingSalesmanGenerator from miplearn.types import TrainingSample, Features -from tests.fixtures.knapsack import get_knapsack_instance -def test_xy_sample_with_lp_solution() -> None: +def test_xy() -> None: features: Features = { "Variables": { "x": { @@ -70,14 +69,14 @@ def test_xy_sample_with_lp_solution() -> None: [True, False], ] } - xy = PrimalSolutionComponent.xy_sample(features, sample) + xy = PrimalSolutionComponent.xy(features, sample) assert xy is not None x_actual, y_actual = xy assert x_actual == x_expected assert y_actual == y_expected -def test_xy_sample_without_lp_solution() -> None: +def test_xy_without_lp_solution() -> None: features: Features = { "Variables": { "x": { @@ -123,7 +122,7 @@ def test_xy_sample_without_lp_solution() -> None: [True, False], ] } - xy = PrimalSolutionComponent.xy_sample(features, sample) + xy = PrimalSolutionComponent.xy(features, sample) assert xy is not None x_actual, y_actual = xy assert x_actual == x_expected @@ -170,7 +169,7 @@ def test_predict() -> None: } } } - x = PrimalSolutionComponent.x_sample(features, sample) + x, _ = PrimalSolutionComponent.xy(features, sample) comp = PrimalSolutionComponent() comp.classifiers = {"default": clf} comp.thresholds = {"default": thr}