diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 5aa85db..dea597d 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -117,15 +117,12 @@ class PrimalSolutionComponent(Component): def predict(self, instance: Instance) -> Solution: assert len(instance.training_data) > 0 sample = instance.training_data[-1] - assert "LP solution" in sample - lp_solution = sample["LP solution"] - assert lp_solution is not None # Initialize empty solution solution: Solution = {} - for (var_name, var_dict) in lp_solution.items(): + for (var_name, var_dict) in instance.features["Variables"].items(): solution[var_name] = {} - for (idx, lp_value) in var_dict.items(): + for idx in var_dict.keys(): solution[var_name][idx] = None # Compute y_pred @@ -147,9 +144,9 @@ class PrimalSolutionComponent(Component): # Convert y_pred into solution category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()} - for (var_name, var_dict) in lp_solution.items(): - for (idx, lp_value) in var_dict.items(): - category = instance.get_variable_category(var_name, idx) + for (var_name, var_dict) in instance.features["Variables"].items(): + for (idx, var_features) in var_dict.items(): + category = var_features["Category"] offset = category_offset[category] category_offset[category] += 1 if y_pred[category][offset, 0]: @@ -211,10 +208,8 @@ class PrimalSolutionComponent(Component): instance: Any, sample: TrainingSample, ) -> Tuple[Dict, Dict]: - x: Dict = {} - y: Dict = {} if "Solution" not in sample: - return x, y + return {}, {} assert sample["Solution"] is not None return cast( Tuple[Dict, Dict], @@ -222,7 +217,6 @@ class PrimalSolutionComponent(Component): instance, sample, sample["Solution"], - extract_y=True, ), ) @@ -231,51 +225,43 @@ class PrimalSolutionComponent(Component): instance: Any, sample: TrainingSample, ) -> Dict: - return cast( - Dict, - PrimalSolutionComponent._extract( - instance, - sample, - instance.features["Variables"], - extract_y=False, - ), - ) + return cast(Dict, PrimalSolutionComponent._extract(instance, sample)) @staticmethod def _extract( instance: Any, sample: TrainingSample, - variables: Dict, - extract_y: bool, + solution: Optional[Dict] = None, ) -> Union[Dict, Tuple[Dict, Dict]]: x: Dict = {} y: Dict = {} - for (var, var_dict) in variables.items(): - for (idx, opt_value) in var_dict.items(): - if extract_y: + opt_value = 0.0 + for (var_name, var_dict) in instance.features["Variables"].items(): + for (idx, var_features) in var_dict.items(): + category = var_features["Category"] + if category is None: + continue + if solution is not None: + opt_value = solution[var_name][idx] assert opt_value is not None assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( - f"Variable {var} has non-binary value {opt_value} in the " + f"Variable {var_name} has non-binary value {opt_value} in the " "optimal solution. Predicting values of non-binary " "variables is not currently supported. Please set its " "category to None." ) - category = instance.get_variable_category(var, idx) - if category is None: - continue if category not in x.keys(): x[category] = [] y[category] = [] - features: Any = instance.get_variable_features(var, idx) - assert isinstance(features, list) + features = var_features["User features"] if "LP solution" in sample and sample["LP solution"] is not None: - lp_value = sample["LP solution"][var][idx] + lp_value = sample["LP solution"][var_name][idx] if lp_value is not None: - features += [sample["LP solution"][var][idx]] + features += [sample["LP solution"][var_name][idx]] x[category] += [features] - if extract_y: + if solution is not None: y[category] += [[opt_value < 0.5, opt_value >= 0.5]] - if extract_y: + if solution is not None: return x, y else: return x diff --git a/miplearn/types.py b/miplearn/types.py index acb8c71..1fc0a23 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -73,6 +73,15 @@ LearningSolveStats = TypedDict( total=False, ) +VariableFeatures = TypedDict( + "VariableFeatures", + { + "Category": Optional[Hashable], + "User features": Optional[List[float]], + }, + total=False, +) + ConstraintFeatures = TypedDict( "ConstraintFeatures", { @@ -88,7 +97,7 @@ ConstraintFeatures = TypedDict( ModelFeatures = TypedDict( "ModelFeatures", { - "Variables": Solution, + "Variables": Dict[str, Dict[VarIndex, VariableFeatures]], "Constraints": Dict[str, ConstraintFeatures], }, total=False, diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index 8914da3..dc8a48d 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -16,22 +16,27 @@ from miplearn.types import TrainingSample def test_xy_sample_with_lp_solution() -> None: instance = cast(Instance, Mock(spec=Instance)) - instance.get_variable_category = Mock( # type: ignore - side_effect=lambda var_name, index: { - 0: "default", - 1: None, - 2: "default", - 3: "default", - }[index] - ) - instance.get_variable_features = Mock( # type: ignore - side_effect=lambda var, index: { - 0: [0.0, 0.0], - 1: [0.0, 1.0], - 2: [1.0, 0.0], - 3: [1.0, 1.0], - }[index] - ) + instance.features = { + "Variables": { + "x": { + 0: { + "Category": "default", + "User features": [0.0, 0.0], + }, + 1: { + "Category": None, + }, + 2: { + "Category": "default", + "User features": [1.0, 0.0], + }, + 3: { + "Category": "default", + "User features": [1.0, 1.0], + }, + } + } + } sample: TrainingSample = { "Solution": { "x": { @@ -78,22 +83,27 @@ def test_xy_sample_with_lp_solution() -> None: def test_xy_sample_without_lp_solution() -> None: comp = PrimalSolutionComponent() instance = cast(Instance, Mock(spec=Instance)) - instance.get_variable_category = Mock( # type: ignore - side_effect=lambda var_name, index: { - 0: "default", - 1: None, - 2: "default", - 3: "default", - }[index] - ) - instance.get_variable_features = Mock( # type: ignore - side_effect=lambda var, index: { - 0: [0.0, 0.0], - 1: [0.0, 1.0], - 2: [1.0, 0.0], - 3: [1.0, 1.0], - }[index] - ) + instance.features = { + "Variables": { + "x": { + 0: { + "Category": "default", + "User features": [0.0, 0.0], + }, + 1: { + "Category": None, + }, + 2: { + "Category": "default", + "User features": [1.0, 0.0], + }, + 3: { + "Category": "default", + "User features": [1.0, 1.0], + }, + } + } + } sample: TrainingSample = { "Solution": { "x": { @@ -143,22 +153,21 @@ def test_predict() -> None: thr = Mock(spec=Threshold) thr.predict = Mock(return_value=[0.75, 0.75]) instance = cast(Instance, Mock(spec=Instance)) - instance.get_variable_category = Mock( # type: ignore - return_value="default", - ) - instance.get_variable_features = Mock( # type: ignore - side_effect=lambda var, index: { - 0: [0.0, 0.0], - 1: [0.0, 2.0], - 2: [2.0, 0.0], - }[index] - ) instance.features = { "Variables": { "x": { - 0: None, - 1: None, - 2: None, + 0: { + "Category": "default", + "User features": [0.0, 0.0], + }, + 1: { + "Category": "default", + "User features": [0.0, 2.0], + }, + 2: { + "Category": "default", + "User features": [2.0, 0.0], + }, } } }