You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
MIPLearn/miplearn/components/primal/joint.py

89 lines
3.2 KiB

# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import List, Dict, Any, Optional
import numpy as np
from miplearn.components.primal import _extract_bin_var_names_values
from miplearn.components.primal.actions import PrimalComponentAction
from miplearn.extractors.abstract import FeaturesExtractor
from miplearn.solvers.abstract import AbstractModel
from miplearn.h5 import H5File
logger = logging.getLogger(__name__)
class JointVarsPrimalComponent:
def __init__(
self, clf: Any, extractor: FeaturesExtractor, action: PrimalComponentAction
):
self.clf = clf
self.extractor = extractor
self.bin_var_names_: Optional[np.ndarray] = None
self.action = action
def fit(self, train_h5: List[str]) -> None:
logger.info("Reading training data...")
self.bin_var_names_ = None
x, y, n_samples, n_features = [], [], len(train_h5), None
for h5_filename in train_h5:
with H5File(h5_filename, "r") as h5:
bin_var_names, bin_var_values, _ = _extract_bin_var_names_values(h5)
# Store/check variable names
if self.bin_var_names_ is None:
self.bin_var_names_ = bin_var_names
else:
assert np.all(bin_var_names == self.bin_var_names_)
# Build x and y vectors
x_sample = self.extractor.get_instance_features(h5)
assert len(x_sample.shape) == 1
if n_features is None:
n_features = len(x_sample)
else:
assert len(x_sample) == n_features
x.append(x_sample)
y.append(bin_var_values)
assert self.bin_var_names_ is not None
logger.info("Constructing matrices...")
x_np = np.vstack(x)
y_np = np.array(y)
assert len(x_np.shape) == 2
assert x_np.shape[0] == n_samples
assert x_np.shape[1] == n_features
assert y_np.shape == (n_samples, len(self.bin_var_names_))
logger.info(
f"Dataset has {n_samples:,d} samples, "
f"{n_features:,d} features and {y_np.shape[1]:,d} targets"
)
logger.info("Training classifier...")
self.clf.fit(x_np, y_np)
logger.info("Done fitting.")
def before_mip(
self, test_h5: str, model: AbstractModel, stats: Dict[str, Any]
) -> None:
assert self.bin_var_names_ is not None
# Read features
with H5File(test_h5, "r") as h5:
x_sample = self.extractor.get_instance_features(h5)
assert len(x_sample.shape) == 1
x_sample = x_sample.reshape(1, -1)
# Predict optimal solution
logger.info("Predicting warm starts...")
y_pred = self.clf.predict(x_sample)
assert len(y_pred.shape) == 2
assert y_pred.shape[0] == 1
assert y_pred.shape[1] == len(self.bin_var_names_)
# Construct warm starts, based on prediction
self.action.perform(model, self.bin_var_names_, y_pred, stats)