mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make lazy constr component compatible with Pyomo+Gurobi
This commit is contained in:
@@ -10,53 +10,60 @@ from sklearn.neighbors import KNeighborsClassifier
|
||||
|
||||
from miplearn.components.lazy.mem import MemorizingLazyComponent
|
||||
from miplearn.extractors.abstract import FeaturesExtractor
|
||||
from miplearn.problems.tsp import build_tsp_model
|
||||
from miplearn.problems.tsp import build_tsp_model_gurobipy, build_tsp_model_pyomo
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
|
||||
|
||||
def test_mem_component(
|
||||
tsp_h5: List[str],
|
||||
tsp_gp_h5: List[str],
|
||||
tsp_pyo_h5: List[str],
|
||||
default_extractor: FeaturesExtractor,
|
||||
) -> None:
|
||||
clf = Mock(wraps=DummyClassifier())
|
||||
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
|
||||
comp.fit(tsp_h5)
|
||||
for h5 in [tsp_gp_h5, tsp_pyo_h5]:
|
||||
clf = Mock(wraps=DummyClassifier())
|
||||
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
|
||||
comp.fit(tsp_gp_h5)
|
||||
|
||||
# Should call fit method with correct arguments
|
||||
clf.fit.assert_called()
|
||||
x, y = clf.fit.call_args.args
|
||||
assert x.shape == (3, 190)
|
||||
assert y.tolist() == [
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
# Should call fit method with correct arguments
|
||||
clf.fit.assert_called()
|
||||
x, y = clf.fit.call_args.args
|
||||
assert x.shape == (3, 190)
|
||||
assert y.tolist() == [
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
|
||||
[1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0],
|
||||
[1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
|
||||
# Should store violations
|
||||
assert comp.constrs_ is not None
|
||||
assert comp.n_features_ == 190
|
||||
assert comp.n_targets_ == 22
|
||||
assert len(comp.constrs_) == 22
|
||||
# Should store violations
|
||||
assert comp.constrs_ is not None
|
||||
assert comp.n_features_ == 190
|
||||
assert comp.n_targets_ == 20
|
||||
assert len(comp.constrs_) == 20
|
||||
|
||||
# Call before-mip
|
||||
stats: Dict[str, Any] = {}
|
||||
model = Mock()
|
||||
comp.before_mip(tsp_h5[0], model, stats)
|
||||
# Call before-mip
|
||||
stats: Dict[str, Any] = {}
|
||||
model = Mock()
|
||||
comp.before_mip(tsp_gp_h5[0], model, stats)
|
||||
|
||||
# Should call predict with correct args
|
||||
clf.predict.assert_called()
|
||||
(x_test,) = clf.predict.call_args.args
|
||||
assert x_test.shape == (1, 190)
|
||||
# Should call predict with correct args
|
||||
clf.predict.assert_called()
|
||||
(x_test,) = clf.predict.call_args.args
|
||||
assert x_test.shape == (1, 190)
|
||||
|
||||
|
||||
def test_usage_tsp(
|
||||
tsp_h5: List[str],
|
||||
tsp_gp_h5: List[str],
|
||||
tsp_pyo_h5: List[str],
|
||||
default_extractor: FeaturesExtractor,
|
||||
) -> None:
|
||||
# Should not crash
|
||||
data_filenames = [f.replace(".h5", ".pkl.gz") for f in tsp_h5]
|
||||
clf = KNeighborsClassifier(n_neighbors=1)
|
||||
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
|
||||
solver = LearningSolver(components=[comp])
|
||||
solver.fit(data_filenames)
|
||||
solver.optimize(data_filenames[0], build_tsp_model)
|
||||
for (h5, build_model) in [
|
||||
(tsp_pyo_h5, build_tsp_model_pyomo),
|
||||
(tsp_gp_h5, build_tsp_model_gurobipy),
|
||||
]:
|
||||
data_filenames = [f.replace(".h5", ".pkl.gz") for f in h5]
|
||||
clf = KNeighborsClassifier(n_neighbors=1)
|
||||
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
|
||||
solver = LearningSolver(components=[comp])
|
||||
solver.fit(data_filenames)
|
||||
stats = solver.optimize(data_filenames[0], build_model) # type: ignore
|
||||
assert stats["Lazy Constraints: AOT"] > 0
|
||||
|
||||
@@ -47,8 +47,13 @@ def multiknapsack_h5(request: Any) -> List[str]:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tsp_h5(request: Any) -> List[str]:
|
||||
return _h5_fixture("tsp*.h5", request)
|
||||
def tsp_gp_h5(request: Any) -> List[str]:
|
||||
return _h5_fixture("tsp-gp*.h5", request)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tsp_pyo_h5(request: Any) -> List[str]:
|
||||
return _h5_fixture("tsp-pyo*.h5", request)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
30
tests/fixtures/gen_tsp.py
vendored
30
tests/fixtures/gen_tsp.py
vendored
@@ -5,7 +5,11 @@ from scipy.stats import uniform, randint
|
||||
|
||||
from miplearn.collectors.basic import BasicCollector
|
||||
from miplearn.io import write_pkl_gz
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator, build_tsp_model
|
||||
from miplearn.problems.tsp import (
|
||||
TravelingSalesmanGenerator,
|
||||
build_tsp_model_gurobipy,
|
||||
build_tsp_model_pyomo,
|
||||
)
|
||||
|
||||
np.random.seed(42)
|
||||
gen = TravelingSalesmanGenerator(
|
||||
@@ -16,7 +20,27 @@ gen = TravelingSalesmanGenerator(
|
||||
fix_cities=True,
|
||||
round=True,
|
||||
)
|
||||
|
||||
data = gen.generate(3)
|
||||
data_filenames = write_pkl_gz(data, dirname(__file__), prefix="tsp-n20-")
|
||||
|
||||
params = {"seed": 42, "threads": 1}
|
||||
|
||||
# Gurobipy
|
||||
data_filenames = write_pkl_gz(data, dirname(__file__), prefix="tsp-gp-n20-")
|
||||
collector = BasicCollector()
|
||||
collector.collect(data_filenames, build_tsp_model)
|
||||
collector.collect(
|
||||
data_filenames,
|
||||
lambda d: build_tsp_model_gurobipy(d, params=params),
|
||||
progress=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Pyomo
|
||||
data_filenames = write_pkl_gz(data, dirname(__file__), prefix="tsp-pyo-n20-")
|
||||
collector = BasicCollector()
|
||||
collector.collect(
|
||||
data_filenames,
|
||||
lambda d: build_tsp_model_pyomo(d, params=params),
|
||||
progress=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
BIN
tests/fixtures/tsp-gp-n20-00000.h5
vendored
Normal file
BIN
tests/fixtures/tsp-gp-n20-00000.h5
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-gp-n20-00000.mps.gz
vendored
Normal file
BIN
tests/fixtures/tsp-gp-n20-00000.mps.gz
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tests/fixtures/tsp-gp-n20-00001.h5
vendored
Normal file
BIN
tests/fixtures/tsp-gp-n20-00001.h5
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-gp-n20-00001.mps.gz
vendored
Normal file
BIN
tests/fixtures/tsp-gp-n20-00001.mps.gz
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tests/fixtures/tsp-gp-n20-00002.h5
vendored
Normal file
BIN
tests/fixtures/tsp-gp-n20-00002.h5
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-gp-n20-00002.mps.gz
vendored
Normal file
BIN
tests/fixtures/tsp-gp-n20-00002.mps.gz
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-gp-n20-00002.pkl.gz
vendored
Normal file
BIN
tests/fixtures/tsp-gp-n20-00002.pkl.gz
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00000.h5
vendored
BIN
tests/fixtures/tsp-n20-00000.h5
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00001.h5
vendored
BIN
tests/fixtures/tsp-n20-00001.h5
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00001.mps.gz
vendored
BIN
tests/fixtures/tsp-n20-00001.mps.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00002.h5
vendored
BIN
tests/fixtures/tsp-n20-00002.h5
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00002.mps.gz
vendored
BIN
tests/fixtures/tsp-n20-00002.mps.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00002.pkl.gz
vendored
BIN
tests/fixtures/tsp-n20-00002.pkl.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00000.h5
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00000.h5
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00000.mps.gz
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00000.mps.gz
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00000.pkl.gz
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00000.pkl.gz
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00001.h5
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00001.h5
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00001.mps.gz
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00001.mps.gz
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00001.pkl.gz
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00001.pkl.gz
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00002.h5
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00002.h5
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00002.mps.gz
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00002.mps.gz
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/tsp-pyo-n20-00002.pkl.gz
vendored
Normal file
BIN
tests/fixtures/tsp-pyo-n20-00002.pkl.gz
vendored
Normal file
Binary file not shown.
@@ -6,7 +6,7 @@ import numpy as np
|
||||
from miplearn.problems.tsp import (
|
||||
TravelingSalesmanData,
|
||||
TravelingSalesmanGenerator,
|
||||
build_tsp_model,
|
||||
build_tsp_model_gurobipy,
|
||||
)
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
from scipy.stats import randint, uniform
|
||||
@@ -51,7 +51,7 @@ def test_tsp() -> None:
|
||||
)
|
||||
),
|
||||
)
|
||||
model = build_tsp_model(data)
|
||||
model = build_tsp_model_gurobipy(data)
|
||||
model.optimize()
|
||||
assert model.inner.getAttr("x", model.inner.getVars()) == [
|
||||
1.0,
|
||||
|
||||
Reference in New Issue
Block a user