mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Refactor PrimalSolutionComponent
This commit is contained in:
@@ -3,8 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Union, Dict, Any
|
||||
from typing import Union, Dict, Callable, List, Hashable, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
@@ -14,35 +13,46 @@ from miplearn.classifiers.adaptive import AdaptiveClassifier
|
||||
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.extractors import VariableFeaturesExtractor, SolutionExtractor, Extractor
|
||||
from miplearn.extractors import InstanceIterator
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.types import TrainingSample, VarIndex, Solution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PrimalSolutionComponent(Component):
|
||||
"""
|
||||
A component that predicts primal solutions.
|
||||
A component that predicts the optimal primal values for the binary decision
|
||||
variables.
|
||||
|
||||
In exact mode, predicted primal solutions are provided to the solver as MIP
|
||||
starts. In heuristic mode, this component fixes the decision variables to their
|
||||
predicted values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier: Classifier = AdaptiveClassifier(),
|
||||
classifier: Callable[[], Classifier] = lambda: AdaptiveClassifier(),
|
||||
mode: str = "exact",
|
||||
threshold: Union[float, Threshold] = MinPrecisionThreshold(0.98),
|
||||
threshold: Callable[[], Threshold] = lambda: MinPrecisionThreshold(
|
||||
[0.98, 0.98]
|
||||
),
|
||||
) -> None:
|
||||
assert mode in ["exact", "heuristic"]
|
||||
self.mode = mode
|
||||
self.classifiers: Dict[Any, Classifier] = {}
|
||||
self.thresholds: Dict[Any, Union[float, Threshold]] = {}
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.threshold_factory = threshold
|
||||
self.classifier_factory = classifier
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
logger.info("Predicting primal solution...")
|
||||
solution = self.predict(instance)
|
||||
if self.mode == "heuristic":
|
||||
solver.internal_solver.fix(solution)
|
||||
else:
|
||||
solver.internal_solver.set_warm_start(solution)
|
||||
if len(self.thresholds) > 0:
|
||||
logger.info("Predicting primal solution...")
|
||||
solution = self.predict(instance)
|
||||
if self.mode == "heuristic":
|
||||
solver.internal_solver.fix(solution)
|
||||
else:
|
||||
solver.internal_solver.set_warm_start(solution)
|
||||
|
||||
def after_solve(
|
||||
self,
|
||||
@@ -54,79 +64,76 @@ class PrimalSolutionComponent(Component):
|
||||
):
|
||||
pass
|
||||
|
||||
def x(self, training_instances):
|
||||
return VariableFeaturesExtractor().extract(training_instances)
|
||||
def x(
|
||||
self,
|
||||
instances: Union[List[str], List[Instance]],
|
||||
) -> Dict[Hashable, np.ndarray]:
|
||||
return self._build_x_y_dict(instances, self._extract_variable_features)
|
||||
|
||||
def y(self, training_instances):
|
||||
return SolutionExtractor().extract(training_instances)
|
||||
def y(
|
||||
self,
|
||||
instances: Union[List[str], List[Instance]],
|
||||
) -> Dict[Hashable, np.ndarray]:
|
||||
return self._build_x_y_dict(instances, self._extract_variable_labels)
|
||||
|
||||
def fit(self, training_instances, n_jobs=1):
|
||||
logger.debug("Extracting features...")
|
||||
features = VariableFeaturesExtractor().extract(training_instances)
|
||||
solutions = SolutionExtractor().extract(training_instances)
|
||||
def fit(
|
||||
self,
|
||||
training_instances: Union[List[str], List[Instance]],
|
||||
n_jobs: int = 1,
|
||||
) -> None:
|
||||
x = self.x(training_instances)
|
||||
y = self.y(training_instances)
|
||||
for category in x.keys():
|
||||
clf = self.classifier_factory()
|
||||
thr = self.threshold_factory()
|
||||
clf.fit(x[category], y[category])
|
||||
thr.fit(clf, x[category], y[category])
|
||||
self.classifiers[category] = clf
|
||||
self.thresholds[category] = thr
|
||||
|
||||
for category in tqdm(
|
||||
features.keys(),
|
||||
desc="Fit (primal)",
|
||||
):
|
||||
x_train = features[category]
|
||||
for label in [0, 1]:
|
||||
y_train = solutions[category][:, label].astype(int)
|
||||
def predict(self, instance: Instance) -> Solution:
|
||||
assert len(instance.training_data) > 0
|
||||
sample = instance.training_data[-1]
|
||||
assert "LP solution" in sample
|
||||
lp_solution = sample["LP solution"]
|
||||
assert lp_solution is not None
|
||||
|
||||
# If all samples are either positive or negative, make constant
|
||||
# predictions
|
||||
y_avg = np.average(y_train)
|
||||
if y_avg < 0.001 or y_avg >= 0.999:
|
||||
self.classifiers[category, label] = round(y_avg)
|
||||
self.thresholds[category, label] = 0.50
|
||||
continue
|
||||
# Initialize empty solution
|
||||
solution: Solution = {}
|
||||
for (var_name, var_dict) in lp_solution.items():
|
||||
solution[var_name] = {}
|
||||
for (idx, lp_value) in var_dict.items():
|
||||
solution[var_name][idx] = None
|
||||
|
||||
# Create a copy of classifier prototype and train it
|
||||
if isinstance(self.classifier_prototype, list):
|
||||
clf = deepcopy(self.classifier_prototype[label])
|
||||
else:
|
||||
clf = deepcopy(self.classifier_prototype)
|
||||
clf.fit(x_train, y_train)
|
||||
# Compute y_pred
|
||||
x = self.x([instance])
|
||||
y_pred = {}
|
||||
for category in x.keys():
|
||||
assert category in self.classifiers, (
|
||||
f"Classifier for category {category} has not been trained. "
|
||||
f"Please call component.fit before component.predict."
|
||||
)
|
||||
proba = self.classifiers[category].predict_proba(x[category])
|
||||
thr = self.thresholds[category].predict(x[category])
|
||||
y_pred[category] = np.vstack(
|
||||
[
|
||||
proba[:, 0] > thr[0],
|
||||
proba[:, 1] > thr[1],
|
||||
]
|
||||
).T
|
||||
|
||||
# Find threshold (dynamic or static)
|
||||
if isinstance(self.threshold_prototype, Threshold):
|
||||
self.thresholds[category, label] = self.threshold_prototype.fit(
|
||||
clf,
|
||||
x_train,
|
||||
y_train,
|
||||
)
|
||||
else:
|
||||
self.thresholds[category, label] = deepcopy(
|
||||
self.threshold_prototype
|
||||
)
|
||||
# Convert y_pred into solution
|
||||
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
||||
for (var_name, var_dict) in lp_solution.items():
|
||||
for (idx, lp_value) in var_dict.items():
|
||||
category = instance.get_variable_category(var_name, idx)
|
||||
offset = category_offset[category]
|
||||
category_offset[category] += 1
|
||||
if y_pred[category][offset, 0]:
|
||||
solution[var_name][idx] = 0.0
|
||||
if y_pred[category][offset, 1]:
|
||||
solution[var_name][idx] = 1.0
|
||||
|
||||
self.classifiers[category, label] = clf
|
||||
|
||||
def predict(self, instance):
|
||||
solution = {}
|
||||
x_test = VariableFeaturesExtractor().extract([instance])
|
||||
var_split = Extractor.split_variables(instance)
|
||||
for category in var_split.keys():
|
||||
n = len(var_split[category])
|
||||
for (i, (var, index)) in enumerate(var_split[category]):
|
||||
if var not in solution.keys():
|
||||
solution[var] = {}
|
||||
solution[var][index] = None
|
||||
for label in [0, 1]:
|
||||
if (category, label) not in self.classifiers.keys():
|
||||
continue
|
||||
clf = self.classifiers[category, label]
|
||||
if isinstance(clf, float) or isinstance(clf, int):
|
||||
ws = np.array([[1 - clf, clf] for _ in range(n)])
|
||||
else:
|
||||
ws = clf.predict_proba(x_test[category])
|
||||
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (
|
||||
n,
|
||||
ws.shape,
|
||||
)
|
||||
for (i, (var, index)) in enumerate(var_split[category]):
|
||||
if ws[i, 1] >= self.thresholds[category, label]:
|
||||
solution[var][index] = label
|
||||
return solution
|
||||
|
||||
def evaluate(self, instances):
|
||||
@@ -175,3 +182,82 @@ class PrimalSolutionComponent(Component):
|
||||
tp_one, tn_one, fp_one, fn_one
|
||||
)
|
||||
return ev
|
||||
|
||||
@staticmethod
|
||||
def _build_x_y_dict(
|
||||
instances: Union[List[str], List[Instance]],
|
||||
extract: Callable[
|
||||
[
|
||||
Instance,
|
||||
TrainingSample,
|
||||
str,
|
||||
VarIndex,
|
||||
Optional[float],
|
||||
],
|
||||
Union[List[bool], List[float]],
|
||||
],
|
||||
) -> Dict[Hashable, np.ndarray]:
|
||||
result: Dict[Hashable, List] = {}
|
||||
for instance in InstanceIterator(instances):
|
||||
assert isinstance(instance, Instance)
|
||||
for sample in instance.training_data:
|
||||
# Skip training samples without solution
|
||||
if "LP solution" not in sample:
|
||||
continue
|
||||
if sample["LP solution"] is None:
|
||||
continue
|
||||
|
||||
# Iterate over all variables
|
||||
for (var, var_dict) in sample["LP solution"].items():
|
||||
for (idx, lp_value) in var_dict.items():
|
||||
category = instance.get_variable_category(var, idx)
|
||||
if category is None:
|
||||
continue
|
||||
if category not in result:
|
||||
result[category] = []
|
||||
result[category] += [
|
||||
extract(
|
||||
instance,
|
||||
sample,
|
||||
var,
|
||||
idx,
|
||||
lp_value,
|
||||
)
|
||||
]
|
||||
|
||||
# Convert result to numpy arrays and return
|
||||
return {c: np.array(ft) for (c, ft) in result.items()}
|
||||
|
||||
@staticmethod
|
||||
def _extract_variable_features(
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
var: str,
|
||||
idx: VarIndex,
|
||||
lp_value: Optional[float],
|
||||
) -> Union[List[bool], List[float]]:
|
||||
features = instance.get_variable_features(var, idx)
|
||||
if lp_value is None:
|
||||
return features
|
||||
else:
|
||||
return features + [lp_value]
|
||||
|
||||
@staticmethod
|
||||
def _extract_variable_labels(
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
var: str,
|
||||
idx: VarIndex,
|
||||
lp_value: Optional[float],
|
||||
) -> Union[List[bool], List[float]]:
|
||||
assert "Solution" in sample
|
||||
solution = sample["Solution"]
|
||||
assert solution is not None
|
||||
opt_value = solution[var][idx]
|
||||
assert opt_value is not None
|
||||
assert 0.0 <= opt_value <= 1.0, (
|
||||
f"Variable {var} has non-binary value {opt_value} in the optimal solution. "
|
||||
f"Predicting values of non-binary variables is not currently supported. "
|
||||
f"Please set its category to None."
|
||||
)
|
||||
return [opt_value < 0.5, opt_value > 0.5]
|
||||
|
||||
Reference in New Issue
Block a user