Implement MaxCutPerturber

This commit is contained in:
2025-12-08 13:21:04 -06:00
parent 9192bb02eb
commit 15cdb7e679
3 changed files with 48 additions and 153 deletions

View File

@@ -22,12 +22,11 @@ def _set_seed() -> None:
np.random.seed(42)
def test_maxcut_generator_not_fixed() -> None:
def test_maxcut_generator() -> None:
_set_seed()
gen = MaxCutGenerator(
n=randint(low=5, high=6),
p=uniform(loc=0.5, scale=0.0),
fix_graph=False,
)
data = gen.generate(3)
assert len(data) == 3
@@ -41,35 +40,6 @@ def test_maxcut_generator_not_fixed() -> None:
(3, 4),
]
assert data[0].weights.tolist() == [-1, 1, -1, -1, -1, 1]
assert list(data[1].graph.nodes()) == [0, 1, 2, 3, 4]
assert list(data[1].graph.edges()) == [(0, 1), (0, 3), (0, 4), (1, 4), (3, 4)]
assert data[1].weights.tolist() == [-1, -1, -1, 1, -1]
def test_maxcut_generator_fixed() -> None:
random.seed(42)
np.random.seed(42)
gen = MaxCutGenerator(
n=randint(low=5, high=6),
p=uniform(loc=0.5, scale=0.0),
fix_graph=True,
w_jitter=0.25,
)
data = gen.generate(3)
assert len(data) == 3
for i in range(3):
assert list(data[i].graph.nodes()) == [0, 1, 2, 3, 4]
assert list(data[i].graph.edges()) == [
(0, 2),
(0, 3),
(0, 4),
(2, 3),
(2, 4),
(3, 4),
]
assert data[0].weights.tolist() == [-1, -1, 1, 1, -1, 1]
assert data[1].weights.tolist() == [-1, -1, -1, -1, 1, -1]
assert data[2].weights.tolist() == [1, 1, -1, -1, -1, 1]
def test_maxcut_model() -> None:
@@ -77,7 +47,6 @@ def test_maxcut_model() -> None:
data = MaxCutGenerator(
n=randint(low=10, high=11),
p=uniform(loc=0.5, scale=0.0),
fix_graph=True,
).generate(1)[0]
for model in [
build_maxcut_model_gurobipy(data),