mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Implement load; update fit
This commit is contained in:
@@ -20,6 +20,7 @@ from .instance.picklegz import (
|
|||||||
read_pickle_gz,
|
read_pickle_gz,
|
||||||
write_pickle_gz_multiple,
|
write_pickle_gz_multiple,
|
||||||
save,
|
save,
|
||||||
|
load,
|
||||||
)
|
)
|
||||||
from .log import setup_logger
|
from .log import setup_logger
|
||||||
from .solvers.gurobi import GurobiSolver
|
from .solvers.gurobi import GurobiSolver
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import gc
|
|||||||
import gzip
|
import gzip
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -177,3 +177,9 @@ def save(objs: List[Any], dirname: str) -> List[str]:
|
|||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
write_pickle_gz(obj, filename)
|
write_pickle_gz(obj, filename)
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def load(filename: str, build_model: Callable) -> Any:
|
||||||
|
with gzip.GzipFile(filename, "rb") as file:
|
||||||
|
data = pickle.load(cast(IO[bytes], file))
|
||||||
|
return build_model(data)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO
|
from typing import Optional, List, Any, cast, Dict, Tuple, Callable, IO, Union
|
||||||
|
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
from p_tqdm import p_map
|
from p_tqdm import p_map
|
||||||
@@ -24,13 +24,18 @@ from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
|
|||||||
from miplearn.types import LearningSolveStats
|
from miplearn.types import LearningSolveStats
|
||||||
import gzip
|
import gzip
|
||||||
import pickle
|
import pickle
|
||||||
|
import miplearn
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class InstanceWrapper(Instance):
|
class FileInstanceWrapper(Instance):
|
||||||
def __init__(self, data_filename: Any, build_model: Callable):
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_filename: Any,
|
||||||
|
build_model: Callable,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert data_filename.endswith(".pkl.gz")
|
assert data_filename.endswith(".pkl.gz")
|
||||||
self.filename = data_filename
|
self.filename = data_filename
|
||||||
@@ -43,9 +48,7 @@ class InstanceWrapper(Instance):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def to_model(self) -> Any:
|
def to_model(self) -> Any:
|
||||||
with gzip.GzipFile(self.filename, "rb") as file:
|
return miplearn.load(self.filename, self.build_model)
|
||||||
data = pickle.load(cast(IO[bytes], file))
|
|
||||||
return self.build_model(data)
|
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def create_sample(self) -> Sample:
|
def create_sample(self) -> Sample:
|
||||||
@@ -56,6 +59,17 @@ class InstanceWrapper(Instance):
|
|||||||
return [self.sample]
|
return [self.sample]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryInstanceWrapper(Instance):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
assert model is not None
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def to_model(self) -> Any:
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
class _GlobalVariables:
|
class _GlobalVariables:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.solver: Optional[LearningSolver] = None
|
self.solver: Optional[LearningSolver] = None
|
||||||
@@ -361,18 +375,24 @@ class LearningSolver:
|
|||||||
|
|
||||||
def solve(
|
def solve(
|
||||||
self,
|
self,
|
||||||
filenames: List[str],
|
arg: Union[Any, List[str]],
|
||||||
build_model: Callable,
|
build_model: Callable = None,
|
||||||
tee: bool = True,
|
tee: bool = False,
|
||||||
) -> List[LearningSolveStats]:
|
) -> List[LearningSolveStats]:
|
||||||
stats = []
|
if isinstance(arg, list):
|
||||||
for f in filenames:
|
assert build_model is not None
|
||||||
s = self._solve(InstanceWrapper(f, build_model), tee=tee)
|
stats = []
|
||||||
stats.append(s)
|
for i in arg:
|
||||||
return stats
|
s = self._solve(FileInstanceWrapper(i, build_model), tee=tee)
|
||||||
|
stats.append(s)
|
||||||
|
return stats
|
||||||
|
else:
|
||||||
|
return self._solve(MemoryInstanceWrapper(arg), tee=tee)
|
||||||
|
|
||||||
def fit(self, filenames: List[str], build_model: Callable) -> None:
|
def fit(self, filenames: List[str], build_model: Callable) -> None:
|
||||||
instances: List[Instance] = [InstanceWrapper(f, build_model) for f in filenames]
|
instances: List[Instance] = [
|
||||||
|
FileInstanceWrapper(f, build_model) for f in filenames
|
||||||
|
]
|
||||||
self._fit(instances)
|
self._fit(instances)
|
||||||
|
|
||||||
def parallel_solve(
|
def parallel_solve(
|
||||||
|
|||||||
@@ -322,8 +322,14 @@ class BasePyomoSolver(InternalSolver):
|
|||||||
|
|
||||||
# Bounds
|
# Bounds
|
||||||
lb, ub = v.bounds
|
lb, ub = v.bounds
|
||||||
upper_bounds.append(float(ub))
|
if ub is not None:
|
||||||
lower_bounds.append(float(lb))
|
upper_bounds.append(float(ub))
|
||||||
|
else:
|
||||||
|
upper_bounds.append(float("inf"))
|
||||||
|
if lb is not None:
|
||||||
|
lower_bounds.append(float(lb))
|
||||||
|
else:
|
||||||
|
lower_bounds.append(-float("inf"))
|
||||||
|
|
||||||
# Objective coefficient
|
# Objective coefficient
|
||||||
if v.name in self._obj:
|
if v.name in self._obj:
|
||||||
@@ -391,7 +397,9 @@ class BasePyomoSolver(InternalSolver):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if model is None:
|
if model is None:
|
||||||
model = instance.to_model()
|
model = instance.to_model()
|
||||||
assert isinstance(model, pe.ConcreteModel)
|
assert isinstance(
|
||||||
|
model, pe.ConcreteModel
|
||||||
|
), f"expected pe.ConcreteModel; found {model.__class__} instead"
|
||||||
self.instance = instance
|
self.instance = instance
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model.extra_constraints = ConstraintList()
|
self.model.extra_constraints = ConstraintList()
|
||||||
|
|||||||
Reference in New Issue
Block a user