mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make extractor configurable
This commit is contained in:
@@ -132,23 +132,22 @@ class Sample:
|
|||||||
class FeaturesExtractor:
|
class FeaturesExtractor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
internal_solver: "InternalSolver",
|
|
||||||
with_sa: bool = True,
|
with_sa: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.solver = internal_solver
|
|
||||||
self.with_sa = with_sa
|
self.with_sa = with_sa
|
||||||
|
|
||||||
def extract(
|
def extract(
|
||||||
self,
|
self,
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
|
solver: "InternalSolver",
|
||||||
with_static: bool = True,
|
with_static: bool = True,
|
||||||
) -> Features:
|
) -> Features:
|
||||||
features = Features()
|
features = Features()
|
||||||
features.variables = self.solver.get_variables(
|
features.variables = solver.get_variables(
|
||||||
with_static=with_static,
|
with_static=with_static,
|
||||||
with_sa=self.with_sa,
|
with_sa=self.with_sa,
|
||||||
)
|
)
|
||||||
features.constraints = self.solver.get_constraints(
|
features.constraints = solver.get_constraints(
|
||||||
with_static=with_static,
|
with_static=with_static,
|
||||||
)
|
)
|
||||||
if with_static:
|
if with_static:
|
||||||
|
|||||||
@@ -99,9 +99,12 @@ class LearningSolver:
|
|||||||
use_lazy_cb: bool = False,
|
use_lazy_cb: bool = False,
|
||||||
solve_lp: bool = True,
|
solve_lp: bool = True,
|
||||||
simulate_perfect: bool = False,
|
simulate_perfect: bool = False,
|
||||||
|
extractor: Optional[FeaturesExtractor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if solver is None:
|
if solver is None:
|
||||||
solver = GurobiPyomoSolver()
|
solver = GurobiPyomoSolver()
|
||||||
|
if extractor is None:
|
||||||
|
extractor = FeaturesExtractor()
|
||||||
assert isinstance(solver, InternalSolver)
|
assert isinstance(solver, InternalSolver)
|
||||||
self.components: Dict[str, Component] = {}
|
self.components: Dict[str, Component] = {}
|
||||||
self.internal_solver: Optional[InternalSolver] = None
|
self.internal_solver: Optional[InternalSolver] = None
|
||||||
@@ -111,6 +114,7 @@ class LearningSolver:
|
|||||||
self.solve_lp: bool = solve_lp
|
self.solve_lp: bool = solve_lp
|
||||||
self.tee = False
|
self.tee = False
|
||||||
self.use_lazy_cb: bool = use_lazy_cb
|
self.use_lazy_cb: bool = use_lazy_cb
|
||||||
|
self.extractor = extractor
|
||||||
if components is not None:
|
if components is not None:
|
||||||
for comp in components:
|
for comp in components:
|
||||||
self._add_component(comp)
|
self._add_component(comp)
|
||||||
@@ -156,7 +160,7 @@ class LearningSolver:
|
|||||||
# Extract features (after-load)
|
# Extract features (after-load)
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
logger.info("Extracting features (after-load)...")
|
logger.info("Extracting features (after-load)...")
|
||||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
features = self.extractor.extract(instance, self.internal_solver)
|
||||||
features.extra = {}
|
features.extra = {}
|
||||||
sample.after_load = features
|
sample.after_load = features
|
||||||
|
|
||||||
@@ -187,8 +191,9 @@ class LearningSolver:
|
|||||||
# Extract features (after-lp)
|
# Extract features (after-lp)
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
logger.info("Extracting features (after-lp)...")
|
logger.info("Extracting features (after-lp)...")
|
||||||
features = FeaturesExtractor(self.internal_solver).extract(
|
features = self.extractor.extract(
|
||||||
instance,
|
instance,
|
||||||
|
self.internal_solver,
|
||||||
with_static=False,
|
with_static=False,
|
||||||
)
|
)
|
||||||
features.extra = {}
|
features.extra = {}
|
||||||
@@ -252,8 +257,9 @@ class LearningSolver:
|
|||||||
# Extract features (after-mip)
|
# Extract features (after-mip)
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
logger.info("Extracting features (after-mip)...")
|
logger.info("Extracting features (after-mip)...")
|
||||||
features = FeaturesExtractor(self.internal_solver).extract(
|
features = self.extractor.extract(
|
||||||
instance,
|
instance,
|
||||||
|
self.internal_solver,
|
||||||
with_static=False,
|
with_static=False,
|
||||||
)
|
)
|
||||||
features.mip_solve = mip_stats
|
features.mip_solve = mip_stats
|
||||||
|
|||||||
Reference in New Issue
Block a user