mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Implement more compact get_variables
This commit is contained in:
@@ -2,10 +2,10 @@
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
import numpy as np
|
||||
|
||||
from miplearn.features import Constraint, Variable
|
||||
from miplearn.features import Constraint, Variable, VariableFeatures
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
|
||||
inf = float("inf")
|
||||
@@ -44,6 +44,30 @@ def _round_variables(vars: Dict[str, Variable]) -> Dict[str, Variable]:
|
||||
return vars
|
||||
|
||||
|
||||
def _round(obj: Any) -> Any:
|
||||
if isinstance(obj, tuple):
|
||||
if obj is None:
|
||||
return None
|
||||
return tuple([round(v, 6) for v in obj])
|
||||
if isinstance(obj, VariableFeatures):
|
||||
obj.reduced_costs = _round(obj.reduced_costs)
|
||||
obj.sa_obj_up = _round(obj.sa_obj_up)
|
||||
obj.sa_obj_down = _round(obj.sa_obj_down)
|
||||
obj.sa_lb_up = _round(obj.sa_lb_up)
|
||||
obj.sa_lb_down = _round(obj.sa_lb_down)
|
||||
obj.sa_ub_up = _round(obj.sa_ub_up)
|
||||
obj.sa_ub_down = _round(obj.sa_ub_down)
|
||||
obj.values = _round(obj.values)
|
||||
return obj
|
||||
|
||||
|
||||
def _filter_attrs(allowed_keys: List[str], obj: Any) -> Any:
|
||||
for key in obj.__dict__.keys():
|
||||
if key not in allowed_keys:
|
||||
setattr(obj, key, None)
|
||||
return obj
|
||||
|
||||
|
||||
def _remove_unsupported_constr_attrs(
|
||||
solver: InternalSolver,
|
||||
constraints: Dict[str, Constraint],
|
||||
@@ -58,20 +82,6 @@ def _remove_unsupported_constr_attrs(
|
||||
return constraints
|
||||
|
||||
|
||||
def _remove_unsupported_var_attrs(
|
||||
solver: InternalSolver,
|
||||
variables: Dict[str, Variable],
|
||||
) -> Dict[str, Variable]:
|
||||
for (cname, c) in variables.items():
|
||||
to_remove = []
|
||||
for k in c.__dict__.keys():
|
||||
if k not in solver.get_variable_attrs():
|
||||
to_remove.append(k)
|
||||
for k in to_remove:
|
||||
setattr(c, k, None)
|
||||
return variables
|
||||
|
||||
|
||||
def run_internal_solver_tests(solver: InternalSolver) -> None:
|
||||
run_basic_usage_tests(solver.clone())
|
||||
run_warm_start_tests(solver.clone())
|
||||
@@ -89,41 +99,13 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
|
||||
|
||||
# Fetch variables (after-load)
|
||||
assert_equals(
|
||||
_round_variables(solver.get_variables()),
|
||||
_remove_unsupported_var_attrs(
|
||||
solver,
|
||||
{
|
||||
"x[0]": Variable(
|
||||
lower_bound=0.0,
|
||||
obj_coeff=505.0,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
),
|
||||
"x[1]": Variable(
|
||||
lower_bound=0.0,
|
||||
obj_coeff=352.0,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
),
|
||||
"x[2]": Variable(
|
||||
lower_bound=0.0,
|
||||
obj_coeff=458.0,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
),
|
||||
"x[3]": Variable(
|
||||
lower_bound=0.0,
|
||||
obj_coeff=220.0,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
),
|
||||
"z": Variable(
|
||||
lower_bound=0.0,
|
||||
obj_coeff=0.0,
|
||||
type="C",
|
||||
upper_bound=67.0,
|
||||
),
|
||||
},
|
||||
solver.get_variables(),
|
||||
VariableFeatures(
|
||||
names=("x[0]", "x[1]", "x[2]", "x[3]", "z"),
|
||||
lower_bounds=(0.0, 0.0, 0.0, 0.0, 0.0),
|
||||
upper_bounds=(1.0, 1.0, 1.0, 1.0, 67.0),
|
||||
types=("B", "B", "B", "B", "C"),
|
||||
obj_coeffs=(505.0, 352.0, 458.0, 220.0, 0.0),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -150,88 +132,22 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
|
||||
assert lp_stats.lp_wallclock_time is not None
|
||||
assert lp_stats.lp_wallclock_time > 0
|
||||
|
||||
# Fetch variables (after-load)
|
||||
# Fetch variables (after-lp)
|
||||
assert_equals(
|
||||
_round_variables(solver.get_variables()),
|
||||
_remove_unsupported_var_attrs(
|
||||
solver,
|
||||
{
|
||||
"x[0]": Variable(
|
||||
basis_status="U",
|
||||
lower_bound=0.0,
|
||||
obj_coeff=505.0,
|
||||
reduced_cost=193.615385,
|
||||
sa_lb_down=-inf,
|
||||
sa_lb_up=1.0,
|
||||
sa_obj_down=311.384615,
|
||||
sa_obj_up=inf,
|
||||
sa_ub_down=0.913043,
|
||||
sa_ub_up=2.043478,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
value=1.0,
|
||||
),
|
||||
"x[1]": Variable(
|
||||
basis_status="B",
|
||||
lower_bound=0.0,
|
||||
obj_coeff=352.0,
|
||||
reduced_cost=0.0,
|
||||
sa_lb_down=-inf,
|
||||
sa_lb_up=0.923077,
|
||||
sa_obj_down=317.777778,
|
||||
sa_obj_up=570.869565,
|
||||
sa_ub_down=0.923077,
|
||||
sa_ub_up=inf,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
value=0.923077,
|
||||
),
|
||||
"x[2]": Variable(
|
||||
basis_status="U",
|
||||
lower_bound=0.0,
|
||||
obj_coeff=458.0,
|
||||
reduced_cost=187.230769,
|
||||
sa_lb_down=-inf,
|
||||
sa_lb_up=1.0,
|
||||
sa_obj_down=270.769231,
|
||||
sa_obj_up=inf,
|
||||
sa_ub_down=0.9,
|
||||
sa_ub_up=2.2,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
value=1.0,
|
||||
),
|
||||
"x[3]": Variable(
|
||||
basis_status="L",
|
||||
lower_bound=0.0,
|
||||
obj_coeff=220.0,
|
||||
reduced_cost=-23.692308,
|
||||
sa_lb_down=-0.111111,
|
||||
sa_lb_up=1.0,
|
||||
sa_obj_down=-inf,
|
||||
sa_obj_up=243.692308,
|
||||
sa_ub_down=0.0,
|
||||
sa_ub_up=inf,
|
||||
type="B",
|
||||
upper_bound=1.0,
|
||||
value=0.0,
|
||||
),
|
||||
"z": Variable(
|
||||
basis_status="U",
|
||||
lower_bound=0.0,
|
||||
obj_coeff=0.0,
|
||||
reduced_cost=13.538462,
|
||||
sa_lb_down=-inf,
|
||||
sa_lb_up=67.0,
|
||||
sa_obj_down=-13.538462,
|
||||
sa_obj_up=inf,
|
||||
sa_ub_down=43.0,
|
||||
sa_ub_up=69.0,
|
||||
type="C",
|
||||
upper_bound=67.0,
|
||||
value=67.0,
|
||||
),
|
||||
},
|
||||
_round(solver.get_variables(with_static=False)),
|
||||
_filter_attrs(
|
||||
solver.get_variable_attrs(),
|
||||
VariableFeatures(
|
||||
basis_status=("U", "B", "U", "L", "U"),
|
||||
reduced_costs=(193.615385, 0.0, 187.230769, -23.692308, 13.538462),
|
||||
sa_lb_down=(-inf, -inf, -inf, -0.111111, -inf),
|
||||
sa_lb_up=(1.0, 0.923077, 1.0, 1.0, 67.0),
|
||||
sa_obj_down=(311.384615, 317.777778, 270.769231, -inf, -13.538462),
|
||||
sa_obj_up=(inf, 570.869565, inf, 243.692308, inf),
|
||||
sa_ub_down=(0.913043, 0.923077, 0.9, 0.0, 43.0),
|
||||
sa_ub_up=(2.043478, inf, 2.2, inf, 69.0),
|
||||
values=(1.0, 0.923077, 1.0, 0.0, 67.0),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -281,16 +197,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
|
||||
|
||||
# Fetch variables (after-mip)
|
||||
assert_equals(
|
||||
_round_variables(solver.get_variables(with_static=False)),
|
||||
_remove_unsupported_var_attrs(
|
||||
solver,
|
||||
{
|
||||
"x[0]": Variable(value=1.0),
|
||||
"x[1]": Variable(value=0.0),
|
||||
"x[2]": Variable(value=1.0),
|
||||
"x[3]": Variable(value=1.0),
|
||||
"z": Variable(value=61.0),
|
||||
},
|
||||
_round(solver.get_variables(with_static=False)),
|
||||
_filter_attrs(
|
||||
solver.get_variable_attrs(),
|
||||
VariableFeatures(values=(1.0, 0.0, 1.0, 1.0, 61.0)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user