Refactor PrimalSolutionComponent

This commit is contained in:
2021-01-25 14:54:58 -06:00
parent f68cc5bd59
commit 3ab3bb3c1f
9 changed files with 501 additions and 233 deletions

View File

@@ -3,8 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from copy import deepcopy
from typing import Union, Dict, Any
from typing import Union, Dict, Callable, List, Hashable, Optional
import numpy as np
from tqdm.auto import tqdm
@@ -14,35 +13,46 @@ from miplearn.classifiers.adaptive import AdaptiveClassifier
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.extractors import VariableFeaturesExtractor, SolutionExtractor, Extractor
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
from miplearn.types import TrainingSample, VarIndex, Solution
logger = logging.getLogger(__name__)
class PrimalSolutionComponent(Component):
"""
A component that predicts primal solutions.
A component that predicts the optimal primal values for the binary decision
variables.
In exact mode, predicted primal solutions are provided to the solver as MIP
starts. In heuristic mode, this component fixes the decision variables to their
predicted values.
"""
def __init__(
self,
classifier: Classifier = AdaptiveClassifier(),
classifier: Callable[[], Classifier] = lambda: AdaptiveClassifier(),
mode: str = "exact",
threshold: Union[float, Threshold] = MinPrecisionThreshold(0.98),
threshold: Callable[[], Threshold] = lambda: MinPrecisionThreshold(
[0.98, 0.98]
),
) -> None:
assert mode in ["exact", "heuristic"]
self.mode = mode
self.classifiers: Dict[Any, Classifier] = {}
self.thresholds: Dict[Any, Union[float, Threshold]] = {}
self.threshold_prototype = threshold
self.classifier_prototype = classifier
self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.threshold_factory = threshold
self.classifier_factory = classifier
def before_solve(self, solver, instance, model):
logger.info("Predicting primal solution...")
solution = self.predict(instance)
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
solver.internal_solver.set_warm_start(solution)
if len(self.thresholds) > 0:
logger.info("Predicting primal solution...")
solution = self.predict(instance)
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
solver.internal_solver.set_warm_start(solution)
def after_solve(
self,
@@ -54,79 +64,76 @@ class PrimalSolutionComponent(Component):
):
pass
def x(self, training_instances):
return VariableFeaturesExtractor().extract(training_instances)
def x(
self,
instances: Union[List[str], List[Instance]],
) -> Dict[Hashable, np.ndarray]:
return self._build_x_y_dict(instances, self._extract_variable_features)
def y(self, training_instances):
return SolutionExtractor().extract(training_instances)
def y(
self,
instances: Union[List[str], List[Instance]],
) -> Dict[Hashable, np.ndarray]:
return self._build_x_y_dict(instances, self._extract_variable_labels)
def fit(self, training_instances, n_jobs=1):
logger.debug("Extracting features...")
features = VariableFeaturesExtractor().extract(training_instances)
solutions = SolutionExtractor().extract(training_instances)
def fit(
self,
training_instances: Union[List[str], List[Instance]],
n_jobs: int = 1,
) -> None:
x = self.x(training_instances)
y = self.y(training_instances)
for category in x.keys():
clf = self.classifier_factory()
thr = self.threshold_factory()
clf.fit(x[category], y[category])
thr.fit(clf, x[category], y[category])
self.classifiers[category] = clf
self.thresholds[category] = thr
for category in tqdm(
features.keys(),
desc="Fit (primal)",
):
x_train = features[category]
for label in [0, 1]:
y_train = solutions[category][:, label].astype(int)
def predict(self, instance: Instance) -> Solution:
assert len(instance.training_data) > 0
sample = instance.training_data[-1]
assert "LP solution" in sample
lp_solution = sample["LP solution"]
assert lp_solution is not None
# If all samples are either positive or negative, make constant
# predictions
y_avg = np.average(y_train)
if y_avg < 0.001 or y_avg >= 0.999:
self.classifiers[category, label] = round(y_avg)
self.thresholds[category, label] = 0.50
continue
# Initialize empty solution
solution: Solution = {}
for (var_name, var_dict) in lp_solution.items():
solution[var_name] = {}
for (idx, lp_value) in var_dict.items():
solution[var_name][idx] = None
# Create a copy of classifier prototype and train it
if isinstance(self.classifier_prototype, list):
clf = deepcopy(self.classifier_prototype[label])
else:
clf = deepcopy(self.classifier_prototype)
clf.fit(x_train, y_train)
# Compute y_pred
x = self.x([instance])
y_pred = {}
for category in x.keys():
assert category in self.classifiers, (
f"Classifier for category {category} has not been trained. "
f"Please call component.fit before component.predict."
)
proba = self.classifiers[category].predict_proba(x[category])
thr = self.thresholds[category].predict(x[category])
y_pred[category] = np.vstack(
[
proba[:, 0] > thr[0],
proba[:, 1] > thr[1],
]
).T
# Find threshold (dynamic or static)
if isinstance(self.threshold_prototype, Threshold):
self.thresholds[category, label] = self.threshold_prototype.fit(
clf,
x_train,
y_train,
)
else:
self.thresholds[category, label] = deepcopy(
self.threshold_prototype
)
# Convert y_pred into solution
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
for (var_name, var_dict) in lp_solution.items():
for (idx, lp_value) in var_dict.items():
category = instance.get_variable_category(var_name, idx)
offset = category_offset[category]
category_offset[category] += 1
if y_pred[category][offset, 0]:
solution[var_name][idx] = 0.0
if y_pred[category][offset, 1]:
solution[var_name][idx] = 1.0
self.classifiers[category, label] = clf
def predict(self, instance):
solution = {}
x_test = VariableFeaturesExtractor().extract([instance])
var_split = Extractor.split_variables(instance)
for category in var_split.keys():
n = len(var_split[category])
for (i, (var, index)) in enumerate(var_split[category]):
if var not in solution.keys():
solution[var] = {}
solution[var][index] = None
for label in [0, 1]:
if (category, label) not in self.classifiers.keys():
continue
clf = self.classifiers[category, label]
if isinstance(clf, float) or isinstance(clf, int):
ws = np.array([[1 - clf, clf] for _ in range(n)])
else:
ws = clf.predict_proba(x_test[category])
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (
n,
ws.shape,
)
for (i, (var, index)) in enumerate(var_split[category]):
if ws[i, 1] >= self.thresholds[category, label]:
solution[var][index] = label
return solution
def evaluate(self, instances):
@@ -175,3 +182,82 @@ class PrimalSolutionComponent(Component):
tp_one, tn_one, fp_one, fn_one
)
return ev
@staticmethod
def _build_x_y_dict(
instances: Union[List[str], List[Instance]],
extract: Callable[
[
Instance,
TrainingSample,
str,
VarIndex,
Optional[float],
],
Union[List[bool], List[float]],
],
) -> Dict[Hashable, np.ndarray]:
result: Dict[Hashable, List] = {}
for instance in InstanceIterator(instances):
assert isinstance(instance, Instance)
for sample in instance.training_data:
# Skip training samples without solution
if "LP solution" not in sample:
continue
if sample["LP solution"] is None:
continue
# Iterate over all variables
for (var, var_dict) in sample["LP solution"].items():
for (idx, lp_value) in var_dict.items():
category = instance.get_variable_category(var, idx)
if category is None:
continue
if category not in result:
result[category] = []
result[category] += [
extract(
instance,
sample,
var,
idx,
lp_value,
)
]
# Convert result to numpy arrays and return
return {c: np.array(ft) for (c, ft) in result.items()}
@staticmethod
def _extract_variable_features(
instance: Instance,
sample: TrainingSample,
var: str,
idx: VarIndex,
lp_value: Optional[float],
) -> Union[List[bool], List[float]]:
features = instance.get_variable_features(var, idx)
if lp_value is None:
return features
else:
return features + [lp_value]
@staticmethod
def _extract_variable_labels(
instance: Instance,
sample: TrainingSample,
var: str,
idx: VarIndex,
lp_value: Optional[float],
) -> Union[List[bool], List[float]]:
assert "Solution" in sample
solution = sample["Solution"]
assert solution is not None
opt_value = solution[var][idx]
assert opt_value is not None
assert 0.0 <= opt_value <= 1.0, (
f"Variable {var} has non-binary value {opt_value} in the optimal solution. "
f"Predicting values of non-binary variables is not currently supported. "
f"Please set its category to None."
)
return [opt_value < 0.5, opt_value > 0.5]