mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Minor changes
This commit is contained in:
@@ -6,4 +6,6 @@
|
|||||||
from .component import Component
|
from .component import Component
|
||||||
from .instance import Instance
|
from .instance import Instance
|
||||||
from .solvers import LearningSolver
|
from .solvers import LearningSolver
|
||||||
from .benchmark import BenchmarkRunner
|
from .benchmark import BenchmarkRunner
|
||||||
|
from .warmstart import WarmStartComponent
|
||||||
|
from .branching import BranchPriorityComponent
|
||||||
@@ -14,22 +14,22 @@ from scipy.stats.distributions import rv_frozen
|
|||||||
class MaxWeightStableSetChallengeA:
|
class MaxWeightStableSetChallengeA:
|
||||||
"""
|
"""
|
||||||
- Fixed random graph (200 vertices, 5% density)
|
- Fixed random graph (200 vertices, 5% density)
|
||||||
- Uniformly random weights in the [100., 125.] interval
|
- Random weights ~ U(100., 150.)
|
||||||
- 500 training instances
|
- 300 training instances
|
||||||
- 100 test instances
|
- 50 test instances
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.generator = MaxWeightStableSetGenerator(w=uniform(loc=100., scale=25.),
|
self.generator = MaxWeightStableSetGenerator(w=uniform(loc=100., scale=50.),
|
||||||
n=randint(low=200, high=201),
|
n=randint(low=200, high=201),
|
||||||
density=uniform(loc=0.05, scale=0.0),
|
density=uniform(loc=0.05, scale=0.0),
|
||||||
fix_graph=True)
|
fix_graph=True)
|
||||||
|
|
||||||
def get_training_instances(self):
|
def get_training_instances(self):
|
||||||
return self.generator.generate(500)
|
return self.generator.generate(300)
|
||||||
|
|
||||||
def get_test_instances(self):
|
def get_test_instances(self):
|
||||||
return self.generator.generate(100)
|
return self.generator.generate(50)
|
||||||
|
|
||||||
|
|
||||||
class MaxWeightStableSetGenerator:
|
class MaxWeightStableSetGenerator:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ branch_and_bound(mip,
|
|||||||
node_limit = 1000,
|
node_limit = 1000,
|
||||||
branch_rule = full_strong_branching_track,
|
branch_rule = full_strong_branching_track,
|
||||||
node_rule = best_bound,
|
node_rule = best_bound,
|
||||||
print_interval = 1)
|
print_interval = 100)
|
||||||
|
|
||||||
priority = [(pseudocost_count_up[v] == 0 || pseudocost_count_down[v] == 0) ? 0 :
|
priority = [(pseudocost_count_up[v] == 0 || pseudocost_count_down[v] == 0) ? 0 :
|
||||||
(pseudocost_sum_up[v] / pseudocost_count_up[v]) *
|
(pseudocost_sum_up[v] / pseudocost_count_up[v]) *
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ class LearningSolver:
|
|||||||
|
|
||||||
if mode is not None:
|
if mode is not None:
|
||||||
assert mode in ["exact", "heuristic"]
|
assert mode in ["exact", "heuristic"]
|
||||||
for component in self.components:
|
for component in self.components.values():
|
||||||
component.mode = mode
|
component.mode = mode
|
||||||
|
|
||||||
def _create_solver(self):
|
def _create_solver(self):
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class WarmStartPredictor(ABC):
|
|||||||
class LogisticWarmStartPredictor(WarmStartPredictor):
|
class LogisticWarmStartPredictor(WarmStartPredictor):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
min_samples=100,
|
min_samples=100,
|
||||||
thr_fix=[0.99, 0.99],
|
thr_fix=[0.95, 0.95],
|
||||||
thr_balance=[0.95, 0.95],
|
thr_balance=[0.95, 0.95],
|
||||||
thr_score=[0.95, 0.95]):
|
thr_score=[0.95, 0.95]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -151,8 +151,12 @@ class WarmStartComponent(Component):
|
|||||||
if self.mode == "heuristic":
|
if self.mode == "heuristic":
|
||||||
if ws[i,0] == 1:
|
if ws[i,0] == 1:
|
||||||
var[index].fix(0)
|
var[index].fix(0)
|
||||||
|
if solver.is_persistent:
|
||||||
|
solver.internal_solver.update_var(var[index])
|
||||||
elif ws[i,1] == 1:
|
elif ws[i,1] == 1:
|
||||||
var[index].fix(1)
|
var[index].fix(1)
|
||||||
|
if solver.is_persistent:
|
||||||
|
solver.internal_solver.update_var(var[index])
|
||||||
else:
|
else:
|
||||||
if ws[i,0] == 1:
|
if ws[i,0] == 1:
|
||||||
var[index].value = 0
|
var[index].value = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user