From ec694647942894b0c929cd777a3943a34b20b9f7 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 30 Mar 2021 21:44:13 -0500 Subject: [PATCH] Refactor primal --- miplearn/components/primal.py | 111 +++++++------------------------- tests/components/test_primal.py | 33 ++++++---- 2 files changed, 46 insertions(+), 98 deletions(-) diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 24948d2..80ab3de 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -99,12 +99,6 @@ class PrimalSolutionComponent(Component): stats["Primal: zero"] = self._n_zero stats["Primal: one"] = self._n_one - def x( - self, - instances: Union[List[str], List[Instance]], - ) -> Dict[Hashable, np.ndarray]: - return self._build_x_y_dict(instances, self._extract_variable_features) - def fit_xy( self, x: Dict[str, np.ndarray], @@ -133,7 +127,7 @@ class PrimalSolutionComponent(Component): solution[var_name][idx] = None # Compute y_pred - x = self.x([instance]) + x = self.x_sample(instance, sample) y_pred = {} for category in x.keys(): assert category in self.classifiers, ( @@ -210,85 +204,6 @@ class PrimalSolutionComponent(Component): ) return ev - @staticmethod - def _build_x_y_dict( - instances: Union[List[str], List[Instance]], - extract: Callable[ - [ - Instance, - TrainingSample, - str, - VarIndex, - Optional[float], - ], - Union[List[bool], List[float]], - ], - ) -> Dict[Hashable, np.ndarray]: - result: Dict[Hashable, List] = {} - for instance in InstanceIterator(instances): - assert isinstance(instance, Instance) - for sample in instance.training_data: - # Skip training samples without solution - if "LP solution" not in sample: - continue - if sample["LP solution"] is None: - continue - - # Iterate over all variables - for (var, var_dict) in sample["LP solution"].items(): - for (idx, lp_value) in var_dict.items(): - category = instance.get_variable_category(var, idx) - if category is None: - continue - if category not in result: - result[category] = [] - result[category] += [ - extract( - instance, - sample, - var, - idx, - lp_value, - ) - ] - - # Convert result to numpy arrays and return - return {c: np.array(ft) for (c, ft) in result.items()} - - @staticmethod - def _extract_variable_features( - instance: Instance, - sample: TrainingSample, - var: str, - idx: VarIndex, - lp_value: Optional[float], - ) -> Union[List[bool], List[float]]: - features = instance.get_variable_features(var, idx) - if lp_value is None: - return features - else: - return features + [lp_value] - - @staticmethod - def _extract_variable_labels( - instance: Instance, - sample: TrainingSample, - var: str, - idx: VarIndex, - lp_value: Optional[float], - ) -> Union[List[bool], List[float]]: - assert "Solution" in sample - solution = sample["Solution"] - assert solution is not None - opt_value = solution[var][idx] - assert opt_value is not None - assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( - f"Variable {var} has non-binary value {opt_value} in the optimal solution. " - f"Predicting values of non-binary variables is not currently supported. " - f"Please set its category to None." - ) - return [opt_value < 0.5, opt_value > 0.5] - @staticmethod def xy_sample( instance: Any, @@ -322,3 +237,27 @@ class PrimalSolutionComponent(Component): x[category] += [features] y[category] += [[opt_value < 0.5, opt_value >= 0.5]] return x, y + + @staticmethod + def x_sample( + instance: Any, + sample: TrainingSample, + ) -> Dict: + x: Dict = {} + for (var, var_dict) in instance.model_features["Variables"].items(): + for idx in var_dict.keys(): + category = instance.get_variable_category(var, idx) + if category is None: + continue + if category not in x.keys(): + x[category] = [] + features: Any = instance.get_variable_features(var, idx) + assert isinstance(features, list) + if "LP solution" in sample and sample["LP solution"] is not None: + lp_value = sample["LP solution"][var][idx] + if lp_value is not None: + features += [sample["LP solution"][var][idx]] + x[category] += [features] + for category in x.keys(): + x[category] = np.array(x[category]) + return x diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index a768d41..6065b0d 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -15,7 +15,6 @@ from miplearn.types import TrainingSample def test_xy_sample_with_lp_solution() -> None: - comp = PrimalSolutionComponent() instance = cast(Instance, Mock(spec=Instance)) instance.get_variable_category = Mock( # type: ignore side_effect=lambda var_name, index: { @@ -131,8 +130,6 @@ def test_xy_sample_without_lp_solution() -> None: def test_predict() -> None: - comp = PrimalSolutionComponent() - clf = Mock(spec=Classifier) clf.predict_proba = Mock( return_value=np.array( @@ -143,12 +140,8 @@ def test_predict() -> None: ] ) ) - comp.classifiers = {"default": clf} - thr = Mock(spec=Threshold) thr.predict = Mock(return_value=[0.75, 0.75]) - comp.thresholds = {"default": thr} - instance = cast(Instance, Mock(spec=Instance)) instance.get_variable_category = Mock( # type: ignore return_value="default", @@ -160,6 +153,15 @@ def test_predict() -> None: 2: [2.0, 0.0], }[index] ) + instance.model_features = { + "Variables": { + "x": { + 0: None, + 1: None, + 2: None, + } + } + } instance.training_data = [ { "LP solution": { @@ -171,16 +173,23 @@ def test_predict() -> None: } } ] - - x = comp.x([instance]) + x = { + "default": np.array( + [ + [0.0, 0.0, 0.1], + [0.0, 2.0, 0.5], + [2.0, 0.0, 0.9], + ] + ) + } + comp = PrimalSolutionComponent() + comp.classifiers = {"default": clf} + comp.thresholds = {"default": thr} solution_actual = comp.predict(instance) - - # Should ask for probabilities and thresholds clf.predict_proba.assert_called_once() thr.predict.assert_called_once() assert_array_equal(x["default"], clf.predict_proba.call_args[0][0]) assert_array_equal(x["default"], thr.predict.call_args[0][0]) - assert solution_actual == { "x": { 0: 0.0,