mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Make sample_ method accept instance
This commit is contained in:
@@ -73,7 +73,7 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
# Predict solution and provide it to the solver
|
||||
logger.info("Predicting MIP solution...")
|
||||
solution = self.sample_predict(features, training_data)
|
||||
solution = self.sample_predict(instance, training_data)
|
||||
assert solver.internal_solver is not None
|
||||
if self.mode == "heuristic":
|
||||
solver.internal_solver.fix(solution)
|
||||
@@ -101,20 +101,20 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
def sample_predict(
|
||||
self,
|
||||
features: Features,
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
) -> Solution:
|
||||
assert features.variables is not None
|
||||
assert instance.features.variables is not None
|
||||
|
||||
# Initialize empty solution
|
||||
solution: Solution = {}
|
||||
for (var_name, var_dict) in features.variables.items():
|
||||
for (var_name, var_dict) in instance.features.variables.items():
|
||||
solution[var_name] = {}
|
||||
for idx in var_dict.keys():
|
||||
solution[var_name][idx] = None
|
||||
|
||||
# Compute y_pred
|
||||
x, _ = self.sample_xy(features, sample)
|
||||
x, _ = self.sample_xy(instance, sample)
|
||||
y_pred = {}
|
||||
for category in x.keys():
|
||||
assert category in self.classifiers, (
|
||||
@@ -133,7 +133,7 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
# Convert y_pred into solution
|
||||
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
||||
for (var_name, var_dict) in features.variables.items():
|
||||
for (var_name, var_dict) in instance.features.variables.items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features.category
|
||||
offset = category_offset[category]
|
||||
@@ -147,16 +147,16 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
@staticmethod
|
||||
def sample_xy(
|
||||
features: Features,
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
assert features.variables is not None
|
||||
assert instance.features.variables is not None
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
solution: Optional[Solution] = None
|
||||
if sample.solution is not None:
|
||||
solution = sample.solution
|
||||
for (var_name, var_dict) in features.variables.items():
|
||||
for (var_name, var_dict) in instance.features.variables.items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features.category
|
||||
if category is None:
|
||||
@@ -186,12 +186,12 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
def sample_evaluate(
|
||||
self,
|
||||
features: Features,
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
solution_actual = sample.solution
|
||||
assert solution_actual is not None
|
||||
solution_pred = self.sample_predict(features, sample)
|
||||
solution_pred = self.sample_predict(instance, sample)
|
||||
vars_all, vars_one, vars_zero = set(), set(), set()
|
||||
pred_one_positive, pred_zero_positive = set(), set()
|
||||
for (varname, var_dict) in solution_actual.items():
|
||||
|
||||
Reference in New Issue
Block a user