You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.1 KiB
75 lines
2.1 KiB
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from glob import glob
|
|
from os.path import dirname, basename, isfile
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import List, Any
|
|
|
|
import pytest
|
|
|
|
from miplearn.extractors.abstract import FeaturesExtractor
|
|
from miplearn.extractors.fields import H5FieldsExtractor
|
|
|
|
|
|
def _h5_fixture(pattern: str, request: Any) -> List[str]:
|
|
"""
|
|
Create a temporary copy of the provided .h5 files, along with the companion
|
|
.pkl.gz files, and return the path to the copy. Also register a finalizer,
|
|
so that the temporary folder is removed after the tests.
|
|
"""
|
|
filenames = glob(f"{dirname(__file__)}/fixtures/{pattern}")
|
|
print(filenames)
|
|
tmpdir = tempfile.mkdtemp()
|
|
|
|
def cleanup() -> None:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
request.addfinalizer(cleanup)
|
|
|
|
print(tmpdir)
|
|
for f in filenames:
|
|
fbase, _ = os.path.splitext(f)
|
|
for ext in [".h5", ".pkl.gz"]:
|
|
dest = os.path.join(tmpdir, f"{basename(fbase)}{ext}")
|
|
print(dest)
|
|
shutil.copy(f"{fbase}{ext}", dest)
|
|
assert isfile(dest)
|
|
return sorted(glob(f"{tmpdir}/*.h5"))
|
|
|
|
|
|
@pytest.fixture()
|
|
def multiknapsack_h5(request: Any) -> List[str]:
|
|
return _h5_fixture("multiknapsack*.h5", request)
|
|
|
|
|
|
@pytest.fixture()
|
|
def tsp_gp_h5(request: Any) -> List[str]:
|
|
return _h5_fixture("tsp-gp*.h5", request)
|
|
|
|
|
|
@pytest.fixture()
|
|
def tsp_pyo_h5(request: Any) -> List[str]:
|
|
return _h5_fixture("tsp-pyo*.h5", request)
|
|
|
|
|
|
@pytest.fixture()
|
|
def stab_gp_h5(request: Any) -> List[str]:
|
|
return _h5_fixture("stab-gp*.h5", request)
|
|
|
|
|
|
@pytest.fixture()
|
|
def stab_pyo_h5(request: Any) -> List[str]:
|
|
return _h5_fixture("stab-pyo*.h5", request)
|
|
|
|
|
|
@pytest.fixture()
|
|
def default_extractor() -> FeaturesExtractor:
|
|
return H5FieldsExtractor(
|
|
instance_fields=["static_var_obj_coeffs"],
|
|
var_fields=["lp_var_features"],
|
|
)
|