mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-10 11:28:51 -06:00
problems: Allow correlated arguments in random problem generators
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Callable
|
||||
|
||||
import gurobipy as gp
|
||||
import numpy as np
|
||||
@@ -58,7 +58,8 @@ class PMedianGenerator:
|
||||
n
|
||||
Probability distribution for the number of customer.
|
||||
p
|
||||
Probability distribution for the number of medians.
|
||||
Probability distribution for the number of medians, or a callable that takes
|
||||
the number of customers and returns the number of medians (e.g., lambda n: n//10).
|
||||
demands
|
||||
Probability distribution for the customer demands.
|
||||
capacities
|
||||
@@ -70,10 +71,23 @@ class PMedianGenerator:
|
||||
x: rv_frozen = uniform(loc=0.0, scale=100.0),
|
||||
y: rv_frozen = uniform(loc=0.0, scale=100.0),
|
||||
n: rv_frozen = randint(low=100, high=101),
|
||||
p: rv_frozen = randint(low=10, high=11),
|
||||
p: Union[rv_frozen, Callable] = randint(low=10, high=11),
|
||||
demands: rv_frozen = uniform(loc=0, scale=20),
|
||||
capacities: rv_frozen = uniform(loc=0, scale=100),
|
||||
):
|
||||
assert isinstance(x, rv_frozen), "x should be a SciPy probability distribution"
|
||||
assert isinstance(y, rv_frozen), "y should be a SciPy probability distribution"
|
||||
assert isinstance(n, rv_frozen), "n should be a SciPy probability distribution"
|
||||
assert isinstance(p, rv_frozen) or callable(
|
||||
p
|
||||
), "p should be a SciPy probability distribution or callable"
|
||||
assert isinstance(
|
||||
demands, rv_frozen
|
||||
), "demands should be a SciPy probability distribution"
|
||||
assert isinstance(
|
||||
capacities, rv_frozen
|
||||
), "capacities should be a SciPy probability distribution"
|
||||
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.n = n
|
||||
@@ -84,7 +98,10 @@ class PMedianGenerator:
|
||||
def generate(self, n_samples: int) -> List[PMedianData]:
|
||||
def _sample() -> PMedianData:
|
||||
n = self.n.rvs()
|
||||
p = self.p.rvs()
|
||||
if callable(self.p):
|
||||
p = self.p(n)
|
||||
else:
|
||||
p = self.p.rvs()
|
||||
loc = np.array([(self.x.rvs(), self.y.rvs()) for _ in range(n)])
|
||||
distances = squareform(pdist(loc))
|
||||
demands = self.demands.rvs(n)
|
||||
|
||||
Reference in New Issue
Block a user