mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Refactor primal
This commit is contained in:
@@ -99,12 +99,6 @@ class PrimalSolutionComponent(Component):
|
||||
stats["Primal: zero"] = self._n_zero
|
||||
stats["Primal: one"] = self._n_one
|
||||
|
||||
def x(
|
||||
self,
|
||||
instances: Union[List[str], List[Instance]],
|
||||
) -> Dict[Hashable, np.ndarray]:
|
||||
return self._build_x_y_dict(instances, self._extract_variable_features)
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
@@ -133,7 +127,7 @@ class PrimalSolutionComponent(Component):
|
||||
solution[var_name][idx] = None
|
||||
|
||||
# Compute y_pred
|
||||
x = self.x([instance])
|
||||
x = self.x_sample(instance, sample)
|
||||
y_pred = {}
|
||||
for category in x.keys():
|
||||
assert category in self.classifiers, (
|
||||
@@ -210,85 +204,6 @@ class PrimalSolutionComponent(Component):
|
||||
)
|
||||
return ev
|
||||
|
||||
@staticmethod
|
||||
def _build_x_y_dict(
|
||||
instances: Union[List[str], List[Instance]],
|
||||
extract: Callable[
|
||||
[
|
||||
Instance,
|
||||
TrainingSample,
|
||||
str,
|
||||
VarIndex,
|
||||
Optional[float],
|
||||
],
|
||||
Union[List[bool], List[float]],
|
||||
],
|
||||
) -> Dict[Hashable, np.ndarray]:
|
||||
result: Dict[Hashable, List] = {}
|
||||
for instance in InstanceIterator(instances):
|
||||
assert isinstance(instance, Instance)
|
||||
for sample in instance.training_data:
|
||||
# Skip training samples without solution
|
||||
if "LP solution" not in sample:
|
||||
continue
|
||||
if sample["LP solution"] is None:
|
||||
continue
|
||||
|
||||
# Iterate over all variables
|
||||
for (var, var_dict) in sample["LP solution"].items():
|
||||
for (idx, lp_value) in var_dict.items():
|
||||
category = instance.get_variable_category(var, idx)
|
||||
if category is None:
|
||||
continue
|
||||
if category not in result:
|
||||
result[category] = []
|
||||
result[category] += [
|
||||
extract(
|
||||
instance,
|
||||
sample,
|
||||
var,
|
||||
idx,
|
||||
lp_value,
|
||||
)
|
||||
]
|
||||
|
||||
# Convert result to numpy arrays and return
|
||||
return {c: np.array(ft) for (c, ft) in result.items()}
|
||||
|
||||
@staticmethod
|
||||
def _extract_variable_features(
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
var: str,
|
||||
idx: VarIndex,
|
||||
lp_value: Optional[float],
|
||||
) -> Union[List[bool], List[float]]:
|
||||
features = instance.get_variable_features(var, idx)
|
||||
if lp_value is None:
|
||||
return features
|
||||
else:
|
||||
return features + [lp_value]
|
||||
|
||||
@staticmethod
|
||||
def _extract_variable_labels(
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
var: str,
|
||||
idx: VarIndex,
|
||||
lp_value: Optional[float],
|
||||
) -> Union[List[bool], List[float]]:
|
||||
assert "Solution" in sample
|
||||
solution = sample["Solution"]
|
||||
assert solution is not None
|
||||
opt_value = solution[var][idx]
|
||||
assert opt_value is not None
|
||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
||||
f"Variable {var} has non-binary value {opt_value} in the optimal solution. "
|
||||
f"Predicting values of non-binary variables is not currently supported. "
|
||||
f"Please set its category to None."
|
||||
)
|
||||
return [opt_value < 0.5, opt_value > 0.5]
|
||||
|
||||
@staticmethod
|
||||
def xy_sample(
|
||||
instance: Any,
|
||||
@@ -322,3 +237,27 @@ class PrimalSolutionComponent(Component):
|
||||
x[category] += [features]
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
return x, y
|
||||
|
||||
@staticmethod
|
||||
def x_sample(
|
||||
instance: Any,
|
||||
sample: TrainingSample,
|
||||
) -> Dict:
|
||||
x: Dict = {}
|
||||
for (var, var_dict) in instance.model_features["Variables"].items():
|
||||
for idx in var_dict.keys():
|
||||
category = instance.get_variable_category(var, idx)
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
features: Any = instance.get_variable_features(var, idx)
|
||||
assert isinstance(features, list)
|
||||
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||
lp_value = sample["LP solution"][var][idx]
|
||||
if lp_value is not None:
|
||||
features += [sample["LP solution"][var][idx]]
|
||||
x[category] += [features]
|
||||
for category in x.keys():
|
||||
x[category] = np.array(x[category])
|
||||
return x
|
||||
|
||||
@@ -15,7 +15,6 @@ from miplearn.types import TrainingSample
|
||||
|
||||
|
||||
def test_xy_sample_with_lp_solution() -> None:
|
||||
comp = PrimalSolutionComponent()
|
||||
instance = cast(Instance, Mock(spec=Instance))
|
||||
instance.get_variable_category = Mock( # type: ignore
|
||||
side_effect=lambda var_name, index: {
|
||||
@@ -131,8 +130,6 @@ def test_xy_sample_without_lp_solution() -> None:
|
||||
|
||||
|
||||
def test_predict() -> None:
|
||||
comp = PrimalSolutionComponent()
|
||||
|
||||
clf = Mock(spec=Classifier)
|
||||
clf.predict_proba = Mock(
|
||||
return_value=np.array(
|
||||
@@ -143,12 +140,8 @@ def test_predict() -> None:
|
||||
]
|
||||
)
|
||||
)
|
||||
comp.classifiers = {"default": clf}
|
||||
|
||||
thr = Mock(spec=Threshold)
|
||||
thr.predict = Mock(return_value=[0.75, 0.75])
|
||||
comp.thresholds = {"default": thr}
|
||||
|
||||
instance = cast(Instance, Mock(spec=Instance))
|
||||
instance.get_variable_category = Mock( # type: ignore
|
||||
return_value="default",
|
||||
@@ -160,6 +153,15 @@ def test_predict() -> None:
|
||||
2: [2.0, 0.0],
|
||||
}[index]
|
||||
)
|
||||
instance.model_features = {
|
||||
"Variables": {
|
||||
"x": {
|
||||
0: None,
|
||||
1: None,
|
||||
2: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
instance.training_data = [
|
||||
{
|
||||
"LP solution": {
|
||||
@@ -171,16 +173,23 @@ def test_predict() -> None:
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
x = comp.x([instance])
|
||||
x = {
|
||||
"default": np.array(
|
||||
[
|
||||
[0.0, 0.0, 0.1],
|
||||
[0.0, 2.0, 0.5],
|
||||
[2.0, 0.0, 0.9],
|
||||
]
|
||||
)
|
||||
}
|
||||
comp = PrimalSolutionComponent()
|
||||
comp.classifiers = {"default": clf}
|
||||
comp.thresholds = {"default": thr}
|
||||
solution_actual = comp.predict(instance)
|
||||
|
||||
# Should ask for probabilities and thresholds
|
||||
clf.predict_proba.assert_called_once()
|
||||
thr.predict.assert_called_once()
|
||||
assert_array_equal(x["default"], clf.predict_proba.call_args[0][0])
|
||||
assert_array_equal(x["default"], thr.predict.call_args[0][0])
|
||||
|
||||
assert solution_actual == {
|
||||
"x": {
|
||||
0: 0.0,
|
||||
|
||||
Reference in New Issue
Block a user