Implement MemorizingLazyConstrComponent

This commit is contained in:
2023-10-26 15:37:05 -05:00
parent 2d07a44f7d
commit c1adc0b79e
20 changed files with 202 additions and 169 deletions

View File

View File

@@ -0,0 +1,62 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import List, Dict, Any
from unittest.mock import Mock
from sklearn.dummy import DummyClassifier
from sklearn.neighbors import KNeighborsClassifier
from miplearn.components.lazy.mem import MemorizingLazyConstrComponent
from miplearn.extractors.abstract import FeaturesExtractor
from miplearn.problems.tsp import build_tsp_model
from miplearn.solvers.learning import LearningSolver
def test_mem_component(
tsp_h5: List[str],
default_extractor: FeaturesExtractor,
) -> None:
clf = Mock(wraps=DummyClassifier())
comp = MemorizingLazyConstrComponent(clf=clf, extractor=default_extractor)
comp.fit(tsp_h5)
# Should call fit method with correct arguments
clf.fit.assert_called()
x, y = clf.fit.call_args.args
assert x.shape == (3, 190)
assert y.tolist() == [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0],
[1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1],
]
# Should store violations
assert comp.violations_ is not None
assert comp.n_features_ == 190
assert comp.n_targets_ == 22
assert len(comp.violations_) == 22
# Call before-mip
stats: Dict[str, Any] = {}
model = Mock()
comp.before_mip(tsp_h5[0], model, stats)
# Should call predict with correct args
clf.predict.assert_called()
(x_test,) = clf.predict.call_args.args
assert x_test.shape == (1, 190)
def test_usage_tsp(
tsp_h5: List[str],
default_extractor: FeaturesExtractor,
) -> None:
# Should not crash
data_filenames = [f.replace(".h5", ".pkl.gz") for f in tsp_h5]
clf = KNeighborsClassifier(n_neighbors=1)
comp = MemorizingLazyConstrComponent(clf=clf, extractor=default_extractor)
solver = LearningSolver(components=[comp])
solver.fit(data_filenames)
solver.optimize(data_filenames[0], build_tsp_model)

View File

@@ -20,7 +20,8 @@ logger = logging.getLogger(__name__)
def test_mem_component(
multiknapsack_h5: List[str], default_extractor: FeaturesExtractor
multiknapsack_h5: List[str],
default_extractor: FeaturesExtractor,
) -> None:
# Create mock classifier
clf = Mock(wraps=DummyClassifier())

View File

@@ -8,13 +8,18 @@ from typing import List
import pytest
from miplearn.extractors.fields import H5FieldsExtractor
from miplearn.extractors.abstract import FeaturesExtractor
from miplearn.extractors.fields import H5FieldsExtractor
@pytest.fixture()
def multiknapsack_h5() -> List[str]:
return sorted(glob(f"{dirname(__file__)}/fixtures/multiknapsack*.h5"))
return sorted(glob(f"{dirname(__file__)}/fixtures/multiknapsack-n100*.h5"))
@pytest.fixture()
def tsp_h5() -> List[str]:
return sorted(glob(f"{dirname(__file__)}/fixtures/tsp-n20*.h5"))
@pytest.fixture()

22
tests/fixtures/gen_tsp.py vendored Normal file
View File

@@ -0,0 +1,22 @@
from os.path import dirname
import numpy as np
from scipy.stats import uniform, randint
from miplearn.collectors.basic import BasicCollector
from miplearn.io import write_pkl_gz
from miplearn.problems.tsp import TravelingSalesmanGenerator, build_tsp_model
np.random.seed(42)
gen = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0),
n=randint(low=20, high=21),
gamma=uniform(loc=1.0, scale=0.25),
fix_cities=True,
round=True,
)
data = gen.generate(3)
data_filenames = write_pkl_gz(data, dirname(__file__), prefix="tsp-n20-")
collector = BasicCollector()
collector.collect(data_filenames, build_tsp_model)

BIN
tests/fixtures/tsp-n20-00000.h5 vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00000.mps.gz vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00000.pkl.gz vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00001.h5 vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00001.mps.gz vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00001.pkl.gz vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00002.h5 vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00002.mps.gz vendored Normal file

Binary file not shown.

BIN
tests/fixtures/tsp-n20-00002.pkl.gz vendored Normal file

Binary file not shown.