|
|
|
@ -5,7 +5,7 @@
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
import traceback
|
|
|
|
|
from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO
|
|
|
|
|
from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO, Union
|
|
|
|
|
|
|
|
|
|
from overrides import overrides
|
|
|
|
|
from p_tqdm import p_map
|
|
|
|
@ -24,13 +24,18 @@ from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
|
|
|
|
|
from miplearn.types import LearningSolveStats
|
|
|
|
|
import gzip
|
|
|
|
|
import pickle
|
|
|
|
|
import miplearn
|
|
|
|
|
from os.path import exists
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InstanceWrapper(Instance):
|
|
|
|
|
def __init__(self, data_filename: Any, build_model: Callable):
|
|
|
|
|
class FileInstanceWrapper(Instance):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
data_filename: Any,
|
|
|
|
|
build_model: Callable,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert data_filename.endswith(".pkl.gz")
|
|
|
|
|
self.filename = data_filename
|
|
|
|
@ -43,9 +48,7 @@ class InstanceWrapper(Instance):
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def to_model(self) -> Any:
|
|
|
|
|
with gzip.GzipFile(self.filename, "rb") as file:
|
|
|
|
|
data = pickle.load(cast(IO[bytes], file))
|
|
|
|
|
return self.build_model(data)
|
|
|
|
|
return miplearn.load(self.filename, self.build_model)
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def create_sample(self) -> Sample:
|
|
|
|
@ -56,6 +59,17 @@ class InstanceWrapper(Instance):
|
|
|
|
|
return [self.sample]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryInstanceWrapper(Instance):
|
|
|
|
|
def __init__(self, model):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert model is not None
|
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def to_model(self) -> Any:
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _GlobalVariables:
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.solver: Optional[LearningSolver] = None
|
|
|
|
@ -361,18 +375,24 @@ class LearningSolver:
|
|
|
|
|
|
|
|
|
|
def solve(
|
|
|
|
|
self,
|
|
|
|
|
filenames: List[str],
|
|
|
|
|
build_model: Callable,
|
|
|
|
|
tee: bool = True,
|
|
|
|
|
arg: Union[Any, List[str]],
|
|
|
|
|
build_model: Callable = None,
|
|
|
|
|
tee: bool = False,
|
|
|
|
|
) -> List[LearningSolveStats]:
|
|
|
|
|
stats = []
|
|
|
|
|
for f in filenames:
|
|
|
|
|
s = self._solve(InstanceWrapper(f, build_model), tee=tee)
|
|
|
|
|
stats.append(s)
|
|
|
|
|
return stats
|
|
|
|
|
if isinstance(arg, list):
|
|
|
|
|
assert build_model is not None
|
|
|
|
|
stats = []
|
|
|
|
|
for i in arg:
|
|
|
|
|
s = self._solve(FileInstanceWrapper(i, build_model), tee=tee)
|
|
|
|
|
stats.append(s)
|
|
|
|
|
return stats
|
|
|
|
|
else:
|
|
|
|
|
return self._solve(MemoryInstanceWrapper(arg), tee=tee)
|
|
|
|
|
|
|
|
|
|
def fit(self, filenames: List[str], build_model: Callable) -> None:
|
|
|
|
|
instances: List[Instance] = [InstanceWrapper(f, build_model) for f in filenames]
|
|
|
|
|
instances: List[Instance] = [
|
|
|
|
|
FileInstanceWrapper(f, build_model) for f in filenames
|
|
|
|
|
]
|
|
|
|
|
self._fit(instances)
|
|
|
|
|
|
|
|
|
|
def parallel_solve(
|
|
|
|
|