Implement more compact get_variables

This commit is contained in:
2021-04-15 06:19:54 -05:00
parent e6eca2ee7f
commit 08f0bedbe0
5 changed files with 306 additions and 170 deletions

View File

@@ -2,10 +2,10 @@
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Any, Dict
from typing import Any, Dict, List
import numpy as np
from miplearn.features import Constraint, Variable
from miplearn.features import Constraint, Variable, VariableFeatures
from miplearn.solvers.internal import InternalSolver
inf = float("inf")
@@ -44,6 +44,30 @@ def _round_variables(vars: Dict[str, Variable]) -> Dict[str, Variable]:
return vars
def _round(obj: Any) -> Any:
if isinstance(obj, tuple):
if obj is None:
return None
return tuple([round(v, 6) for v in obj])
if isinstance(obj, VariableFeatures):
obj.reduced_costs = _round(obj.reduced_costs)
obj.sa_obj_up = _round(obj.sa_obj_up)
obj.sa_obj_down = _round(obj.sa_obj_down)
obj.sa_lb_up = _round(obj.sa_lb_up)
obj.sa_lb_down = _round(obj.sa_lb_down)
obj.sa_ub_up = _round(obj.sa_ub_up)
obj.sa_ub_down = _round(obj.sa_ub_down)
obj.values = _round(obj.values)
return obj
def _filter_attrs(allowed_keys: List[str], obj: Any) -> Any:
for key in obj.__dict__.keys():
if key not in allowed_keys:
setattr(obj, key, None)
return obj
def _remove_unsupported_constr_attrs(
solver: InternalSolver,
constraints: Dict[str, Constraint],
@@ -58,20 +82,6 @@ def _remove_unsupported_constr_attrs(
return constraints
def _remove_unsupported_var_attrs(
solver: InternalSolver,
variables: Dict[str, Variable],
) -> Dict[str, Variable]:
for (cname, c) in variables.items():
to_remove = []
for k in c.__dict__.keys():
if k not in solver.get_variable_attrs():
to_remove.append(k)
for k in to_remove:
setattr(c, k, None)
return variables
def run_internal_solver_tests(solver: InternalSolver) -> None:
run_basic_usage_tests(solver.clone())
run_warm_start_tests(solver.clone())
@@ -89,41 +99,13 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
# Fetch variables (after-load)
assert_equals(
_round_variables(solver.get_variables()),
_remove_unsupported_var_attrs(
solver,
{
"x[0]": Variable(
lower_bound=0.0,
obj_coeff=505.0,
type="B",
upper_bound=1.0,
),
"x[1]": Variable(
lower_bound=0.0,
obj_coeff=352.0,
type="B",
upper_bound=1.0,
),
"x[2]": Variable(
lower_bound=0.0,
obj_coeff=458.0,
type="B",
upper_bound=1.0,
),
"x[3]": Variable(
lower_bound=0.0,
obj_coeff=220.0,
type="B",
upper_bound=1.0,
),
"z": Variable(
lower_bound=0.0,
obj_coeff=0.0,
type="C",
upper_bound=67.0,
),
},
solver.get_variables(),
VariableFeatures(
names=("x[0]", "x[1]", "x[2]", "x[3]", "z"),
lower_bounds=(0.0, 0.0, 0.0, 0.0, 0.0),
upper_bounds=(1.0, 1.0, 1.0, 1.0, 67.0),
types=("B", "B", "B", "B", "C"),
obj_coeffs=(505.0, 352.0, 458.0, 220.0, 0.0),
),
)
@@ -150,88 +132,22 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
assert lp_stats.lp_wallclock_time is not None
assert lp_stats.lp_wallclock_time > 0
# Fetch variables (after-load)
# Fetch variables (after-lp)
assert_equals(
_round_variables(solver.get_variables()),
_remove_unsupported_var_attrs(
solver,
{
"x[0]": Variable(
basis_status="U",
lower_bound=0.0,
obj_coeff=505.0,
reduced_cost=193.615385,
sa_lb_down=-inf,
sa_lb_up=1.0,
sa_obj_down=311.384615,
sa_obj_up=inf,
sa_ub_down=0.913043,
sa_ub_up=2.043478,
type="B",
upper_bound=1.0,
value=1.0,
),
"x[1]": Variable(
basis_status="B",
lower_bound=0.0,
obj_coeff=352.0,
reduced_cost=0.0,
sa_lb_down=-inf,
sa_lb_up=0.923077,
sa_obj_down=317.777778,
sa_obj_up=570.869565,
sa_ub_down=0.923077,
sa_ub_up=inf,
type="B",
upper_bound=1.0,
value=0.923077,
),
"x[2]": Variable(
basis_status="U",
lower_bound=0.0,
obj_coeff=458.0,
reduced_cost=187.230769,
sa_lb_down=-inf,
sa_lb_up=1.0,
sa_obj_down=270.769231,
sa_obj_up=inf,
sa_ub_down=0.9,
sa_ub_up=2.2,
type="B",
upper_bound=1.0,
value=1.0,
),
"x[3]": Variable(
basis_status="L",
lower_bound=0.0,
obj_coeff=220.0,
reduced_cost=-23.692308,
sa_lb_down=-0.111111,
sa_lb_up=1.0,
sa_obj_down=-inf,
sa_obj_up=243.692308,
sa_ub_down=0.0,
sa_ub_up=inf,
type="B",
upper_bound=1.0,
value=0.0,
),
"z": Variable(
basis_status="U",
lower_bound=0.0,
obj_coeff=0.0,
reduced_cost=13.538462,
sa_lb_down=-inf,
sa_lb_up=67.0,
sa_obj_down=-13.538462,
sa_obj_up=inf,
sa_ub_down=43.0,
sa_ub_up=69.0,
type="C",
upper_bound=67.0,
value=67.0,
),
},
_round(solver.get_variables(with_static=False)),
_filter_attrs(
solver.get_variable_attrs(),
VariableFeatures(
basis_status=("U", "B", "U", "L", "U"),
reduced_costs=(193.615385, 0.0, 187.230769, -23.692308, 13.538462),
sa_lb_down=(-inf, -inf, -inf, -0.111111, -inf),
sa_lb_up=(1.0, 0.923077, 1.0, 1.0, 67.0),
sa_obj_down=(311.384615, 317.777778, 270.769231, -inf, -13.538462),
sa_obj_up=(inf, 570.869565, inf, 243.692308, inf),
sa_ub_down=(0.913043, 0.923077, 0.9, 0.0, 43.0),
sa_ub_up=(2.043478, inf, 2.2, inf, 69.0),
values=(1.0, 0.923077, 1.0, 0.0, 67.0),
),
),
)
@@ -281,16 +197,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
# Fetch variables (after-mip)
assert_equals(
_round_variables(solver.get_variables(with_static=False)),
_remove_unsupported_var_attrs(
solver,
{
"x[0]": Variable(value=1.0),
"x[1]": Variable(value=0.0),
"x[2]": Variable(value=1.0),
"x[3]": Variable(value=1.0),
"z": Variable(value=61.0),
},
_round(solver.get_variables(with_static=False)),
_filter_attrs(
solver.get_variable_attrs(),
VariableFeatures(values=(1.0, 0.0, 1.0, 1.0, 61.0)),
),
)