mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Merge branch 'feature/new-py-api' into feature/docs
This commit is contained in:
@@ -10,6 +10,7 @@ import pytest
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.dynamic_common import DynamicConstraintsComponent
|
||||
from miplearn.components.dynamic_lazy import DynamicLazyConstraintsComponent
|
||||
from miplearn.features.sample import MemorySample
|
||||
from miplearn.instance.base import Instance
|
||||
@@ -24,13 +25,23 @@ def training_instances() -> List[Instance]:
|
||||
samples_0 = [
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c1", "c2"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c1": 0,
|
||||
b"c2": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([5.0]),
|
||||
},
|
||||
),
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c2", "c3"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c2": 0,
|
||||
b"c3": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([5.0]),
|
||||
},
|
||||
),
|
||||
@@ -55,7 +66,12 @@ def training_instances() -> List[Instance]:
|
||||
samples_1 = [
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c3", "c4"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c3": 0,
|
||||
b"c4": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([8.0]),
|
||||
},
|
||||
)
|
||||
@@ -83,8 +99,8 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
||||
comp = DynamicLazyConstraintsComponent()
|
||||
comp.pre_fit(
|
||||
[
|
||||
np.array(["c1", "c3", "c4"], dtype="S"),
|
||||
np.array(["c1", "c2", "c4"], dtype="S"),
|
||||
{b"c1": 0, b"c3": 0, b"c4": 0},
|
||||
{b"c1": 0, b"c2": 0, b"c4": 0},
|
||||
]
|
||||
)
|
||||
x_expected = {
|
||||
@@ -105,7 +121,10 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
||||
|
||||
def test_sample_predict_evaluate(training_instances: List[Instance]) -> None:
|
||||
comp = DynamicLazyConstraintsComponent()
|
||||
comp.known_cids.extend([b"c1", b"c2", b"c3", b"c4"])
|
||||
comp.known_violations[b"c1"] = 0
|
||||
comp.known_violations[b"c2"] = 0
|
||||
comp.known_violations[b"c3"] = 0
|
||||
comp.known_violations[b"c4"] = 0
|
||||
comp.thresholds[b"type-a"] = MinProbabilityThreshold([0.5, 0.5])
|
||||
comp.thresholds[b"type-b"] = MinProbabilityThreshold([0.5, 0.5])
|
||||
comp.classifiers[b"type-a"] = Mock(spec=Classifier)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, FrozenSet, List
|
||||
from typing import Any, List, Dict
|
||||
|
||||
import gurobipy
|
||||
import gurobipy as gp
|
||||
import networkx as nx
|
||||
import pytest
|
||||
@@ -12,12 +13,11 @@ from gurobipy import GRB
|
||||
from networkx import Graph
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
from miplearn.components.dynamic_user_cuts import UserCutsComponent
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,25 +41,32 @@ class GurobiStableSetProblem(Instance):
|
||||
return True
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
assert isinstance(model, gp.Model)
|
||||
vals = model.cbGetNodeRel(model.getVars())
|
||||
violations = []
|
||||
try:
|
||||
vals = model.cbGetNodeRel(model.getVars())
|
||||
except gurobipy.GurobiError:
|
||||
return {}
|
||||
violations = {}
|
||||
for clique in nx.find_cliques(self.graph):
|
||||
if sum(vals[i] for i in clique) > 1:
|
||||
violations.append(",".join([str(i) for i in clique]).encode())
|
||||
vname = (",".join([str(i) for i in clique])).encode()
|
||||
violations[vname] = list(clique)
|
||||
return violations
|
||||
|
||||
@overrides
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: InternalSolver,
|
||||
solver: GurobiSolver,
|
||||
model: Any,
|
||||
cid: ConstraintName,
|
||||
clique: List[int],
|
||||
) -> Any:
|
||||
clique = [int(i) for i in cid.decode().split(",")]
|
||||
x = model.getVars()
|
||||
model.addConstr(gp.quicksum([x[i] for i in clique]) <= 1)
|
||||
constr = gp.quicksum([x[i] for i in clique]) <= 1
|
||||
if solver.cb_where:
|
||||
model.cbCut(constr)
|
||||
else:
|
||||
model.addConstr(constr)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -71,7 +78,7 @@ def stab_instance() -> Instance:
|
||||
@pytest.fixture
|
||||
def solver() -> LearningSolver:
|
||||
return LearningSolver(
|
||||
solver=GurobiSolver(),
|
||||
solver=GurobiSolver(params={"Threads": 1}),
|
||||
components=[UserCutsComponent()],
|
||||
)
|
||||
|
||||
@@ -80,16 +87,18 @@ def test_usage(
|
||||
stab_instance: Instance,
|
||||
solver: LearningSolver,
|
||||
) -> None:
|
||||
stats_before = solver.solve(stab_instance)
|
||||
stats_before = solver._solve(stab_instance)
|
||||
sample = stab_instance.get_samples()[0]
|
||||
user_cuts_enforced = sample.get_array("mip_user_cuts_enforced")
|
||||
assert user_cuts_enforced is not None
|
||||
assert len(user_cuts_enforced) > 0
|
||||
user_cuts_encoded = sample.get_scalar("mip_user_cuts")
|
||||
assert user_cuts_encoded is not None
|
||||
user_cuts = json.loads(user_cuts_encoded)
|
||||
assert user_cuts is not None
|
||||
assert len(user_cuts) > 0
|
||||
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
||||
assert stats_before["UserCuts: Added in callback"] > 0
|
||||
|
||||
solver.fit([stab_instance])
|
||||
stats_after = solver.solve(stab_instance)
|
||||
solver._fit([stab_instance])
|
||||
stats_after = solver._solve(stab_instance)
|
||||
assert (
|
||||
stats_after["UserCuts: Added ahead-of-time"]
|
||||
== stats_before["UserCuts: Added in callback"]
|
||||
|
||||
@@ -134,8 +134,8 @@ def test_sample_evaluate(sample: Sample) -> None:
|
||||
def test_usage() -> None:
|
||||
solver = LearningSolver(components=[ObjectiveValueComponent()])
|
||||
instance = GurobiPyomoSolver().build_test_instance_knapsack()
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
solver._fit([instance])
|
||||
stats = solver._solve(instance)
|
||||
assert stats["mip_lower_bound"] == stats["Objective: Predicted lower bound"]
|
||||
assert stats["mip_upper_bound"] == stats["Objective: Predicted upper bound"]
|
||||
|
||||
@@ -13,7 +13,7 @@ from miplearn.classifiers.threshold import Threshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.primal import PrimalSolutionComponent
|
||||
from miplearn.features.sample import Sample, MemorySample
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator, TravelingSalesmanInstance
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
@@ -108,10 +108,11 @@ def test_usage() -> None:
|
||||
]
|
||||
)
|
||||
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
||||
instance = gen.generate(1)[0]
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
data = gen.generate(1)
|
||||
instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances)
|
||||
solver._solve(instance)
|
||||
solver._fit([instance])
|
||||
stats = solver._solve(instance)
|
||||
assert stats["Primal: Free"] == 0
|
||||
assert stats["Primal: One"] + stats["Primal: Zero"] == 10
|
||||
assert stats["mip_lower_bound"] == stats["mip_warm_start_value"]
|
||||
|
||||
@@ -22,7 +22,7 @@ def test_usage() -> None:
|
||||
|
||||
# Solve instance from disk
|
||||
solver = LearningSolver(solver=GurobiSolver())
|
||||
solver.solve(FileInstance(filename))
|
||||
solver._solve(FileInstance(filename))
|
||||
|
||||
# Assert HDF5 contains training data
|
||||
sample = FileInstance(filename).get_samples()[0]
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import tempfile
|
||||
from typing import cast, IO
|
||||
|
||||
from miplearn.instance.picklegz import write_pickle_gz, PickleGzInstance
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn import save
|
||||
from os.path import exists
|
||||
import gzip
|
||||
import pickle
|
||||
|
||||
|
||||
def test_usage() -> None:
|
||||
@@ -14,3 +20,14 @@ def test_usage() -> None:
|
||||
pickled = PickleGzInstance(file.name)
|
||||
pickled.load()
|
||||
assert pickled.to_model() is not None
|
||||
|
||||
|
||||
def test_save() -> None:
|
||||
objs = [1, "ABC", True]
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
filenames = save(objs, dirname)
|
||||
assert len(filenames) == 3
|
||||
for (idx, f) in enumerate(filenames):
|
||||
assert exists(f)
|
||||
with gzip.GzipFile(f, "rb") as file:
|
||||
assert pickle.load(cast(IO[bytes], file)) == objs[idx]
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
from scipy.stats import uniform, randint
|
||||
|
||||
from miplearn import LearningSolver
|
||||
from miplearn.problems.knapsack import MultiKnapsackGenerator
|
||||
from miplearn.problems.knapsack import MultiKnapsackGenerator, MultiKnapsackInstance
|
||||
|
||||
|
||||
def test_knapsack_generator() -> None:
|
||||
@@ -18,17 +18,22 @@ def test_knapsack_generator() -> None:
|
||||
u=uniform(loc=1.0, scale=1.0),
|
||||
alpha=uniform(loc=0.50, scale=0.0),
|
||||
)
|
||||
instances = gen.generate(100)
|
||||
w_sum = sum(instance.weights for instance in instances) / len(instances)
|
||||
b_sum = sum(instance.capacities for instance in instances) / len(instances)
|
||||
data = gen.generate(100)
|
||||
w_sum = sum(d.weights for d in data) / len(data)
|
||||
b_sum = sum(d.capacities for d in data) / len(data)
|
||||
assert round(float(np.mean(w_sum)), -1) == 500.0
|
||||
assert round(float(np.mean(b_sum)), -3) == 25000.0
|
||||
|
||||
|
||||
def test_knapsack() -> None:
|
||||
instance = MultiKnapsackGenerator(
|
||||
data = MultiKnapsackGenerator(
|
||||
n=randint(low=5, high=6),
|
||||
m=randint(low=5, high=6),
|
||||
).generate(1)[0]
|
||||
).generate(1)
|
||||
instance = MultiKnapsackInstance(
|
||||
prices=data[0].prices,
|
||||
capacities=data[0].capacities,
|
||||
weights=data[0].weights,
|
||||
)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_stab() -> None:
|
||||
weights = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
instance = MaxWeightStableSetInstance(graph, weights)
|
||||
solver = LearningSolver()
|
||||
stats = solver.solve(instance)
|
||||
stats = solver._solve(instance)
|
||||
assert stats["mip_lower_bound"] == 2.0
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ def test_stab_generator_fixed_graph() -> None:
|
||||
p=uniform(loc=0.05, scale=0.0),
|
||||
fix_graph=True,
|
||||
)
|
||||
instances = gen.generate(1_000)
|
||||
weights = np.array([instance.weights for instance in instances])
|
||||
data = gen.generate(1_000)
|
||||
weights = np.array([d.weights for d in data])
|
||||
weights_avg_actual = np.round(np.average(weights, axis=0))
|
||||
weights_avg_expected = [55.0] * 10
|
||||
assert list(weights_avg_actual) == weights_avg_expected
|
||||
@@ -46,8 +46,8 @@ def test_stab_generator_random_graph() -> None:
|
||||
p=uniform(loc=0.5, scale=0.0),
|
||||
fix_graph=False,
|
||||
)
|
||||
instances = gen.generate(1_000)
|
||||
n_nodes = [instance.graph.number_of_nodes() for instance in instances]
|
||||
n_edges = [instance.graph.number_of_edges() for instance in instances]
|
||||
data = gen.generate(1_000)
|
||||
n_nodes = [d.graph.number_of_nodes() for d in data]
|
||||
n_edges = [d.graph.number_of_edges() for d in data]
|
||||
assert np.round(np.mean(n_nodes)) == 35.0
|
||||
assert np.round(np.mean(n_edges), -1) == 300.0
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
@@ -13,17 +14,17 @@ from miplearn.solvers.tests import assert_equals
|
||||
|
||||
|
||||
def test_generator() -> None:
|
||||
instances = TravelingSalesmanGenerator(
|
||||
data = TravelingSalesmanGenerator(
|
||||
x=uniform(loc=0.0, scale=1000.0),
|
||||
y=uniform(loc=0.0, scale=1000.0),
|
||||
n=randint(low=100, high=101),
|
||||
gamma=uniform(loc=0.95, scale=0.1),
|
||||
fix_cities=True,
|
||||
).generate(100)
|
||||
assert len(instances) == 100
|
||||
assert instances[0].n_cities == 100
|
||||
assert norm(instances[0].distances - instances[0].distances.T) < 1e-6
|
||||
d = [instance.distances[0, 1] for instance in instances]
|
||||
assert len(data) == 100
|
||||
assert data[0].n_cities == 100
|
||||
assert norm(data[0].distances - data[0].distances.T) < 1e-6
|
||||
d = [d.distances[0, 1] for d in data]
|
||||
assert np.std(d) > 0
|
||||
|
||||
|
||||
@@ -39,7 +40,7 @@ def test_instance() -> None:
|
||||
)
|
||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
assert len(instance.get_samples()) == 1
|
||||
sample = instance.get_samples()[0]
|
||||
assert_equals(sample.get_array("mip_var_values"), [1.0, 0.0, 1.0, 1.0, 0.0, 1.0])
|
||||
@@ -62,13 +63,19 @@ def test_subtour() -> None:
|
||||
distances = squareform(pdist(cities))
|
||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
samples = instance.get_samples()
|
||||
assert len(samples) == 1
|
||||
sample = samples[0]
|
||||
lazy_enforced = sample.get_array("mip_constr_lazy_enforced")
|
||||
assert lazy_enforced is not None
|
||||
assert len(lazy_enforced) > 0
|
||||
|
||||
lazy_encoded = sample.get_scalar("mip_constr_lazy")
|
||||
assert lazy_encoded is not None
|
||||
lazy = json.loads(lazy_encoded)
|
||||
assert lazy == {
|
||||
"st[0,1,4]": [0, 1, 4],
|
||||
"st[2,3,5]": [2, 3, 5],
|
||||
}
|
||||
|
||||
assert_equals(
|
||||
sample.get_array("mip_var_values"),
|
||||
[
|
||||
@@ -89,5 +96,5 @@ def test_subtour() -> None:
|
||||
1.0,
|
||||
],
|
||||
)
|
||||
solver.fit([instance])
|
||||
solver.solve(instance)
|
||||
solver._fit([instance])
|
||||
solver._solve(instance)
|
||||
|
||||
@@ -5,19 +5,27 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from os.path import exists
|
||||
from typing import List, cast
|
||||
|
||||
import dill
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.features.sample import Hdf5Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn.instance.picklegz import (
|
||||
PickleGzInstance,
|
||||
write_pickle_gz,
|
||||
read_pickle_gz,
|
||||
save,
|
||||
)
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator, build_stab_model
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from tests.solvers.test_internal_solver import internal_solvers
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +42,7 @@ def test_learning_solver(
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
assert len(instance.get_samples()) > 0
|
||||
sample = instance.get_samples()[0]
|
||||
|
||||
@@ -55,8 +63,8 @@ def test_learning_solver(
|
||||
assert lp_log is not None
|
||||
assert len(lp_log) > 100
|
||||
|
||||
solver.fit([instance], n_jobs=4)
|
||||
solver.solve(instance)
|
||||
solver._fit([instance], n_jobs=4)
|
||||
solver._solve(instance)
|
||||
|
||||
# Assert solver is picklable
|
||||
with tempfile.TemporaryFile() as file:
|
||||
@@ -73,9 +81,9 @@ def test_solve_without_lp(
|
||||
solver=internal_solver,
|
||||
solve_lp=False,
|
||||
)
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
solver.solve(instance)
|
||||
solver._solve(instance)
|
||||
solver._fit([instance])
|
||||
solver._solve(instance)
|
||||
|
||||
|
||||
def test_parallel_solve(
|
||||
@@ -104,7 +112,7 @@ def test_solve_fit_from_disk(
|
||||
|
||||
# Test: solve
|
||||
solver = LearningSolver(solver=internal_solver)
|
||||
solver.solve(instances[0])
|
||||
solver._solve(instances[0])
|
||||
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
|
||||
assert len(instance_loaded.get_samples()) > 0
|
||||
|
||||
@@ -119,17 +127,30 @@ def test_solve_fit_from_disk(
|
||||
os.remove(cast(PickleGzInstance, instance).filename)
|
||||
|
||||
|
||||
def test_simulate_perfect() -> None:
|
||||
internal_solver = GurobiSolver()
|
||||
instance = internal_solver.build_test_instance_knapsack()
|
||||
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
|
||||
write_pickle_gz(instance, tmp.name)
|
||||
solver = LearningSolver(
|
||||
solver=internal_solver,
|
||||
simulate_perfect=True,
|
||||
)
|
||||
stats = solver.solve(PickleGzInstance(tmp.name))
|
||||
assert stats["mip_lower_bound"] == stats["Objective: Predicted lower bound"]
|
||||
def test_basic_usage() -> None:
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
# Generate instances
|
||||
data = MaxWeightStableSetGenerator(n=randint(low=20, high=21)).generate(4)
|
||||
train_files = save(data[0:3], f"{dirname}/train")
|
||||
test_files = save(data[3:4], f"{dirname}/test")
|
||||
|
||||
# Solve training instances
|
||||
solver = LearningSolver()
|
||||
stats = solver.solve(train_files, build_stab_model)
|
||||
assert len(stats) == 3
|
||||
for f in train_files:
|
||||
sample_filename = f.replace(".pkl.gz", ".h5")
|
||||
assert exists(sample_filename)
|
||||
sample = Hdf5Sample(sample_filename)
|
||||
assert sample.get_scalar("mip_lower_bound") > 0
|
||||
|
||||
# Fit
|
||||
solver.fit(train_files, build_stab_model)
|
||||
|
||||
# Solve test instances
|
||||
stats = solver.solve(test_files, build_stab_model)
|
||||
assert isinstance(stats, list)
|
||||
assert "Objective: Predicted lower bound" in stats[0].keys()
|
||||
|
||||
|
||||
def test_gap() -> None:
|
||||
|
||||
@@ -7,7 +7,10 @@ import os.path
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.benchmark import BenchmarkRunner
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
||||
from miplearn.problems.stab import (
|
||||
MaxWeightStableSetInstance,
|
||||
MaxWeightStableSetGenerator,
|
||||
)
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
|
||||
|
||||
@@ -15,8 +18,14 @@ def test_benchmark() -> None:
|
||||
for n_jobs in [1, 4]:
|
||||
# Generate training and test instances
|
||||
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
|
||||
train_instances = generator.generate(5)
|
||||
test_instances = generator.generate(3)
|
||||
train_instances = [
|
||||
MaxWeightStableSetInstance(data.graph, data.weights)
|
||||
for data in generator.generate(5)
|
||||
]
|
||||
test_instances = [
|
||||
MaxWeightStableSetInstance(data.graph, data.weights)
|
||||
for data in generator.generate(3)
|
||||
]
|
||||
|
||||
# Solve training instances
|
||||
training_solver = LearningSolver()
|
||||
|
||||
Reference in New Issue
Block a user