Convert TrainingSample to dataclass

This commit is contained in:
2021-04-05 20:36:04 -05:00
parent aeed338837
commit b11779817a
15 changed files with 122 additions and 129 deletions

View File

@@ -93,7 +93,7 @@ class StaticLazyConstraintsComponent(Component):
features: Features,
training_data: TrainingSample,
) -> None:
training_data["LazyStatic: Enforced"] = self.enforced_cids
training_data.lazy_enforced = self.enforced_cids
stats["LazyStatic: Restored"] = self.n_restored
stats["LazyStatic: Iterations"] = self.n_iterations
@@ -188,8 +188,8 @@ class StaticLazyConstraintsComponent(Component):
x[category] = []
y[category] = []
x[category] += [cfeatures.user_features]
if "LazyStatic: Enforced" in sample:
if cid in sample["LazyStatic: Enforced"]:
if sample.lazy_enforced is not None:
if cid in sample.lazy_enforced:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]

View File

@@ -82,12 +82,14 @@ class ObjectiveValueComponent(Component):
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {}
f = list(features.instance.user_features)
if "LP value" in sample and sample["LP value"] is not None:
f += [sample["LP value"]]
for c in ["Upper bound", "Lower bound"]:
x[c] = [f]
if c in sample and sample[c] is not None: # type: ignore
y[c] = [[sample[c]]] # type: ignore
if sample.lp_value is not None:
f += [sample.lp_value]
x["Upper bound"] = [f]
x["Lower bound"] = [f]
if sample.lower_bound is not None:
y["Lower bound"] = [[sample.lower_bound]]
if sample.upper_bound is not None:
y["Upper bound"] = [[sample.upper_bound]]
return x, y
def sample_evaluate(
@@ -106,7 +108,8 @@ class ObjectiveValueComponent(Component):
result: Dict[Hashable, Dict[str, float]] = {}
pred = self.sample_predict(features, sample)
for c in ["Upper bound", "Lower bound"]:
if c in sample and sample[c] is not None: # type: ignore
result[c] = compare(pred[c], sample[c]) # type: ignore
if sample.upper_bound is not None:
result["Upper bound"] = compare(pred["Upper bound"], sample.upper_bound)
if sample.lower_bound is not None:
result["Lower bound"] = compare(pred["Lower bound"], sample.lower_bound)
return result

View File

@@ -155,8 +155,8 @@ class PrimalSolutionComponent(Component):
x: Dict = {}
y: Dict = {}
solution: Optional[Solution] = None
if "Solution" in sample and sample["Solution"] is not None:
solution = sample["Solution"]
if sample.solution is not None:
solution = sample.solution
for (var_name, var_dict) in features.variables.items():
for (idx, var_features) in var_dict.items():
category = var_features.category
@@ -168,8 +168,8 @@ class PrimalSolutionComponent(Component):
f: List[float] = []
assert var_features.user_features is not None
f += var_features.user_features
if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var_name][idx]
if sample.lp_solution is not None:
lp_value = sample.lp_solution[var_name][idx]
if lp_value is not None:
f += [lp_value]
x[category] += [f]
@@ -190,7 +190,7 @@ class PrimalSolutionComponent(Component):
features: Features,
sample: TrainingSample,
) -> Dict[Hashable, Dict[str, float]]:
solution_actual = sample["Solution"]
solution_actual = sample.solution
assert solution_actual is not None
solution_pred = self.sample_predict(features, sample)
vars_all, vars_one, vars_zero = set(), set(), set()

View File

@@ -95,8 +95,8 @@ class ConvertTightIneqsIntoEqsStep(Component):
features,
training_data,
):
if "slacks" not in training_data.keys():
training_data["slacks"] = solver.internal_solver.get_inequality_slacks()
if training_data.slacks is None:
training_data.slacks = solver.internal_solver.get_inequality_slacks()
stats["ConvertTight: Restored"] = self.n_restored
stats["ConvertTight: Inf iterations"] = self.n_infeasible_iterations
stats["ConvertTight: Subopt iterations"] = self.n_suboptimal_iterations
@@ -120,7 +120,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
disable=len(instances) < 5,
):
for training_data in instance.training_data:
cids = training_data["slacks"].keys()
cids = training_data.slacks.keys()
for cid in cids:
category = instance.get_constraint_category(cid)
if category is None:
@@ -142,7 +142,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
desc="Extract (rlx:conv_ineqs:y)",
disable=len(instances) < 5,
):
for (cid, slack) in instance.training_data[0]["slacks"].items():
for (cid, slack) in instance.training_data[0].slacks.items():
category = instance.get_constraint_category(cid)
if category is None:
continue

View File

@@ -96,8 +96,8 @@ class DropRedundantInequalitiesStep(Component):
features,
training_data,
):
if "slacks" not in training_data.keys():
training_data["slacks"] = solver.internal_solver.get_inequality_slacks()
if training_data.slacks is None:
training_data.slacks = solver.internal_solver.get_inequality_slacks()
stats["DropRedundant: Iterations"] = self.n_iterations
stats["DropRedundant: Restored"] = self.n_restored
@@ -131,7 +131,7 @@ class DropRedundantInequalitiesStep(Component):
x = {}
y = {}
for training_data in instance.training_data:
for (cid, slack) in training_data["slacks"].items():
for (cid, slack) in training_data.slacks.items():
category = instance.get_constraint_category(cid)
if category is None:
continue

View File

@@ -18,7 +18,7 @@ class Extractor(ABC):
@staticmethod
def split_variables(instance):
result = {}
lp_solution = instance.training_data[0]["LP solution"]
lp_solution = instance.training_data[0].lp_solution
for var_name in lp_solution:
for index in lp_solution[var_name]:
category = instance.get_variable_category(var_name, index)
@@ -37,7 +37,7 @@ class InstanceFeaturesExtractor(Extractor):
np.hstack(
[
instance.get_instance_features(),
instance.training_data[0]["LP value"],
instance.training_data[0].lp_value,
]
)
for instance in instances

View File

@@ -2,13 +2,9 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import gzip
import logging
import os
import pickle
import traceback
import tempfile
from typing import Optional, List, Any, IO, cast, BinaryIO, Union, Callable, Dict
from typing import Optional, List, Any, cast, Callable, Dict
from p_tqdm import p_map
@@ -22,7 +18,7 @@ from miplearn.instance import Instance, PickleGzInstance
from miplearn.solvers import _RedirectOutput
from miplearn.solvers.internal import InternalSolver
from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
from miplearn.types import TrainingSample, LearningSolveStats, MIPSolveStats
from miplearn.types import TrainingSample, LearningSolveStats
logger = logging.getLogger(__name__)
@@ -134,7 +130,7 @@ class LearningSolver:
model = instance.to_model()
# Initialize training sample
training_sample: TrainingSample = {}
training_sample = TrainingSample()
instance.training_data += [training_sample]
# Initialize stats
@@ -168,16 +164,13 @@ class LearningSolver:
logger.info("Solving root LP relaxation...")
lp_stats = self.internal_solver.solve_lp(tee=tee)
stats.update(cast(LearningSolveStats, lp_stats))
training_sample["LP solution"] = self.internal_solver.get_solution()
training_sample["LP value"] = lp_stats["LP value"]
training_sample["LP log"] = lp_stats["LP log"]
training_sample.lp_solution = self.internal_solver.get_solution()
training_sample.lp_value = lp_stats["LP value"]
training_sample.lp_log = lp_stats["LP log"]
logger.debug("Running after_solve_lp callbacks...")
for component in self.components.values():
component.after_solve_lp(*callback_args)
else:
training_sample["LP solution"] = self.internal_solver.get_empty_solution()
training_sample["LP value"] = 0.0
# Define wrappers
def iteration_cb_wrapper() -> bool:
@@ -213,8 +206,8 @@ class LearningSolver:
lazy_cb=lazy_cb,
)
stats.update(cast(LearningSolveStats, mip_stats))
if "LP value" in training_sample.keys():
stats["LP value"] = training_sample["LP value"]
if training_sample.lp_value is not None:
stats["LP value"] = training_sample.lp_value
stats["Solver"] = "default"
stats["Gap"] = self._compute_gap(
ub=stats["Upper bound"],
@@ -223,10 +216,10 @@ class LearningSolver:
stats["Mode"] = self.mode
# Add some information to training_sample
training_sample["Lower bound"] = stats["Lower bound"]
training_sample["Upper bound"] = stats["Upper bound"]
training_sample["MIP log"] = stats["MIP log"]
training_sample["Solution"] = self.internal_solver.get_solution()
training_sample.lower_bound = stats["Lower bound"]
training_sample.upper_bound = stats["Upper bound"]
training_sample.mip_log = stats["MIP log"]
training_sample.solution = self.internal_solver.get_solution()
# After-solve callbacks
logger.debug("Calling after_solve_mip callbacks...")

View File

@@ -11,22 +11,19 @@ VarIndex = Union[str, int, Tuple[Union[str, int]]]
Solution = Dict[str, Dict[VarIndex, Optional[float]]]
TrainingSample = TypedDict(
"TrainingSample",
{
"LP log": str,
"LP solution": Optional[Solution],
"LP value": Optional[float],
"LazyStatic: All": Set[str],
"LazyStatic: Enforced": Set[str],
"Lower bound": Optional[float],
"MIP log": str,
"Solution": Optional[Solution],
"Upper bound": Optional[float],
"slacks": Dict,
},
total=False,
)
@dataclass
class TrainingSample:
lp_log: Optional[str] = None
lp_solution: Optional[Solution] = None
lp_value: Optional[float] = None
lazy_enforced: Optional[Set[str]] = None
lower_bound: Optional[float] = None
mip_log: Optional[str] = None
solution: Optional[Solution] = None
upper_bound: Optional[float] = None
slacks: Optional[Dict[str, float]] = None
LPSolveStats = TypedDict(
"LPSolveStats",