Implement load; update fit

This commit is contained in:
2022-02-25 08:26:33 -06:00
parent 522f3a7e18
commit 04dd3ad5d5
4 changed files with 54 additions and 19 deletions

View File

@@ -5,7 +5,7 @@
import logging
import time
import traceback
from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO
from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO, Union
from overrides import overrides
from p_tqdm import p_map
@@ -24,13 +24,18 @@ from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
from miplearn.types import LearningSolveStats
import gzip
import pickle
import miplearn
from os.path import exists
logger = logging.getLogger(__name__)
class InstanceWrapper(Instance):
def __init__(self, data_filename: Any, build_model: Callable):
class FileInstanceWrapper(Instance):
def __init__(
self,
data_filename: Any,
build_model: Callable,
):
super().__init__()
assert data_filename.endswith(".pkl.gz")
self.filename = data_filename
@@ -43,9 +48,7 @@ class InstanceWrapper(Instance):
@overrides
def to_model(self) -> Any:
with gzip.GzipFile(self.filename, "rb") as file:
data = pickle.load(cast(IO[bytes], file))
return self.build_model(data)
return miplearn.load(self.filename, self.build_model)
@overrides
def create_sample(self) -> Sample:
@@ -56,6 +59,17 @@ class InstanceWrapper(Instance):
return [self.sample]
class MemoryInstanceWrapper(Instance):
def __init__(self, model):
super().__init__()
assert model is not None
self.model = model
@overrides
def to_model(self) -> Any:
return self.model
class _GlobalVariables:
def __init__(self) -> None:
self.solver: Optional[LearningSolver] = None
@@ -361,18 +375,24 @@ class LearningSolver:
def solve(
self,
filenames: List[str],
build_model: Callable,
tee: bool = True,
arg: Union[Any, List[str]],
build_model: Callable = None,
tee: bool = False,
) -> List[LearningSolveStats]:
stats = []
for f in filenames:
s = self._solve(InstanceWrapper(f, build_model), tee=tee)
stats.append(s)
return stats
if isinstance(arg, list):
assert build_model is not None
stats = []
for i in arg:
s = self._solve(FileInstanceWrapper(i, build_model), tee=tee)
stats.append(s)
return stats
else:
return self._solve(MemoryInstanceWrapper(arg), tee=tee)
def fit(self, filenames: List[str], build_model: Callable) -> None:
instances: List[Instance] = [InstanceWrapper(f, build_model) for f in filenames]
instances: List[Instance] = [
FileInstanceWrapper(f, build_model) for f in filenames
]
self._fit(instances)
def parallel_solve(

View File

@@ -322,8 +322,14 @@ class BasePyomoSolver(InternalSolver):
# Bounds
lb, ub = v.bounds
upper_bounds.append(float(ub))
lower_bounds.append(float(lb))
if ub is not None:
upper_bounds.append(float(ub))
else:
upper_bounds.append(float("inf"))
if lb is not None:
lower_bounds.append(float(lb))
else:
lower_bounds.append(-float("inf"))
# Objective coefficient
if v.name in self._obj:
@@ -391,7 +397,9 @@ class BasePyomoSolver(InternalSolver):
) -> None:
if model is None:
model = instance.to_model()
assert isinstance(model, pe.ConcreteModel)
assert isinstance(
model, pe.ConcreteModel
), f"expected pe.ConcreteModel; found {model.__class__} instead"
self.instance = instance
self.model = model
self.model.extra_constraints = ConstraintList()