mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Convert TrainingSample to dataclass
This commit is contained in:
@@ -2,13 +2,9 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import gzip
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import traceback
|
||||
import tempfile
|
||||
from typing import Optional, List, Any, IO, cast, BinaryIO, Union, Callable, Dict
|
||||
from typing import Optional, List, Any, cast, Callable, Dict
|
||||
|
||||
from p_tqdm import p_map
|
||||
|
||||
@@ -22,7 +18,7 @@ from miplearn.instance import Instance, PickleGzInstance
|
||||
from miplearn.solvers import _RedirectOutput
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
|
||||
from miplearn.types import TrainingSample, LearningSolveStats, MIPSolveStats
|
||||
from miplearn.types import TrainingSample, LearningSolveStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -134,7 +130,7 @@ class LearningSolver:
|
||||
model = instance.to_model()
|
||||
|
||||
# Initialize training sample
|
||||
training_sample: TrainingSample = {}
|
||||
training_sample = TrainingSample()
|
||||
instance.training_data += [training_sample]
|
||||
|
||||
# Initialize stats
|
||||
@@ -168,16 +164,13 @@ class LearningSolver:
|
||||
logger.info("Solving root LP relaxation...")
|
||||
lp_stats = self.internal_solver.solve_lp(tee=tee)
|
||||
stats.update(cast(LearningSolveStats, lp_stats))
|
||||
training_sample["LP solution"] = self.internal_solver.get_solution()
|
||||
training_sample["LP value"] = lp_stats["LP value"]
|
||||
training_sample["LP log"] = lp_stats["LP log"]
|
||||
training_sample.lp_solution = self.internal_solver.get_solution()
|
||||
training_sample.lp_value = lp_stats["LP value"]
|
||||
training_sample.lp_log = lp_stats["LP log"]
|
||||
|
||||
logger.debug("Running after_solve_lp callbacks...")
|
||||
for component in self.components.values():
|
||||
component.after_solve_lp(*callback_args)
|
||||
else:
|
||||
training_sample["LP solution"] = self.internal_solver.get_empty_solution()
|
||||
training_sample["LP value"] = 0.0
|
||||
|
||||
# Define wrappers
|
||||
def iteration_cb_wrapper() -> bool:
|
||||
@@ -213,8 +206,8 @@ class LearningSolver:
|
||||
lazy_cb=lazy_cb,
|
||||
)
|
||||
stats.update(cast(LearningSolveStats, mip_stats))
|
||||
if "LP value" in training_sample.keys():
|
||||
stats["LP value"] = training_sample["LP value"]
|
||||
if training_sample.lp_value is not None:
|
||||
stats["LP value"] = training_sample.lp_value
|
||||
stats["Solver"] = "default"
|
||||
stats["Gap"] = self._compute_gap(
|
||||
ub=stats["Upper bound"],
|
||||
@@ -223,10 +216,10 @@ class LearningSolver:
|
||||
stats["Mode"] = self.mode
|
||||
|
||||
# Add some information to training_sample
|
||||
training_sample["Lower bound"] = stats["Lower bound"]
|
||||
training_sample["Upper bound"] = stats["Upper bound"]
|
||||
training_sample["MIP log"] = stats["MIP log"]
|
||||
training_sample["Solution"] = self.internal_solver.get_solution()
|
||||
training_sample.lower_bound = stats["Lower bound"]
|
||||
training_sample.upper_bound = stats["Upper bound"]
|
||||
training_sample.mip_log = stats["MIP log"]
|
||||
training_sample.solution = self.internal_solver.get_solution()
|
||||
|
||||
# After-solve callbacks
|
||||
logger.debug("Calling after_solve_mip callbacks...")
|
||||
|
||||
Reference in New Issue
Block a user