Use np.array for Variables.names

This commit is contained in:
2021-08-08 07:24:14 -05:00
parent f69067aafd
commit 7d55d6f34c
10 changed files with 96 additions and 76 deletions

View File

@@ -22,7 +22,7 @@ from miplearn.solvers.tests import assert_equals
def sample() -> Sample:
sample = MemorySample(
{
"static_var_names": ["x[0]", "x[1]", "x[2]", "x[3]"],
"static_var_names": np.array(["x[0]", "x[1]", "x[2]", "x[3]"], dtype="S"),
"static_var_categories": ["default", None, "default", "default"],
"mip_var_values": np.array([0.0, 1.0, 1.0, 0.0]),
"static_instance_features": [5.0],
@@ -112,10 +112,10 @@ def test_usage() -> None:
def test_evaluate(sample: Sample) -> None:
comp = PrimalSolutionComponent()
comp.sample_predict = lambda _: { # type: ignore
"x[0]": 1.0,
"x[1]": 1.0,
"x[2]": 0.0,
"x[3]": None,
b"x[0]": 1.0,
b"x[1]": 1.0,
b"x[2]": 0.0,
b"x[3]": None,
}
ev = comp.sample_evaluate(None, sample)
assert_equals(
@@ -150,8 +150,8 @@ def test_predict(sample: Sample) -> None:
assert_array_equal(x["default"], clf.predict_proba.call_args[0][0])
assert_array_equal(x["default"], thr.predict.call_args[0][0])
assert pred == {
"x[0]": 0.0,
"x[1]": None,
"x[2]": None,
"x[3]": 1.0,
b"x[0]": 0.0,
b"x[1]": None,
b"x[2]": None,
b"x[3]": 1.0,
}

View File

@@ -33,7 +33,7 @@ def test_knapsack() -> None:
extractor.extract_after_load_features(instance, solver, sample)
assert_equals(
sample.get_vector("static_var_names"),
["x[0]", "x[1]", "x[2]", "x[3]", "z"],
np.array(["x[0]", "x[1]", "x[2]", "x[3]", "z"], dtype="S"),
)
assert_equals(
sample.get_vector("static_var_lower_bounds"), [0.0, 0.0, 0.0, 0.0, 0.0]
@@ -126,16 +126,16 @@ def test_constraint_getindex() -> None:
senses=["=", "<", ">"],
lhs=[
[
("x1", 1.0),
("x2", 1.0),
(b"x1", 1.0),
(b"x2", 1.0),
],
[
("x2", 2.0),
("x3", 2.0),
(b"x2", 2.0),
(b"x3", 2.0),
],
[
("x3", 3.0),
("x4", 3.0),
(b"x3", 3.0),
(b"x4", 3.0),
],
],
)
@@ -147,12 +147,12 @@ def test_constraint_getindex() -> None:
senses=["=", ">"],
lhs=[
[
("x1", 1.0),
("x2", 1.0),
(b"x1", 1.0),
(b"x2", 1.0),
],
[
("x3", 3.0),
("x4", 3.0),
(b"x3", 3.0),
(b"x4", 3.0),
],
],
),