mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Replace InstanceIterator by PickleGzInstance
This commit is contained in:
@@ -8,6 +8,7 @@ import pickle
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from miplearn.instance import PickleGzInstance, write_pickle_gz, read_pickle_gz
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from . import _get_knapsack_instance, get_internal_solvers
|
||||
@@ -78,61 +79,40 @@ def test_parallel_solve():
|
||||
def test_solve_fit_from_disk():
|
||||
for internal_solver in get_internal_solvers():
|
||||
# Create instances and pickle them
|
||||
filenames = []
|
||||
instances = []
|
||||
for k in range(3):
|
||||
instance = _get_knapsack_instance(internal_solver)
|
||||
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as file:
|
||||
filenames += [file.name]
|
||||
pickle.dump(instance, file)
|
||||
instances += [PickleGzInstance(file.name)]
|
||||
write_pickle_gz(instance, file.name)
|
||||
|
||||
# Test: solve
|
||||
solver = LearningSolver(solver=internal_solver)
|
||||
solver.solve(filenames[0])
|
||||
with open(filenames[0], "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
assert len(instance.training_data) > 0
|
||||
solver.solve(instances[0])
|
||||
instance_loaded = read_pickle_gz(instances[0].filename)
|
||||
assert len(instance_loaded.training_data) > 0
|
||||
|
||||
# Test: parallel_solve
|
||||
solver.parallel_solve(filenames)
|
||||
for filename in filenames:
|
||||
with open(filename, "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
assert len(instance.training_data) > 0
|
||||
|
||||
# Test: solve (with specified output)
|
||||
output = [f + ".out" for f in filenames]
|
||||
solver.solve(
|
||||
filenames[0],
|
||||
output_filename=output[0],
|
||||
)
|
||||
assert os.path.isfile(output[0])
|
||||
|
||||
# Test: parallel_solve (with specified output)
|
||||
solver.parallel_solve(
|
||||
filenames,
|
||||
output_filenames=output,
|
||||
)
|
||||
for filename in output:
|
||||
assert os.path.isfile(filename)
|
||||
solver.parallel_solve(instances)
|
||||
for instance in instances:
|
||||
instance_loaded = read_pickle_gz(instance.filename)
|
||||
assert len(instance.training_data) > 0
|
||||
|
||||
# Delete temporary files
|
||||
for filename in filenames:
|
||||
os.remove(filename)
|
||||
for filename in output:
|
||||
os.remove(filename)
|
||||
for instance in instances:
|
||||
os.remove(instance.filename)
|
||||
|
||||
|
||||
def test_simulate_perfect():
|
||||
internal_solver = GurobiSolver
|
||||
instance = _get_knapsack_instance(internal_solver)
|
||||
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
|
||||
pickle.dump(instance, tmp)
|
||||
tmp.flush()
|
||||
write_pickle_gz(instance, tmp.name)
|
||||
solver = LearningSolver(
|
||||
solver=internal_solver,
|
||||
simulate_perfect=True,
|
||||
)
|
||||
stats = solver.solve(tmp.name)
|
||||
stats = solver.solve(PickleGzInstance(tmp.name))
|
||||
assert stats["Lower bound"] == stats["Objective: Predicted lower bound"]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user