|
|
|
@ -69,12 +69,14 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
features: Features,
|
|
|
|
|
training_data: TrainingSample,
|
|
|
|
|
) -> None:
|
|
|
|
|
logger.info("Predicting primal solution...")
|
|
|
|
|
|
|
|
|
|
# Do nothing if models are not trained
|
|
|
|
|
if len(self.classifiers) == 0:
|
|
|
|
|
logger.info("Classifiers not fitted. Skipping.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Predict solution and provide it to the solver
|
|
|
|
|
logger.info("Predicting MIP solution...")
|
|
|
|
|
solution = self.sample_predict(instance, training_data)
|
|
|
|
|
assert solver.internal_solver is not None
|
|
|
|
|
if self.mode == "heuristic":
|
|
|
|
@ -130,6 +132,8 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
|
|
|
|
for (var_name, var_features) in instance.features.variables.items():
|
|
|
|
|
category = var_features.category
|
|
|
|
|
if category not in category_offset:
|
|
|
|
|
continue
|
|
|
|
|
offset = category_offset[category]
|
|
|
|
|
category_offset[category] += 1
|
|
|
|
|
if y_pred[category][offset, 0]:
|
|
|
|
|