mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Finish TSP implementation; improve performance of Extractors
This commit is contained in:
@@ -27,10 +27,14 @@ class LazyConstraintsComponent(Component):
|
||||
|
||||
def __init__(self):
|
||||
self.violations = set()
|
||||
self.count = {}
|
||||
self.n_samples = 0
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
logger.info("Enforcing %d lazy constraints" % len(self.violations))
|
||||
for v in self.violations:
|
||||
if self.count[v] < self.n_samples * 0.05:
|
||||
continue
|
||||
cut = instance.build_lazy_constraint(model, v)
|
||||
solver.internal_solver.add_constraint(cut)
|
||||
|
||||
@@ -38,11 +42,16 @@ class LazyConstraintsComponent(Component):
|
||||
pass
|
||||
|
||||
def fit(self, training_instances):
|
||||
logger.debug("Fitting...")
|
||||
self.n_samples = len(training_instances)
|
||||
for instance in training_instances:
|
||||
if not hasattr(instance, "found_violations"):
|
||||
continue
|
||||
for v in instance.found_violations:
|
||||
self.violations.add(v)
|
||||
if v not in self.count.keys():
|
||||
self.count[v] = 0
|
||||
self.count[v] += 1
|
||||
|
||||
def predict(self, instance, model=None):
|
||||
return self.violations
|
||||
|
||||
@@ -31,12 +31,15 @@ class ObjectiveValueComponent(Component):
|
||||
pass
|
||||
|
||||
def fit(self, training_instances):
|
||||
logger.debug("Extracting features...")
|
||||
features = InstanceFeaturesExtractor().extract(training_instances)
|
||||
ub = ObjectiveValueExtractor(kind="upper bound").extract(training_instances)
|
||||
lb = ObjectiveValueExtractor(kind="lower bound").extract(training_instances)
|
||||
self.ub_regressor = deepcopy(self.regressor_prototype)
|
||||
self.lb_regressor = deepcopy(self.regressor_prototype)
|
||||
logger.debug("Fitting ub_regressor...")
|
||||
self.ub_regressor.fit(features, ub)
|
||||
logger.debug("Fitting ub_regressor...")
|
||||
self.lb_regressor.fit(features, lb)
|
||||
|
||||
def predict(self, instances):
|
||||
|
||||
@@ -129,7 +129,7 @@ class PrimalSolutionComponent(Component):
|
||||
self.dynamic_thresholds = dynamic_thresholds
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
solution = self.predict(instance, model)
|
||||
solution = self.predict(instance)
|
||||
if self.mode == "heuristic":
|
||||
solver.internal_solver.fix(solution)
|
||||
else:
|
||||
@@ -139,6 +139,7 @@ class PrimalSolutionComponent(Component):
|
||||
pass
|
||||
|
||||
def fit(self, training_instances):
|
||||
logger.debug("Extracting features...")
|
||||
features = VariableFeaturesExtractor().extract(training_instances)
|
||||
solutions = SolutionExtractor().extract(training_instances)
|
||||
|
||||
@@ -180,12 +181,10 @@ class PrimalSolutionComponent(Component):
|
||||
self.thresholds[category, label] = thresholds[k]
|
||||
|
||||
|
||||
def predict(self, instance, model=None):
|
||||
if model is None:
|
||||
model = instance.to_model()
|
||||
x_test = VariableFeaturesExtractor().extract([instance], [model])
|
||||
def predict(self, instance):
|
||||
x_test = VariableFeaturesExtractor().extract([instance])
|
||||
solution = {}
|
||||
var_split = Extractor.split_variables(instance, model)
|
||||
var_split = Extractor.split_variables(instance)
|
||||
for category in var_split.keys():
|
||||
for (i, (var, index)) in enumerate(var_split[category]):
|
||||
if var not in solution.keys():
|
||||
|
||||
@@ -27,29 +27,7 @@ def test_predict():
|
||||
instances, models = _get_instances()
|
||||
comp = PrimalSolutionComponent()
|
||||
comp.fit(instances)
|
||||
solution = comp.predict(instances[0], models[0])
|
||||
assert models[0].x in solution.keys()
|
||||
solution = comp.predict(instances[0])
|
||||
assert "x" in solution
|
||||
for idx in range(4):
|
||||
assert idx in solution[models[0].x].keys()
|
||||
|
||||
# def test_warm_start_save_load():
|
||||
# state_file = tempfile.NamedTemporaryFile(mode="r")
|
||||
# solver = LearningSolver(components={"warm-start": WarmStartComponent()})
|
||||
# solver.parallel_solve(_get_instances(), n_jobs=2)
|
||||
# solver.fit()
|
||||
# comp = solver.components["warm-start"]
|
||||
# assert comp.x_train["default"].shape == (8, 6)
|
||||
# assert comp.y_train["default"].shape == (8, 2)
|
||||
# assert ("default", 0) in comp.predictors.keys()
|
||||
# assert ("default", 1) in comp.predictors.keys()
|
||||
# solver.save_state(state_file.name)
|
||||
|
||||
# solver.solve(_get_instances()[0])
|
||||
|
||||
# solver = LearningSolver(components={"warm-start": WarmStartComponent()})
|
||||
# solver.load_state(state_file.name)
|
||||
# comp = solver.components["warm-start"]
|
||||
# assert comp.x_train["default"].shape == (8, 6)
|
||||
# assert comp.y_train["default"].shape == (8, 2)
|
||||
# assert ("default", 0) in comp.predictors.keys()
|
||||
# assert ("default", 1) in comp.predictors.keys()
|
||||
assert idx in solution["x"]
|
||||
|
||||
Reference in New Issue
Block a user