mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Refactor primal
This commit is contained in:
@@ -99,12 +99,6 @@ class PrimalSolutionComponent(Component):
|
|||||||
stats["Primal: zero"] = self._n_zero
|
stats["Primal: zero"] = self._n_zero
|
||||||
stats["Primal: one"] = self._n_one
|
stats["Primal: one"] = self._n_one
|
||||||
|
|
||||||
def x(
|
|
||||||
self,
|
|
||||||
instances: Union[List[str], List[Instance]],
|
|
||||||
) -> Dict[Hashable, np.ndarray]:
|
|
||||||
return self._build_x_y_dict(instances, self._extract_variable_features)
|
|
||||||
|
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[str, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
@@ -133,7 +127,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
solution[var_name][idx] = None
|
solution[var_name][idx] = None
|
||||||
|
|
||||||
# Compute y_pred
|
# Compute y_pred
|
||||||
x = self.x([instance])
|
x = self.x_sample(instance, sample)
|
||||||
y_pred = {}
|
y_pred = {}
|
||||||
for category in x.keys():
|
for category in x.keys():
|
||||||
assert category in self.classifiers, (
|
assert category in self.classifiers, (
|
||||||
@@ -210,85 +204,6 @@ class PrimalSolutionComponent(Component):
|
|||||||
)
|
)
|
||||||
return ev
|
return ev
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_x_y_dict(
|
|
||||||
instances: Union[List[str], List[Instance]],
|
|
||||||
extract: Callable[
|
|
||||||
[
|
|
||||||
Instance,
|
|
||||||
TrainingSample,
|
|
||||||
str,
|
|
||||||
VarIndex,
|
|
||||||
Optional[float],
|
|
||||||
],
|
|
||||||
Union[List[bool], List[float]],
|
|
||||||
],
|
|
||||||
) -> Dict[Hashable, np.ndarray]:
|
|
||||||
result: Dict[Hashable, List] = {}
|
|
||||||
for instance in InstanceIterator(instances):
|
|
||||||
assert isinstance(instance, Instance)
|
|
||||||
for sample in instance.training_data:
|
|
||||||
# Skip training samples without solution
|
|
||||||
if "LP solution" not in sample:
|
|
||||||
continue
|
|
||||||
if sample["LP solution"] is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Iterate over all variables
|
|
||||||
for (var, var_dict) in sample["LP solution"].items():
|
|
||||||
for (idx, lp_value) in var_dict.items():
|
|
||||||
category = instance.get_variable_category(var, idx)
|
|
||||||
if category is None:
|
|
||||||
continue
|
|
||||||
if category not in result:
|
|
||||||
result[category] = []
|
|
||||||
result[category] += [
|
|
||||||
extract(
|
|
||||||
instance,
|
|
||||||
sample,
|
|
||||||
var,
|
|
||||||
idx,
|
|
||||||
lp_value,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert result to numpy arrays and return
|
|
||||||
return {c: np.array(ft) for (c, ft) in result.items()}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_variable_features(
|
|
||||||
instance: Instance,
|
|
||||||
sample: TrainingSample,
|
|
||||||
var: str,
|
|
||||||
idx: VarIndex,
|
|
||||||
lp_value: Optional[float],
|
|
||||||
) -> Union[List[bool], List[float]]:
|
|
||||||
features = instance.get_variable_features(var, idx)
|
|
||||||
if lp_value is None:
|
|
||||||
return features
|
|
||||||
else:
|
|
||||||
return features + [lp_value]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_variable_labels(
|
|
||||||
instance: Instance,
|
|
||||||
sample: TrainingSample,
|
|
||||||
var: str,
|
|
||||||
idx: VarIndex,
|
|
||||||
lp_value: Optional[float],
|
|
||||||
) -> Union[List[bool], List[float]]:
|
|
||||||
assert "Solution" in sample
|
|
||||||
solution = sample["Solution"]
|
|
||||||
assert solution is not None
|
|
||||||
opt_value = solution[var][idx]
|
|
||||||
assert opt_value is not None
|
|
||||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
|
||||||
f"Variable {var} has non-binary value {opt_value} in the optimal solution. "
|
|
||||||
f"Predicting values of non-binary variables is not currently supported. "
|
|
||||||
f"Please set its category to None."
|
|
||||||
)
|
|
||||||
return [opt_value < 0.5, opt_value > 0.5]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def xy_sample(
|
def xy_sample(
|
||||||
instance: Any,
|
instance: Any,
|
||||||
@@ -322,3 +237,27 @@ class PrimalSolutionComponent(Component):
|
|||||||
x[category] += [features]
|
x[category] += [features]
|
||||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def x_sample(
|
||||||
|
instance: Any,
|
||||||
|
sample: TrainingSample,
|
||||||
|
) -> Dict:
|
||||||
|
x: Dict = {}
|
||||||
|
for (var, var_dict) in instance.model_features["Variables"].items():
|
||||||
|
for idx in var_dict.keys():
|
||||||
|
category = instance.get_variable_category(var, idx)
|
||||||
|
if category is None:
|
||||||
|
continue
|
||||||
|
if category not in x.keys():
|
||||||
|
x[category] = []
|
||||||
|
features: Any = instance.get_variable_features(var, idx)
|
||||||
|
assert isinstance(features, list)
|
||||||
|
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||||
|
lp_value = sample["LP solution"][var][idx]
|
||||||
|
if lp_value is not None:
|
||||||
|
features += [sample["LP solution"][var][idx]]
|
||||||
|
x[category] += [features]
|
||||||
|
for category in x.keys():
|
||||||
|
x[category] = np.array(x[category])
|
||||||
|
return x
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from miplearn.types import TrainingSample
|
|||||||
|
|
||||||
|
|
||||||
def test_xy_sample_with_lp_solution() -> None:
|
def test_xy_sample_with_lp_solution() -> None:
|
||||||
comp = PrimalSolutionComponent()
|
|
||||||
instance = cast(Instance, Mock(spec=Instance))
|
instance = cast(Instance, Mock(spec=Instance))
|
||||||
instance.get_variable_category = Mock( # type: ignore
|
instance.get_variable_category = Mock( # type: ignore
|
||||||
side_effect=lambda var_name, index: {
|
side_effect=lambda var_name, index: {
|
||||||
@@ -131,8 +130,6 @@ def test_xy_sample_without_lp_solution() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_predict() -> None:
|
def test_predict() -> None:
|
||||||
comp = PrimalSolutionComponent()
|
|
||||||
|
|
||||||
clf = Mock(spec=Classifier)
|
clf = Mock(spec=Classifier)
|
||||||
clf.predict_proba = Mock(
|
clf.predict_proba = Mock(
|
||||||
return_value=np.array(
|
return_value=np.array(
|
||||||
@@ -143,12 +140,8 @@ def test_predict() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
comp.classifiers = {"default": clf}
|
|
||||||
|
|
||||||
thr = Mock(spec=Threshold)
|
thr = Mock(spec=Threshold)
|
||||||
thr.predict = Mock(return_value=[0.75, 0.75])
|
thr.predict = Mock(return_value=[0.75, 0.75])
|
||||||
comp.thresholds = {"default": thr}
|
|
||||||
|
|
||||||
instance = cast(Instance, Mock(spec=Instance))
|
instance = cast(Instance, Mock(spec=Instance))
|
||||||
instance.get_variable_category = Mock( # type: ignore
|
instance.get_variable_category = Mock( # type: ignore
|
||||||
return_value="default",
|
return_value="default",
|
||||||
@@ -160,6 +153,15 @@ def test_predict() -> None:
|
|||||||
2: [2.0, 0.0],
|
2: [2.0, 0.0],
|
||||||
}[index]
|
}[index]
|
||||||
)
|
)
|
||||||
|
instance.model_features = {
|
||||||
|
"Variables": {
|
||||||
|
"x": {
|
||||||
|
0: None,
|
||||||
|
1: None,
|
||||||
|
2: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
instance.training_data = [
|
instance.training_data = [
|
||||||
{
|
{
|
||||||
"LP solution": {
|
"LP solution": {
|
||||||
@@ -171,16 +173,23 @@ def test_predict() -> None:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
x = {
|
||||||
x = comp.x([instance])
|
"default": np.array(
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.1],
|
||||||
|
[0.0, 2.0, 0.5],
|
||||||
|
[2.0, 0.0, 0.9],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
comp = PrimalSolutionComponent()
|
||||||
|
comp.classifiers = {"default": clf}
|
||||||
|
comp.thresholds = {"default": thr}
|
||||||
solution_actual = comp.predict(instance)
|
solution_actual = comp.predict(instance)
|
||||||
|
|
||||||
# Should ask for probabilities and thresholds
|
|
||||||
clf.predict_proba.assert_called_once()
|
clf.predict_proba.assert_called_once()
|
||||||
thr.predict.assert_called_once()
|
thr.predict.assert_called_once()
|
||||||
assert_array_equal(x["default"], clf.predict_proba.call_args[0][0])
|
assert_array_equal(x["default"], clf.predict_proba.call_args[0][0])
|
||||||
assert_array_equal(x["default"], thr.predict.call_args[0][0])
|
assert_array_equal(x["default"], thr.predict.call_args[0][0])
|
||||||
|
|
||||||
assert solution_actual == {
|
assert solution_actual == {
|
||||||
"x": {
|
"x": {
|
||||||
0: 0.0,
|
0: 0.0,
|
||||||
|
|||||||
Reference in New Issue
Block a user