Replace InstanceIterator by PickleGzInstance

This commit is contained in:
2021-04-04 14:48:46 -05:00
parent b4770c6c0a
commit 08e808690e
14 changed files with 253 additions and 257 deletions

View File

@@ -8,6 +8,7 @@ import pickle
import tempfile
import os
from miplearn.instance import PickleGzInstance, write_pickle_gz, read_pickle_gz
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.learning import LearningSolver
from . import _get_knapsack_instance, get_internal_solvers
@@ -78,61 +79,40 @@ def test_parallel_solve():
def test_solve_fit_from_disk():
for internal_solver in get_internal_solvers():
# Create instances and pickle them
filenames = []
instances = []
for k in range(3):
instance = _get_knapsack_instance(internal_solver)
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as file:
filenames += [file.name]
pickle.dump(instance, file)
instances += [PickleGzInstance(file.name)]
write_pickle_gz(instance, file.name)
# Test: solve
solver = LearningSolver(solver=internal_solver)
solver.solve(filenames[0])
with open(filenames[0], "rb") as file:
instance = pickle.load(file)
assert len(instance.training_data) > 0
solver.solve(instances[0])
instance_loaded = read_pickle_gz(instances[0].filename)
assert len(instance_loaded.training_data) > 0
# Test: parallel_solve
solver.parallel_solve(filenames)
for filename in filenames:
with open(filename, "rb") as file:
instance = pickle.load(file)
assert len(instance.training_data) > 0
# Test: solve (with specified output)
output = [f + ".out" for f in filenames]
solver.solve(
filenames[0],
output_filename=output[0],
)
assert os.path.isfile(output[0])
# Test: parallel_solve (with specified output)
solver.parallel_solve(
filenames,
output_filenames=output,
)
for filename in output:
assert os.path.isfile(filename)
solver.parallel_solve(instances)
for instance in instances:
instance_loaded = read_pickle_gz(instance.filename)
assert len(instance.training_data) > 0
# Delete temporary files
for filename in filenames:
os.remove(filename)
for filename in output:
os.remove(filename)
for instance in instances:
os.remove(instance.filename)
def test_simulate_perfect():
internal_solver = GurobiSolver
instance = _get_knapsack_instance(internal_solver)
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
pickle.dump(instance, tmp)
tmp.flush()
write_pickle_gz(instance, tmp.name)
solver = LearningSolver(
solver=internal_solver,
simulate_perfect=True,
)
stats = solver.solve(tmp.name)
stats = solver.solve(PickleGzInstance(tmp.name))
assert stats["Lower bound"] == stats["Objective: Predicted lower bound"]