Implement load; update fit

master
Alinson S. Xavier 4 years ago
parent 522f3a7e18
commit 04dd3ad5d5
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -20,6 +20,7 @@ from .instance.picklegz import (
read_pickle_gz, read_pickle_gz,
write_pickle_gz_multiple, write_pickle_gz_multiple,
save, save,
load,
) )
from .log import setup_logger from .log import setup_logger
from .solvers.gurobi import GurobiSolver from .solvers.gurobi import GurobiSolver

@ -6,7 +6,7 @@ import gc
import gzip import gzip
import os import os
import pickle import pickle
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict, Callable
import numpy as np import numpy as np
from overrides import overrides from overrides import overrides
@ -177,3 +177,9 @@ def save(objs: List[Any], dirname: str) -> List[str]:
filenames.append(filename) filenames.append(filename)
write_pickle_gz(obj, filename) write_pickle_gz(obj, filename)
return filenames return filenames
def load(filename: str, build_model: Callable) -> Any:
with gzip.GzipFile(filename, "rb") as file:
data = pickle.load(cast(IO[bytes], file))
return build_model(data)

@ -5,7 +5,7 @@
import logging import logging
import time import time
import traceback import traceback
from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO, Union
from overrides import overrides from overrides import overrides
from p_tqdm import p_map from p_tqdm import p_map
@ -24,13 +24,18 @@ from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
from miplearn.types import LearningSolveStats from miplearn.types import LearningSolveStats
import gzip import gzip
import pickle import pickle
import miplearn
from os.path import exists from os.path import exists
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InstanceWrapper(Instance): class FileInstanceWrapper(Instance):
def __init__(self, data_filename: Any, build_model: Callable): def __init__(
self,
data_filename: Any,
build_model: Callable,
):
super().__init__() super().__init__()
assert data_filename.endswith(".pkl.gz") assert data_filename.endswith(".pkl.gz")
self.filename = data_filename self.filename = data_filename
@ -43,9 +48,7 @@ class InstanceWrapper(Instance):
@overrides @overrides
def to_model(self) -> Any: def to_model(self) -> Any:
with gzip.GzipFile(self.filename, "rb") as file: return miplearn.load(self.filename, self.build_model)
data = pickle.load(cast(IO[bytes], file))
return self.build_model(data)
@overrides @overrides
def create_sample(self) -> Sample: def create_sample(self) -> Sample:
@ -56,6 +59,17 @@ class InstanceWrapper(Instance):
return [self.sample] return [self.sample]
class MemoryInstanceWrapper(Instance):
def __init__(self, model):
super().__init__()
assert model is not None
self.model = model
@overrides
def to_model(self) -> Any:
return self.model
class _GlobalVariables: class _GlobalVariables:
def __init__(self) -> None: def __init__(self) -> None:
self.solver: Optional[LearningSolver] = None self.solver: Optional[LearningSolver] = None
@ -361,18 +375,24 @@ class LearningSolver:
def solve( def solve(
self, self,
filenames: List[str], arg: Union[Any, List[str]],
build_model: Callable, build_model: Callable = None,
tee: bool = True, tee: bool = False,
) -> List[LearningSolveStats]: ) -> List[LearningSolveStats]:
stats = [] if isinstance(arg, list):
for f in filenames: assert build_model is not None
s = self._solve(InstanceWrapper(f, build_model), tee=tee) stats = []
stats.append(s) for i in arg:
return stats s = self._solve(FileInstanceWrapper(i, build_model), tee=tee)
stats.append(s)
return stats
else:
return self._solve(MemoryInstanceWrapper(arg), tee=tee)
def fit(self, filenames: List[str], build_model: Callable) -> None: def fit(self, filenames: List[str], build_model: Callable) -> None:
instances: List[Instance] = [InstanceWrapper(f, build_model) for f in filenames] instances: List[Instance] = [
FileInstanceWrapper(f, build_model) for f in filenames
]
self._fit(instances) self._fit(instances)
def parallel_solve( def parallel_solve(

@ -322,8 +322,14 @@ class BasePyomoSolver(InternalSolver):
# Bounds # Bounds
lb, ub = v.bounds lb, ub = v.bounds
upper_bounds.append(float(ub)) if ub is not None:
lower_bounds.append(float(lb)) upper_bounds.append(float(ub))
else:
upper_bounds.append(float("inf"))
if lb is not None:
lower_bounds.append(float(lb))
else:
lower_bounds.append(-float("inf"))
# Objective coefficient # Objective coefficient
if v.name in self._obj: if v.name in self._obj:
@ -391,7 +397,9 @@ class BasePyomoSolver(InternalSolver):
) -> None: ) -> None:
if model is None: if model is None:
model = instance.to_model() model = instance.to_model()
assert isinstance(model, pe.ConcreteModel) assert isinstance(
model, pe.ConcreteModel
), f"expected pe.ConcreteModel; found {model.__class__} instead"
self.instance = instance self.instance = instance
self.model = model self.model = model
self.model.extra_constraints = ConstraintList() self.model.extra_constraints = ConstraintList()

Loading…
Cancel
Save