mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 10:28:52 -06:00
Make all before/solve callbacks receive same parameters
This commit is contained in:
@@ -33,7 +33,15 @@ class UserCutsComponent(Component):
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.classifiers: Dict[Any, Classifier] = {}
|
||||
|
||||
def before_solve_mip(self, solver, instance, model):
|
||||
def before_solve_mip(
|
||||
self,
|
||||
solver,
|
||||
instance,
|
||||
model,
|
||||
stats,
|
||||
features,
|
||||
training_data,
|
||||
):
|
||||
instance.found_violated_user_cuts = []
|
||||
logger.info("Predicting violated user cuts...")
|
||||
violations = self.predict(instance)
|
||||
@@ -42,16 +50,6 @@ class UserCutsComponent(Component):
|
||||
cut = instance.build_user_cut(model, v)
|
||||
solver.internal_solver.add_constraint(cut)
|
||||
|
||||
def after_solve_mip(
|
||||
self,
|
||||
solver,
|
||||
instance,
|
||||
model,
|
||||
results,
|
||||
training_data,
|
||||
):
|
||||
pass
|
||||
|
||||
def fit(self, training_instances):
|
||||
logger.debug("Fitting...")
|
||||
features = InstanceFeaturesExtractor().extract(training_instances)
|
||||
|
||||
Reference in New Issue
Block a user