mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Finish TSP implementation; improve performance of Extractors
This commit is contained in:
@@ -129,7 +129,7 @@ class PrimalSolutionComponent(Component):
|
||||
self.dynamic_thresholds = dynamic_thresholds
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
solution = self.predict(instance, model)
|
||||
solution = self.predict(instance)
|
||||
if self.mode == "heuristic":
|
||||
solver.internal_solver.fix(solution)
|
||||
else:
|
||||
@@ -139,6 +139,7 @@ class PrimalSolutionComponent(Component):
|
||||
pass
|
||||
|
||||
def fit(self, training_instances):
|
||||
logger.debug("Extracting features...")
|
||||
features = VariableFeaturesExtractor().extract(training_instances)
|
||||
solutions = SolutionExtractor().extract(training_instances)
|
||||
|
||||
@@ -180,12 +181,10 @@ class PrimalSolutionComponent(Component):
|
||||
self.thresholds[category, label] = thresholds[k]
|
||||
|
||||
|
||||
def predict(self, instance, model=None):
|
||||
if model is None:
|
||||
model = instance.to_model()
|
||||
x_test = VariableFeaturesExtractor().extract([instance], [model])
|
||||
def predict(self, instance):
|
||||
x_test = VariableFeaturesExtractor().extract([instance])
|
||||
solution = {}
|
||||
var_split = Extractor.split_variables(instance, model)
|
||||
var_split = Extractor.split_variables(instance)
|
||||
for category in var_split.keys():
|
||||
for (i, (var, index)) in enumerate(var_split[category]):
|
||||
if var not in solution.keys():
|
||||
|
||||
Reference in New Issue
Block a user