Make extractor configurable

master
Alinson S. Xavier 5 years ago
parent 95e326f5f6
commit 39597287a6
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -132,23 +132,22 @@ class Sample:
class FeaturesExtractor: class FeaturesExtractor:
def __init__( def __init__(
self, self,
internal_solver: "InternalSolver",
with_sa: bool = True, with_sa: bool = True,
) -> None: ) -> None:
self.solver = internal_solver
self.with_sa = with_sa self.with_sa = with_sa
def extract( def extract(
self, self,
instance: "Instance", instance: "Instance",
solver: "InternalSolver",
with_static: bool = True, with_static: bool = True,
) -> Features: ) -> Features:
features = Features() features = Features()
features.variables = self.solver.get_variables( features.variables = solver.get_variables(
with_static=with_static, with_static=with_static,
with_sa=self.with_sa, with_sa=self.with_sa,
) )
features.constraints = self.solver.get_constraints( features.constraints = solver.get_constraints(
with_static=with_static, with_static=with_static,
) )
if with_static: if with_static:

@ -99,9 +99,12 @@ class LearningSolver:
use_lazy_cb: bool = False, use_lazy_cb: bool = False,
solve_lp: bool = True, solve_lp: bool = True,
simulate_perfect: bool = False, simulate_perfect: bool = False,
extractor: Optional[FeaturesExtractor] = None,
) -> None: ) -> None:
if solver is None: if solver is None:
solver = GurobiPyomoSolver() solver = GurobiPyomoSolver()
if extractor is None:
extractor = FeaturesExtractor()
assert isinstance(solver, InternalSolver) assert isinstance(solver, InternalSolver)
self.components: Dict[str, Component] = {} self.components: Dict[str, Component] = {}
self.internal_solver: Optional[InternalSolver] = None self.internal_solver: Optional[InternalSolver] = None
@ -111,6 +114,7 @@ class LearningSolver:
self.solve_lp: bool = solve_lp self.solve_lp: bool = solve_lp
self.tee = False self.tee = False
self.use_lazy_cb: bool = use_lazy_cb self.use_lazy_cb: bool = use_lazy_cb
self.extractor = extractor
if components is not None: if components is not None:
for comp in components: for comp in components:
self._add_component(comp) self._add_component(comp)
@ -156,7 +160,7 @@ class LearningSolver:
# Extract features (after-load) # Extract features (after-load)
# ------------------------------------------------------- # -------------------------------------------------------
logger.info("Extracting features (after-load)...") logger.info("Extracting features (after-load)...")
features = FeaturesExtractor(self.internal_solver).extract(instance) features = self.extractor.extract(instance, self.internal_solver)
features.extra = {} features.extra = {}
sample.after_load = features sample.after_load = features
@ -187,8 +191,9 @@ class LearningSolver:
# Extract features (after-lp) # Extract features (after-lp)
# ------------------------------------------------------- # -------------------------------------------------------
logger.info("Extracting features (after-lp)...") logger.info("Extracting features (after-lp)...")
features = FeaturesExtractor(self.internal_solver).extract( features = self.extractor.extract(
instance, instance,
self.internal_solver,
with_static=False, with_static=False,
) )
features.extra = {} features.extra = {}
@ -252,8 +257,9 @@ class LearningSolver:
# Extract features (after-mip) # Extract features (after-mip)
# ------------------------------------------------------- # -------------------------------------------------------
logger.info("Extracting features (after-mip)...") logger.info("Extracting features (after-mip)...")
features = FeaturesExtractor(self.internal_solver).extract( features = self.extractor.extract(
instance, instance,
self.internal_solver,
with_static=False, with_static=False,
) )
features.mip_solve = mip_stats features.mip_solve = mip_stats

Loading…
Cancel
Save