simulate_perfect: Do not overwrite original file

This commit is contained in:
2021-01-13 11:04:33 -06:00
parent b01d97cc2b
commit beee252fa2
6 changed files with 46 additions and 26 deletions

View File

@@ -11,6 +11,7 @@ import gzip
from copy import deepcopy
from typing import Optional, List
from p_tqdm import p_map
from tempfile import NamedTemporaryFile
from . import RedirectOutput
from .. import (
@@ -211,13 +212,16 @@ class LearningSolver:
details.
"""
if self.simulate_perfect:
self._solve(
instance=instance,
model=model,
output=output,
tee=tee,
)
self.fit([instance])
if not isinstance(instance, str):
raise Exception("Not implemented")
with tempfile.NamedTemporaryFile(suffix=os.path.basename(instance)) as tmp:
self._solve(
instance=instance,
model=model,
output=tmp.name,
tee=tee,
)
self.fit([tmp.name])
return self._solve(
instance=instance,
model=model,

View File

@@ -7,8 +7,11 @@ import pickle
import tempfile
import os
from miplearn import DynamicLazyConstraintsComponent
from miplearn import LearningSolver
from miplearn import (
LearningSolver,
GurobiSolver,
DynamicLazyConstraintsComponent,
)
from . import _get_instance, _get_internal_solvers
@@ -109,3 +112,18 @@ def test_solve_fit_from_disk():
os.remove(filename)
for filename in output:
os.remove(filename)
def test_simulate_perfect():
internal_solver = GurobiSolver()
instance = _get_instance(internal_solver)
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
pickle.dump(instance, tmp)
tmp.flush()
solver = LearningSolver(
solver=internal_solver,
simulate_perfect=True,
)
stats = solver.solve(tmp.name)
assert stats["Lower bound"] == stats["Predicted LB"]