Convert TrainingSample to dataclass

This commit is contained in:
2021-04-05 20:36:04 -05:00
parent aeed338837
commit b11779817a
15 changed files with 122 additions and 129 deletions

View File

@@ -2,13 +2,9 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import gzip
import logging
import os
import pickle
import traceback
import tempfile
from typing import Optional, List, Any, IO, cast, BinaryIO, Union, Callable, Dict
from typing import Optional, List, Any, cast, Callable, Dict
from p_tqdm import p_map
@@ -22,7 +18,7 @@ from miplearn.instance import Instance, PickleGzInstance
from miplearn.solvers import _RedirectOutput
from miplearn.solvers.internal import InternalSolver
from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
from miplearn.types import TrainingSample, LearningSolveStats, MIPSolveStats
from miplearn.types import TrainingSample, LearningSolveStats
logger = logging.getLogger(__name__)
@@ -134,7 +130,7 @@ class LearningSolver:
model = instance.to_model()
# Initialize training sample
training_sample: TrainingSample = {}
training_sample = TrainingSample()
instance.training_data += [training_sample]
# Initialize stats
@@ -168,16 +164,13 @@ class LearningSolver:
logger.info("Solving root LP relaxation...")
lp_stats = self.internal_solver.solve_lp(tee=tee)
stats.update(cast(LearningSolveStats, lp_stats))
training_sample["LP solution"] = self.internal_solver.get_solution()
training_sample["LP value"] = lp_stats["LP value"]
training_sample["LP log"] = lp_stats["LP log"]
training_sample.lp_solution = self.internal_solver.get_solution()
training_sample.lp_value = lp_stats["LP value"]
training_sample.lp_log = lp_stats["LP log"]
logger.debug("Running after_solve_lp callbacks...")
for component in self.components.values():
component.after_solve_lp(*callback_args)
else:
training_sample["LP solution"] = self.internal_solver.get_empty_solution()
training_sample["LP value"] = 0.0
# Define wrappers
def iteration_cb_wrapper() -> bool:
@@ -213,8 +206,8 @@ class LearningSolver:
lazy_cb=lazy_cb,
)
stats.update(cast(LearningSolveStats, mip_stats))
if "LP value" in training_sample.keys():
stats["LP value"] = training_sample["LP value"]
if training_sample.lp_value is not None:
stats["LP value"] = training_sample.lp_value
stats["Solver"] = "default"
stats["Gap"] = self._compute_gap(
ub=stats["Upper bound"],
@@ -223,10 +216,10 @@ class LearningSolver:
stats["Mode"] = self.mode
# Add some information to training_sample
training_sample["Lower bound"] = stats["Lower bound"]
training_sample["Upper bound"] = stats["Upper bound"]
training_sample["MIP log"] = stats["MIP log"]
training_sample["Solution"] = self.internal_solver.get_solution()
training_sample.lower_bound = stats["Lower bound"]
training_sample.upper_bound = stats["Upper bound"]
training_sample.mip_log = stats["MIP log"]
training_sample.solution = self.internal_solver.get_solution()
# After-solve callbacks
logger.debug("Calling after_solve_mip callbacks...")