Change LearningSolver.solve and fit

This commit is contained in:
2022-02-22 15:21:56 -06:00
parent c98ff4eab4
commit 522f3a7e18
10 changed files with 157 additions and 107 deletions

View File

@@ -5,19 +5,27 @@
import logging
import os
import tempfile
from os.path import exists
from typing import List, cast
import dill
from scipy.stats import randint
from miplearn.features.sample import Hdf5Sample
from miplearn.instance.base import Instance
from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.instance.picklegz import (
PickleGzInstance,
write_pickle_gz,
read_pickle_gz,
save,
)
from miplearn.problems.stab import MaxWeightStableSetGenerator, build_stab_model
from miplearn.solvers.internal import InternalSolver
from miplearn.solvers.learning import LearningSolver
from miplearn.solvers.tests import assert_equals
# noinspection PyUnresolvedReferences
from tests.solvers.test_internal_solver import internal_solvers
from miplearn.solvers.tests import assert_equals
logger = logging.getLogger(__name__)
@@ -34,7 +42,7 @@ def test_learning_solver(
mode=mode,
)
solver.solve(instance)
solver._solve(instance)
assert len(instance.get_samples()) > 0
sample = instance.get_samples()[0]
@@ -55,8 +63,8 @@ def test_learning_solver(
assert lp_log is not None
assert len(lp_log) > 100
solver.fit([instance], n_jobs=4)
solver.solve(instance)
solver._fit([instance], n_jobs=4)
solver._solve(instance)
# Assert solver is picklable
with tempfile.TemporaryFile() as file:
@@ -73,9 +81,9 @@ def test_solve_without_lp(
solver=internal_solver,
solve_lp=False,
)
solver.solve(instance)
solver.fit([instance])
solver.solve(instance)
solver._solve(instance)
solver._fit([instance])
solver._solve(instance)
def test_parallel_solve(
@@ -104,7 +112,7 @@ def test_solve_fit_from_disk(
# Test: solve
solver = LearningSolver(solver=internal_solver)
solver.solve(instances[0])
solver._solve(instances[0])
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
assert len(instance_loaded.get_samples()) > 0
@@ -119,17 +127,29 @@ def test_solve_fit_from_disk(
os.remove(cast(PickleGzInstance, instance).filename)
def test_simulate_perfect() -> None:
internal_solver = GurobiSolver()
instance = internal_solver.build_test_instance_knapsack()
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
write_pickle_gz(instance, tmp.name)
solver = LearningSolver(
solver=internal_solver,
simulate_perfect=True,
)
stats = solver.solve(PickleGzInstance(tmp.name))
assert stats["mip_lower_bound"] == stats["Objective: Predicted lower bound"]
def test_basic_usage() -> None:
with tempfile.TemporaryDirectory() as dirname:
# Generate instances
data = MaxWeightStableSetGenerator(n=randint(low=20, high=21)).generate(4)
train_files = save(data[0:3], f"{dirname}/train")
test_files = save(data[3:4], f"{dirname}/test")
# Solve training instances
solver = LearningSolver()
stats = solver.solve(train_files, build_stab_model)
assert len(stats) == 3
for f in train_files:
sample_filename = f.replace(".pkl.gz", ".h5")
assert exists(sample_filename)
sample = Hdf5Sample(sample_filename)
assert sample.get_scalar("mip_lower_bound") > 0
# Fit
solver.fit(train_files, build_stab_model)
# Solve test instances
stats = solver.solve(test_files, build_stab_model)
assert "Objective: Predicted lower bound" in stats[0].keys()
def test_gap() -> None: