mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
MIPLearn v0.3
This commit is contained in:
3
tests/extractors/__init__.py
Normal file
3
tests/extractors/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
19
tests/extractors/test_dummy.py
Normal file
19
tests/extractors/test_dummy.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from typing import List
|
||||
|
||||
from miplearn.extractors.dummy import DummyExtractor
|
||||
from miplearn.h5 import H5File
|
||||
|
||||
|
||||
def test_dummy(multiknapsack_h5: List[str]) -> None:
|
||||
ext = DummyExtractor()
|
||||
with H5File(multiknapsack_h5[0], "r") as h5:
|
||||
x = ext.get_instance_features(h5)
|
||||
assert x.shape == (1,)
|
||||
x = ext.get_var_features(h5)
|
||||
assert x.shape == (100, 1)
|
||||
x = ext.get_constr_features(h5)
|
||||
assert x.shape == (4, 1)
|
||||
33
tests/extractors/test_fields.py
Normal file
33
tests/extractors/test_fields.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from miplearn.extractors.fields import H5FieldsExtractor
|
||||
from miplearn.h5 import H5File
|
||||
|
||||
|
||||
def test_fields_instance(multiknapsack_h5: List[str]) -> None:
|
||||
ext = H5FieldsExtractor(
|
||||
instance_fields=[
|
||||
"lp_obj_value",
|
||||
"lp_var_values",
|
||||
"static_var_obj_coeffs",
|
||||
],
|
||||
var_fields=["lp_var_values"],
|
||||
)
|
||||
with H5File(multiknapsack_h5[0], "r") as h5:
|
||||
x = ext.get_instance_features(h5)
|
||||
assert x.shape == (201,)
|
||||
|
||||
x = ext.get_var_features(h5)
|
||||
assert x.shape == (100, 1)
|
||||
|
||||
|
||||
def test_fields_instance_none(multiknapsack_h5: List[str]) -> None:
|
||||
ext = H5FieldsExtractor(instance_fields=None)
|
||||
with H5File(multiknapsack_h5[0], "r") as h5:
|
||||
with pytest.raises(Exception):
|
||||
ext.get_instance_features(h5)
|
||||
Reference in New Issue
Block a user