problems: Allow correlated arguments in random problem generators

This commit is contained in:
2025-12-08 16:08:05 -06:00
parent 485625e07f
commit 9f0fa0e500
9 changed files with 133 additions and 30 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Union, Callable
import gurobipy as gp
import numpy as np
@@ -47,8 +47,10 @@ class MultiKnapsackGenerator:
----------
n: rv_discrete
Probability distribution for the number of items (or variables).
m: rv_discrete
Probability distribution for the number of knapsacks (or constraints).
m: rv_discrete or callable
Probability distribution for the number of knapsacks (or constraints), or a
callable that takes the numer of items and returns the number of knapsacks
(e.g., lambda n: n//3).
w: rv_continuous
Probability distribution for the item weights.
K: rv_continuous
@@ -65,7 +67,7 @@ class MultiKnapsackGenerator:
def __init__(
self,
n: rv_frozen = randint(low=100, high=101),
m: rv_frozen = randint(low=30, high=31),
m: Union[rv_frozen, Callable] = randint(low=30, high=31),
w: rv_frozen = randint(low=0, high=1000),
K: rv_frozen = randint(low=500, high=501),
u: rv_frozen = uniform(loc=0.0, scale=1.0),
@@ -73,7 +75,9 @@ class MultiKnapsackGenerator:
round: bool = True,
):
assert isinstance(n, rv_frozen), "n should be a SciPy probability distribution"
assert isinstance(m, rv_frozen), "m should be a SciPy probability distribution"
assert isinstance(m, rv_frozen) or callable(
m
), "m should be a SciPy probability distribution or callable"
assert isinstance(w, rv_frozen), "w should be a SciPy probability distribution"
assert isinstance(K, rv_frozen), "K should be a SciPy probability distribution"
assert isinstance(u, rv_frozen), "u should be a SciPy probability distribution"
@@ -92,7 +96,10 @@ class MultiKnapsackGenerator:
def generate(self, n_samples: int) -> List[MultiKnapsackData]:
def _sample() -> MultiKnapsackData:
n = self.n.rvs()
m = self.m.rvs()
if callable(self.m):
m = self.m(n)
else:
m = self.m.rvs()
w = np.array([self.w.rvs(n) for _ in range(m)])
u = self.u.rvs(n)
K = self.K.rvs()