mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Change LearningSolver.solve and fit
This commit is contained in:
@@ -87,7 +87,7 @@ def test_usage(
|
||||
stab_instance: Instance,
|
||||
solver: LearningSolver,
|
||||
) -> None:
|
||||
stats_before = solver.solve(stab_instance)
|
||||
stats_before = solver._solve(stab_instance)
|
||||
sample = stab_instance.get_samples()[0]
|
||||
user_cuts_encoded = sample.get_scalar("mip_user_cuts")
|
||||
assert user_cuts_encoded is not None
|
||||
@@ -97,8 +97,8 @@ def test_usage(
|
||||
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
||||
assert stats_before["UserCuts: Added in callback"] > 0
|
||||
|
||||
solver.fit([stab_instance])
|
||||
stats_after = solver.solve(stab_instance)
|
||||
solver._fit([stab_instance])
|
||||
stats_after = solver._solve(stab_instance)
|
||||
assert (
|
||||
stats_after["UserCuts: Added ahead-of-time"]
|
||||
== stats_before["UserCuts: Added in callback"]
|
||||
|
||||
@@ -134,8 +134,8 @@ def test_sample_evaluate(sample: Sample) -> None:
|
||||
def test_usage() -> None:
|
||||
solver = LearningSolver(components=[ObjectiveValueComponent()])
|
||||
instance = GurobiPyomoSolver().build_test_instance_knapsack()
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
solver._fit([instance])
|
||||
stats = solver._solve(instance)
|
||||
assert stats["mip_lower_bound"] == stats["Objective: Predicted lower bound"]
|
||||
assert stats["mip_upper_bound"] == stats["Objective: Predicted upper bound"]
|
||||
|
||||
@@ -110,9 +110,9 @@ def test_usage() -> None:
|
||||
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
||||
data = gen.generate(1)
|
||||
instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances)
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
solver._fit([instance])
|
||||
stats = solver._solve(instance)
|
||||
assert stats["Primal: Free"] == 0
|
||||
assert stats["Primal: One"] + stats["Primal: Zero"] == 10
|
||||
assert stats["mip_lower_bound"] == stats["mip_warm_start_value"]
|
||||
|
||||
@@ -22,7 +22,7 @@ def test_usage() -> None:
|
||||
|
||||
# Solve instance from disk
|
||||
solver = LearningSolver(solver=GurobiSolver())
|
||||
solver.solve(FileInstance(filename))
|
||||
solver._solve(FileInstance(filename))
|
||||
|
||||
# Assert HDF5 contains training data
|
||||
sample = FileInstance(filename).get_samples()[0]
|
||||
|
||||
@@ -36,4 +36,4 @@ def test_knapsack() -> None:
|
||||
weights=data[0].weights,
|
||||
)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_stab() -> None:
|
||||
weights = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
instance = MaxWeightStableSetInstance(graph, weights)
|
||||
solver = LearningSolver()
|
||||
stats = solver.solve(instance)
|
||||
stats = solver._solve(instance)
|
||||
assert stats["mip_lower_bound"] == 2.0
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def test_instance() -> None:
|
||||
)
|
||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
assert len(instance.get_samples()) == 1
|
||||
sample = instance.get_samples()[0]
|
||||
assert_equals(sample.get_array("mip_var_values"), [1.0, 0.0, 1.0, 1.0, 0.0, 1.0])
|
||||
@@ -63,7 +63,7 @@ def test_subtour() -> None:
|
||||
distances = squareform(pdist(cities))
|
||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
samples = instance.get_samples()
|
||||
assert len(samples) == 1
|
||||
sample = samples[0]
|
||||
@@ -96,5 +96,5 @@ def test_subtour() -> None:
|
||||
1.0,
|
||||
],
|
||||
)
|
||||
solver.fit([instance])
|
||||
solver.solve(instance)
|
||||
solver._fit([instance])
|
||||
solver._solve(instance)
|
||||
|
||||
@@ -5,19 +5,27 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from os.path import exists
|
||||
from typing import List, cast
|
||||
|
||||
import dill
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.features.sample import Hdf5Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn.instance.picklegz import (
|
||||
PickleGzInstance,
|
||||
write_pickle_gz,
|
||||
read_pickle_gz,
|
||||
save,
|
||||
)
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator, build_stab_model
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from tests.solvers.test_internal_solver import internal_solvers
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +42,7 @@ def test_learning_solver(
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
assert len(instance.get_samples()) > 0
|
||||
sample = instance.get_samples()[0]
|
||||
|
||||
@@ -55,8 +63,8 @@ def test_learning_solver(
|
||||
assert lp_log is not None
|
||||
assert len(lp_log) > 100
|
||||
|
||||
solver.fit([instance], n_jobs=4)
|
||||
solver.solve(instance)
|
||||
solver._fit([instance], n_jobs=4)
|
||||
solver._solve(instance)
|
||||
|
||||
# Assert solver is picklable
|
||||
with tempfile.TemporaryFile() as file:
|
||||
@@ -73,9 +81,9 @@ def test_solve_without_lp(
|
||||
solver=internal_solver,
|
||||
solve_lp=False,
|
||||
)
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
solver._fit([instance])
|
||||
solver._solve(instance)
|
||||
|
||||
|
||||
def test_parallel_solve(
|
||||
@@ -104,7 +112,7 @@ def test_solve_fit_from_disk(
|
||||
|
||||
# Test: solve
|
||||
solver = LearningSolver(solver=internal_solver)
|
||||
solver.solve(instances[0])
|
||||
solver._solve(instances[0])
|
||||
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
|
||||
assert len(instance_loaded.get_samples()) > 0
|
||||
|
||||
@@ -119,17 +127,29 @@ def test_solve_fit_from_disk(
|
||||
os.remove(cast(PickleGzInstance, instance).filename)
|
||||
|
||||
|
||||
def test_simulate_perfect() -> None:
|
||||
internal_solver = GurobiSolver()
|
||||
instance = internal_solver.build_test_instance_knapsack()
|
||||
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
|
||||
write_pickle_gz(instance, tmp.name)
|
||||
solver = LearningSolver(
|
||||
solver=internal_solver,
|
||||
simulate_perfect=True,
|
||||
)
|
||||
stats = solver.solve(PickleGzInstance(tmp.name))
|
||||
assert stats["mip_lower_bound"] == stats["Objective: Predicted lower bound"]
|
||||
def test_basic_usage() -> None:
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
# Generate instances
|
||||
data = MaxWeightStableSetGenerator(n=randint(low=20, high=21)).generate(4)
|
||||
train_files = save(data[0:3], f"{dirname}/train")
|
||||
test_files = save(data[3:4], f"{dirname}/test")
|
||||
|
||||
# Solve training instances
|
||||
solver = LearningSolver()
|
||||
stats = solver.solve(train_files, build_stab_model)
|
||||
assert len(stats) == 3
|
||||
for f in train_files:
|
||||
sample_filename = f.replace(".pkl.gz", ".h5")
|
||||
assert exists(sample_filename)
|
||||
sample = Hdf5Sample(sample_filename)
|
||||
assert sample.get_scalar("mip_lower_bound") > 0
|
||||
|
||||
# Fit
|
||||
solver.fit(train_files, build_stab_model)
|
||||
|
||||
# Solve test instances
|
||||
stats = solver.solve(test_files, build_stab_model)
|
||||
assert "Objective: Predicted lower bound" in stats[0].keys()
|
||||
|
||||
|
||||
def test_gap() -> None:
|
||||
|
||||
Reference in New Issue
Block a user