mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Change LearningSolver.solve and fit
This commit is contained in:
@@ -5,19 +5,27 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from os.path import exists
|
||||
from typing import List, cast
|
||||
|
||||
import dill
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.features.sample import Hdf5Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn.instance.picklegz import (
|
||||
PickleGzInstance,
|
||||
write_pickle_gz,
|
||||
read_pickle_gz,
|
||||
save,
|
||||
)
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator, build_stab_model
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from tests.solvers.test_internal_solver import internal_solvers
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +42,7 @@ def test_learning_solver(
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
assert len(instance.get_samples()) > 0
|
||||
sample = instance.get_samples()[0]
|
||||
|
||||
@@ -55,8 +63,8 @@ def test_learning_solver(
|
||||
assert lp_log is not None
|
||||
assert len(lp_log) > 100
|
||||
|
||||
solver.fit([instance], n_jobs=4)
|
||||
solver.solve(instance)
|
||||
solver._fit([instance], n_jobs=4)
|
||||
solver._solve(instance)
|
||||
|
||||
# Assert solver is picklable
|
||||
with tempfile.TemporaryFile() as file:
|
||||
@@ -73,9 +81,9 @@ def test_solve_without_lp(
|
||||
solver=internal_solver,
|
||||
solve_lp=False,
|
||||
)
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
solver._fit([instance])
|
||||
solver._solve(instance)
|
||||
|
||||
|
||||
def test_parallel_solve(
|
||||
@@ -104,7 +112,7 @@ def test_solve_fit_from_disk(
|
||||
|
||||
# Test: solve
|
||||
solver = LearningSolver(solver=internal_solver)
|
||||
solver.solve(instances[0])
|
||||
solver._solve(instances[0])
|
||||
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
|
||||
assert len(instance_loaded.get_samples()) > 0
|
||||
|
||||
@@ -119,17 +127,29 @@ def test_solve_fit_from_disk(
|
||||
os.remove(cast(PickleGzInstance, instance).filename)
|
||||
|
||||
|
||||
def test_simulate_perfect() -> None:
|
||||
internal_solver = GurobiSolver()
|
||||
instance = internal_solver.build_test_instance_knapsack()
|
||||
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
|
||||
write_pickle_gz(instance, tmp.name)
|
||||
solver = LearningSolver(
|
||||
solver=internal_solver,
|
||||
simulate_perfect=True,
|
||||
)
|
||||
stats = solver.solve(PickleGzInstance(tmp.name))
|
||||
assert stats["mip_lower_bound"] == stats["Objective: Predicted lower bound"]
|
||||
def test_basic_usage() -> None:
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
# Generate instances
|
||||
data = MaxWeightStableSetGenerator(n=randint(low=20, high=21)).generate(4)
|
||||
train_files = save(data[0:3], f"{dirname}/train")
|
||||
test_files = save(data[3:4], f"{dirname}/test")
|
||||
|
||||
# Solve training instances
|
||||
solver = LearningSolver()
|
||||
stats = solver.solve(train_files, build_stab_model)
|
||||
assert len(stats) == 3
|
||||
for f in train_files:
|
||||
sample_filename = f.replace(".pkl.gz", ".h5")
|
||||
assert exists(sample_filename)
|
||||
sample = Hdf5Sample(sample_filename)
|
||||
assert sample.get_scalar("mip_lower_bound") > 0
|
||||
|
||||
# Fit
|
||||
solver.fit(train_files, build_stab_model)
|
||||
|
||||
# Solve test instances
|
||||
stats = solver.solve(test_files, build_stab_model)
|
||||
assert "Objective: Predicted lower bound" in stats[0].keys()
|
||||
|
||||
|
||||
def test_gap() -> None:
|
||||
|
||||
Reference in New Issue
Block a user