mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-10 11:28:51 -06:00
problems: Allow correlated arguments in random problem generators
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Callable
|
||||
|
||||
import gurobipy as gp
|
||||
import numpy as np
|
||||
@@ -47,8 +47,10 @@ class MultiKnapsackGenerator:
|
||||
----------
|
||||
n: rv_discrete
|
||||
Probability distribution for the number of items (or variables).
|
||||
m: rv_discrete
|
||||
Probability distribution for the number of knapsacks (or constraints).
|
||||
m: rv_discrete or callable
|
||||
Probability distribution for the number of knapsacks (or constraints), or a
|
||||
callable that takes the numer of items and returns the number of knapsacks
|
||||
(e.g., lambda n: n//3).
|
||||
w: rv_continuous
|
||||
Probability distribution for the item weights.
|
||||
K: rv_continuous
|
||||
@@ -65,7 +67,7 @@ class MultiKnapsackGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
n: rv_frozen = randint(low=100, high=101),
|
||||
m: rv_frozen = randint(low=30, high=31),
|
||||
m: Union[rv_frozen, Callable] = randint(low=30, high=31),
|
||||
w: rv_frozen = randint(low=0, high=1000),
|
||||
K: rv_frozen = randint(low=500, high=501),
|
||||
u: rv_frozen = uniform(loc=0.0, scale=1.0),
|
||||
@@ -73,7 +75,9 @@ class MultiKnapsackGenerator:
|
||||
round: bool = True,
|
||||
):
|
||||
assert isinstance(n, rv_frozen), "n should be a SciPy probability distribution"
|
||||
assert isinstance(m, rv_frozen), "m should be a SciPy probability distribution"
|
||||
assert isinstance(m, rv_frozen) or callable(
|
||||
m
|
||||
), "m should be a SciPy probability distribution or callable"
|
||||
assert isinstance(w, rv_frozen), "w should be a SciPy probability distribution"
|
||||
assert isinstance(K, rv_frozen), "K should be a SciPy probability distribution"
|
||||
assert isinstance(u, rv_frozen), "u should be a SciPy probability distribution"
|
||||
@@ -92,7 +96,10 @@ class MultiKnapsackGenerator:
|
||||
def generate(self, n_samples: int) -> List[MultiKnapsackData]:
|
||||
def _sample() -> MultiKnapsackData:
|
||||
n = self.n.rvs()
|
||||
m = self.m.rvs()
|
||||
if callable(self.m):
|
||||
m = self.m(n)
|
||||
else:
|
||||
m = self.m.rvs()
|
||||
w = np.array([self.w.rvs(n) for _ in range(m)])
|
||||
u = self.u.rvs(n)
|
||||
K = self.K.rvs()
|
||||
|
||||
Reference in New Issue
Block a user